diff --git a/proxy/proxy.go b/proxy/proxy.go index 3b0cdbd..ed23af7 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -8,6 +8,8 @@ import ( M "github.com/xjasonlyu/tun2socks/constant" "github.com/xjasonlyu/tun2socks/proxy/proto" + + "go.uber.org/atomic" ) const ( @@ -15,7 +17,7 @@ const ( ) var ( - _defaultDialer Dialer = &Base{} + _defaultDialer atomic.Value ) type Dialer interface { @@ -29,24 +31,28 @@ type Proxy interface { Proto() proto.Proto } +func init() { + _defaultDialer.Store(&Base{}) +} + // SetDialer sets default Dialer. func SetDialer(d Dialer) { - _defaultDialer = d + _defaultDialer.Store(d) } // Dial uses default Dialer to dial TCP. func Dial(metadata *M.Metadata) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) defer cancel() - return _defaultDialer.DialContext(ctx, metadata) + return _defaultDialer.Load().(Dialer).DialContext(ctx, metadata) } // DialContext uses default Dialer to dial TCP with context. func DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) { - return _defaultDialer.DialContext(ctx, metadata) + return _defaultDialer.Load().(Dialer).DialContext(ctx, metadata) } // DialUDP uses default Dialer to dial UDP. func DialUDP(metadata *M.Metadata) (net.PacketConn, error) { - return _defaultDialer.DialUDP(metadata) + return _defaultDialer.Load().(Dialer).DialUDP(metadata) } diff --git a/tunnel/udp.go b/tunnel/udp.go index f9bac4d..692a44d 100755 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -14,6 +14,8 @@ import ( "github.com/xjasonlyu/tun2socks/log" "github.com/xjasonlyu/tun2socks/proxy" "github.com/xjasonlyu/tun2socks/tunnel/statistic" + + "go.uber.org/atomic" ) var ( @@ -23,11 +25,11 @@ var ( // _udpSessionTimeout is the default timeout for // each UDP session. - _udpSessionTimeout = 60 * time.Second + _udpSessionTimeout = atomic.NewInt64(int64(60 * time.Second)) ) func SetUDPTimeout(v int) { - _udpSessionTimeout = time.Duration(v) * time.Second + _udpSessionTimeout.Store(int64(time.Duration(v) * time.Second)) } func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { @@ -119,7 +121,7 @@ func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil { log.Warnf("[UDP] write to %s error: %v", remote, err) } - pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */ + pc.SetReadDeadline(time.Now().Add(time.Duration(_udpSessionTimeout.Load()))) /* reset timeout */ log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote) } @@ -129,7 +131,7 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) { defer pool.Put(buf) for /* just loop */ { - pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) + pc.SetReadDeadline(time.Now().Add(time.Duration(_udpSessionTimeout.Load()))) n, from, err := pc.ReadFrom(buf) if err != nil { if !errors.Is(err, os.ErrDeadlineExceeded) /* ignore i/o timeout */ {