Skip to content

Commit

Permalink
feat: connect timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
Menci committed Sep 10, 2024
1 parent 3141ef5 commit ff18b39
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
38 changes: 37 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"regexp"
"strconv"
"strings"
"time"

"gopkg.in/yaml.v2"
)
Expand All @@ -30,8 +31,10 @@ type ServiceConfig struct {
Connect string `yaml:"connect"`
logLevel string `yaml:"logLevel,omitempty"`
ProxyProtocol bool `yaml:"proxyProtocol,omitempty"`
timeout string `yaml:"timeout,omitempty"`

LogLevel LogLevel
Timeout time.Duration
}

func parseLogLevel(s string) (LogLevel, error) {
Expand All @@ -50,8 +53,11 @@ func parseLogLevel(s string) (LogLevel, error) {
}

type Config struct {
timeout string `yaml:"timeout,omitempty"`
Tailscale TailscaleConfig `yaml:"tailscale"`
Services map[string]*ServiceConfig `yaml:"services"`

Timeout time.Duration
}

type boolFlag struct {
Expand All @@ -75,6 +81,7 @@ func (f *boolFlag) String() string {

type arguments struct {
conf string
timeout string
tsHostname string
tsAuthKey string
tsEphemeral boolFlag
Expand All @@ -89,6 +96,7 @@ type arguments struct {
func parseArguments() *arguments {
flags := &arguments{}
flag.StringVar(&flags.conf, "conf", "", "YAML Configuration file")
flag.StringVar(&flags.timeout, "timeout", "", "Default connection timeout of services (and Tailscale proxies)")
flag.StringVar(&flags.tsHostname, "ts-hostname", "", "Tailscale hostname")
flag.StringVar(&flags.tsAuthKey, "ts-authkey", "", "Tailscale authentication key (default to $TS_AUTHKEY)")
flag.Var(&flags.tsEphemeral, "ts-ephemeral", "Set the Tailscale host to ephemeral")
Expand All @@ -102,6 +110,7 @@ func parseArguments() *arguments {
fmt.Fprint(f, "\nTsukasa - A flexible port forwarder among TCP, UNIX Socket and Tailscale TCP ports.\n\n")
flag.PrintDefaults()
fmt.Fprintf(f, "\nExample: %s \\\n", os.Args[0])
fmt.Fprintln(f, " --timeout 10s \\")
fmt.Fprintln(f, " --ts-hostname Tsukasa \\")
fmt.Fprintln(f, " --ts-authkey \"$TS_AUTHKEY\" \\")
fmt.Fprintln(f, " --ts-ephemeral false \\")
Expand Down Expand Up @@ -167,6 +176,11 @@ func parseService(s string) (name string, service *ServiceConfig, err error) {
return "", nil, fmt.Errorf("no value expected for option `proxy-protocol`")
}
service.ProxyProtocol = true
case "timeout":
if value == nil {
return "", nil, fmt.Errorf("required value for option `timeout`")
}
service.timeout = *value
default:
return "", nil, fmt.Errorf("unknown service argument: %s", key)
}
Expand All @@ -176,6 +190,10 @@ func parseService(s string) (name string, service *ServiceConfig, err error) {
}

func mergeConfig(c *Config, a *arguments) error {
if a.timeout != "" {
c.timeout = a.timeout
}

if a.tsHostname != "" {
c.Tailscale.Hostname = a.tsHostname
}
Expand Down Expand Up @@ -227,7 +245,7 @@ func (c *Config) ValidateTailscaleConfig() error {
return nil
}

func (c *Config) ValidateServices() error {
func (c *Config) ProcessServices() error {
for name, service := range c.Services {
if service.Listen == "" {
return fmt.Errorf("missing listen address for service %s", name)
Expand All @@ -236,6 +254,16 @@ func (c *Config) ValidateServices() error {
if service.Connect == "" {
return fmt.Errorf("missing connect address for service %s", name)
}

if service.timeout == "" {
service.Timeout = c.Timeout
} else {
if timeout, err := time.ParseDuration(service.timeout); err != nil {
return fmt.Errorf("invalid timeout for service %s: %v", name, err)
} else {
service.Timeout = timeout
}
}
}

return nil
Expand Down Expand Up @@ -264,6 +292,14 @@ func GetConfig() (*Config, error) {
return nil, err
}

if c.timeout != "" {
if timeout, err := time.ParseDuration(c.timeout); err != nil {
return nil, fmt.Errorf("invalid default timeout: %v", err)
} else {
c.Timeout = timeout
}
}

if c.Tailscale.AuthKey == "" {
c.Tailscale.AuthKey = os.Getenv("TS_AUTHKEY")
}
Expand Down
25 changes: 17 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"context"
"net"
"os"
"os/signal"
"sync"
Expand All @@ -19,7 +21,7 @@ func main() {
logger.Fatalf("invalid config: %v", err)
}

if err := config.ValidateServices(); err != nil {
if err := config.ProcessServices(); err != nil {
logger.Fatalf("invalid service config: %v", err)
}

Expand Down Expand Up @@ -79,13 +81,20 @@ func main() {

somethingRunning := false

if config.Tailscale.Listen.Socks5 != "" {
somethingRunning = true
StartProxy(tsLogger, config.Tailscale.Listen.Socks5, tsnet.Dial, Socks5)
}
if config.Tailscale.Listen.HTTP != "" {
somethingRunning = true
StartProxy(tsLogger, config.Tailscale.Listen.HTTP, tsnet.Dial, HTTP)
if config.Tailscale.Listen.Socks5 != "" || config.Tailscale.Listen.HTTP != "" {
proxyDial := func(ctx context.Context, network, address string) (net.Conn, error) {
ctx2, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
return tsnet.Dial(ctx2, network, address)
}
if config.Tailscale.Listen.Socks5 != "" {
somethingRunning = true
StartProxy(tsLogger, config.Tailscale.Listen.Socks5, proxyDial, Socks5)
}
if config.Tailscale.Listen.HTTP != "" {
somethingRunning = true
StartProxy(tsLogger, config.Tailscale.Listen.HTTP, proxyDial, HTTP)
}
}

// Start services.
Expand Down
13 changes: 9 additions & 4 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"strconv"
"sync"
"time"

"tailscale.com/tsnet"
)
Expand Down Expand Up @@ -39,6 +40,7 @@ type Service struct {
ConnectPort int16
ConnectProxyProtocol bool
LogLevel LogLevel
Timeout time.Duration
}

func parsePort(portString string) (int16, error) {
Expand Down Expand Up @@ -95,8 +97,9 @@ func CreateService(serviceContext *ServiceContext, name string, config *ServiceC
ServiceContext: serviceContext,
Config: config,
Name: name,
LogLevel: config.LogLevel,
ConnectProxyProtocol: config.ProxyProtocol,
LogLevel: config.LogLevel,
Timeout: config.Timeout,
}
if service.ListenType, service.ListenAddress, service.ListenPort, err = parseUrl(urlTypeListen, config.Listen); err != nil {
return nil, err
Expand Down Expand Up @@ -135,15 +138,17 @@ func (s *Service) CreateConnector() (func() (net.Conn, error), error) {
switch s.ConnectType {
case AddressTCP:
return func() (net.Conn, error) {
return net.Dial("tcp", s.ConnectAddress+":"+strconv.Itoa(int(s.ConnectPort)))
return net.DialTimeout("tcp", s.ConnectAddress+":"+strconv.Itoa(int(s.ConnectPort)), s.Timeout)
}, nil
case AddressUNIXSocket:
return func() (net.Conn, error) {
return net.Dial("unix", s.ConnectAddress)
return net.DialTimeout("unix", s.ConnectAddress, s.Timeout)
}, nil
case AddressTailscaleTCP:
return func() (net.Conn, error) {
return s.ServiceContext.TsNet.Dial(context.Background(), "tcp", s.ConnectAddress+":"+strconv.Itoa(int(s.ConnectPort)))
ctx, cancel := context.WithTimeout(context.Background(), s.Timeout)
defer cancel()
return s.ServiceContext.TsNet.Dial(ctx, "tcp", s.ConnectAddress+":"+strconv.Itoa(int(s.ConnectPort)))
}, nil
default:
return nil, fmt.Errorf("invalid connect address type: %v", s.ConnectType)
Expand Down

0 comments on commit ff18b39

Please sign in to comment.