From 2d0bd1d2199988277dba97876620b59f1bb23c6b Mon Sep 17 00:00:00 2001 From: LanceLi Date: Mon, 3 Apr 2023 19:15:24 +0800 Subject: [PATCH] Improve: tunnel/tcp pipe (#219) Co-authored-by: xjasonlyu --- tunnel/statistic/tracker.go | 15 ++++++++++++ tunnel/tcp.go | 47 +++++++++++++++---------------------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 15755af..393c200 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -1,6 +1,7 @@ package statistic import ( + "errors" "net" "time" @@ -79,6 +80,20 @@ func (tt *tcpTracker) Close() error { return tt.Conn.Close() } +func (tt *tcpTracker) CloseRead() error { + if cr, ok := tt.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return errors.New("CloseRead is not implemented") +} + +func (tt *tcpTracker) CloseWrite() error { + if cw, ok := tt.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return errors.New("CloseWrite is not implemented") +} + type udpTracker struct { net.PacketConn `json:"-"` diff --git a/tunnel/tcp.go b/tunnel/tcp.go index a45de8b..2618246 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -1,7 +1,6 @@ package tunnel import ( - "errors" "io" "net" "sync" @@ -43,42 +42,34 @@ func handleTCPConn(originConn adapter.TCPConn) { defer remoteConn.Close() log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) - if err = pipe(originConn, remoteConn); err != nil { - log.Debugf("[TCP] %s <-> %s: %v", metadata.SourceAddress(), metadata.DestinationAddress(), err) - } + pipe(originConn, remoteConn) } // pipe copies copy data to & from provided net.Conn(s) bidirectionally. -func pipe(origin, remote net.Conn) error { +func pipe(origin, remote net.Conn) { wg := sync.WaitGroup{} wg.Add(2) - var leftErr, rightErr error - - go func() { - defer wg.Done() - if err := copyBuffer(remote, origin); err != nil { - leftErr = errors.Join(leftErr, err) - } - remote.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) - }() - - go func() { - defer wg.Done() - if err := copyBuffer(origin, remote); err != nil { - rightErr = errors.Join(rightErr, err) - } - origin.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) - }() + go unidirectionalStream(remote, origin, "origin->remote", &wg) + go unidirectionalStream(origin, remote, "remote->origin", &wg) wg.Wait() - return errors.Join(leftErr, rightErr) } -func copyBuffer(dst io.Writer, src io.Reader) error { +func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) { + defer wg.Done() buf := pool.Get(pool.RelayBufferSize) - defer pool.Put(buf) - - _, err := io.CopyBuffer(dst, src, buf) - return err + if _, err := io.CopyBuffer(dst, src, buf); err != nil { + log.Debugf("[TCP] copy data for %s: %v", dir, err) + } + pool.Put(buf) + // Do the upload/download side TCP half-close. + if cr, ok := src.(interface{ CloseRead() error }); ok { + cr.CloseRead() + } + if cw, ok := dst.(interface{ CloseWrite() error }); ok { + cw.CloseWrite() + } + // Set TCP half-close timeout. + dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) }