diff --git a/engine/engine.go b/engine/engine.go index 3c21353..00eda02 100755 --- a/engine/engine.go +++ b/engine/engine.go @@ -10,6 +10,7 @@ import ( "github.com/xjasonlyu/tun2socks/log" "github.com/xjasonlyu/tun2socks/proxy" "github.com/xjasonlyu/tun2socks/stats" + "github.com/xjasonlyu/tun2socks/tunnel" ) var _engine = &engine{} @@ -30,15 +31,16 @@ func Insert(k *Key) { } type Key struct { - MTU int - Mark int - Proxy string - Stats string - Token string - Device string - LogLevel string - Interface string - Version bool + MTU int + Mark int + UDPTimeout int + Proxy string + Stats string + Token string + Device string + LogLevel string + Interface string + Version bool } type engine struct { @@ -64,6 +66,7 @@ func (e *engine) start() error { e.setMark, e.setInterface, e.setStats, + e.setUDPTimeout, e.setProxy, e.setDevice, e.setStack, @@ -123,6 +126,13 @@ func (e *engine) setStats() error { return nil } +func (e *engine) setUDPTimeout() error { + if e.UDPTimeout > 0 { + tunnel.SetUDPTimeout(e.UDPTimeout) + } + return nil +} + func (e *engine) setProxy() (err error) { if e.Proxy == "" { return errors.New("empty proxy") diff --git a/main.go b/main.go index 844345e..8966ee5 100755 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ var key = new(engine.Key) func init() { flag.IntVar(&key.Mark, "fwmark", 0, "Set firewall MARK (Linux only)") flag.IntVar(&key.MTU, "mtu", 0, "Set device maximum transmission unit (MTU)") + flag.IntVar(&key.UDPTimeout, "udp-timeout", 0, "Set timeout for each UDP session") flag.BoolVar(&key.Version, "version", false, "Show version information and quit") flag.StringVar(&key.Device, "device", "", "Use this device [driver://]name") flag.StringVar(&key.Interface, "interface", "", "Use network INTERFACE (Linux/MacOS only)") diff --git a/tunnel/udp.go b/tunnel/udp.go index 5dce3fd..f9bac4d 100755 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -16,16 +16,20 @@ import ( "github.com/xjasonlyu/tun2socks/tunnel/statistic" ) -const ( - udpSessionTimeout = 60 * time.Second -) - var ( // _natTable uses source udp packet information // as key to store destination udp packetConn. _natTable = nat.NewTable() + + // _udpSessionTimeout is the default timeout for + // each UDP session. + _udpSessionTimeout = 60 * time.Second ) +func SetUDPTimeout(v int) { + _udpSessionTimeout = time.Duration(v) * time.Second +} + func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager) } @@ -115,7 +119,7 @@ func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil { log.Warnf("[UDP] write to %s error: %v", remote, err) } - pc.SetReadDeadline(time.Now().Add(udpSessionTimeout)) /* reset timeout */ + pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */ log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote) } @@ -125,7 +129,7 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) { defer pool.Put(buf) for /* just loop */ { - pc.SetReadDeadline(time.Now().Add(udpSessionTimeout)) + pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) n, from, err := pc.ReadFrom(buf) if err != nil { if !errors.Is(err, os.ErrDeadlineExceeded) /* ignore i/o timeout */ {