diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index 81da988..e736113 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -4,6 +4,7 @@ import ( "io" "net" "strconv" + "sync" "time" "golang.org/x/net/proxy" @@ -33,33 +34,44 @@ func NewTCPHandler(proxyHost string, proxyPort uint16, fakeDns dns.FakeDns, sess } func (h *tcpHandler) relay(localConn, remoteConn net.Conn) { - upCh := make(chan struct{}) + var once sync.Once + closeOnce := func() { + once.Do(func() { + localConn.Close() + remoteConn.Close() + }) + } // Close - defer func() { - localConn.Close() - remoteConn.Close() - }() + defer closeOnce() - flag := remoteConn.LocalAddr().String() + var wg sync.WaitGroup + wg.Add(1) + + upCh := make(chan struct{}) // UpLink go func() { - io.Copy(remoteConn, localConn) - remoteConn.SetReadDeadline(time.Now()) - - log.Warnf("up link finished: %v", flag) - + if _, err := io.Copy(remoteConn, localConn); err != nil { + closeOnce() + } upCh <- struct{}{} + wg.Done() }() // DownLink - io.Copy(localConn, remoteConn) - localConn.SetReadDeadline(time.Now()) + if _, err := io.Copy(localConn, remoteConn); err != nil { + closeOnce() + } - log.Warnf("down link finished: %v", flag) + select { + case <-upCh: + case <-time.After(10 * time.Second): + closeOnce() + } - <-upCh // Wait for UpLink done. + //<-upCh // Wait for UpLink done. + wg.Wait() if h.sessionStater != nil { h.sessionStater.RemoveSession(localConn)