|
|
@ -1,7 +1,6 @@ |
|
|
|
package tunnel |
|
|
|
|
|
|
|
import ( |
|
|
|
"errors" |
|
|
|
"io" |
|
|
|
"net" |
|
|
|
"sync" |
|
|
@ -43,42 +42,34 @@ func handleTCPConn(originConn adapter.TCPConn) { |
|
|
|
defer remoteConn.Close() |
|
|
|
|
|
|
|
log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) |
|
|
|
if err = pipe(originConn, remoteConn); err != nil { |
|
|
|
log.Debugf("[TCP] %s <-> %s: %v", metadata.SourceAddress(), metadata.DestinationAddress(), err) |
|
|
|
} |
|
|
|
pipe(originConn, remoteConn) |
|
|
|
} |
|
|
|
|
|
|
|
// pipe copies copy data to & from provided net.Conn(s) bidirectionally.
|
|
|
|
func pipe(origin, remote net.Conn) error { |
|
|
|
func pipe(origin, remote net.Conn) { |
|
|
|
wg := sync.WaitGroup{} |
|
|
|
wg.Add(2) |
|
|
|
|
|
|
|
var leftErr, rightErr error |
|
|
|
|
|
|
|
go func() { |
|
|
|
defer wg.Done() |
|
|
|
if err := copyBuffer(remote, origin); err != nil { |
|
|
|
leftErr = errors.Join(leftErr, err) |
|
|
|
} |
|
|
|
remote.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) |
|
|
|
}() |
|
|
|
|
|
|
|
go func() { |
|
|
|
defer wg.Done() |
|
|
|
if err := copyBuffer(origin, remote); err != nil { |
|
|
|
rightErr = errors.Join(rightErr, err) |
|
|
|
} |
|
|
|
origin.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) |
|
|
|
}() |
|
|
|
go unidirectionalStream(remote, origin, "origin->remote", &wg) |
|
|
|
go unidirectionalStream(origin, remote, "remote->origin", &wg) |
|
|
|
|
|
|
|
wg.Wait() |
|
|
|
return errors.Join(leftErr, rightErr) |
|
|
|
} |
|
|
|
|
|
|
|
func copyBuffer(dst io.Writer, src io.Reader) error { |
|
|
|
func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) { |
|
|
|
defer wg.Done() |
|
|
|
buf := pool.Get(pool.RelayBufferSize) |
|
|
|
defer pool.Put(buf) |
|
|
|
|
|
|
|
_, err := io.CopyBuffer(dst, src, buf) |
|
|
|
return err |
|
|
|
if _, err := io.CopyBuffer(dst, src, buf); err != nil { |
|
|
|
log.Debugf("[TCP] copy data for %s: %v", dir, err) |
|
|
|
} |
|
|
|
pool.Put(buf) |
|
|
|
// Do the upload/download side TCP half-close.
|
|
|
|
if cr, ok := src.(interface{ CloseRead() error }); ok { |
|
|
|
cr.CloseRead() |
|
|
|
} |
|
|
|
if cw, ok := dst.(interface{ CloseWrite() error }); ok { |
|
|
|
cw.CloseWrite() |
|
|
|
} |
|
|
|
// Set TCP half-close timeout.
|
|
|
|
dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) |
|
|
|
} |
|
|
|