Browse Source

Improve: tunnel/tcp pipe (#219)

Co-authored-by: xjasonlyu <xjasonlyu@gmail.com>
pull/248/head
LanceLi 2 years ago
committed by GitHub
parent
commit
2d0bd1d219
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 15
      tunnel/statistic/tracker.go
  2. 45
      tunnel/tcp.go

15
tunnel/statistic/tracker.go

@ -1,6 +1,7 @@
package statistic
import (
"errors"
"net"
"time"
@ -79,6 +80,20 @@ func (tt *tcpTracker) Close() error {
return tt.Conn.Close()
}
func (tt *tcpTracker) CloseRead() error {
if cr, ok := tt.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return errors.New("CloseRead is not implemented")
}
func (tt *tcpTracker) CloseWrite() error {
if cw, ok := tt.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return errors.New("CloseWrite is not implemented")
}
type udpTracker struct {
net.PacketConn `json:"-"`

45
tunnel/tcp.go

@ -1,7 +1,6 @@
package tunnel
import (
"errors"
"io"
"net"
"sync"
@ -43,42 +42,34 @@ func handleTCPConn(originConn adapter.TCPConn) {
defer remoteConn.Close()
log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
if err = pipe(originConn, remoteConn); err != nil {
log.Debugf("[TCP] %s <-> %s: %v", metadata.SourceAddress(), metadata.DestinationAddress(), err)
}
pipe(originConn, remoteConn)
}
// pipe copies copy data to & from provided net.Conn(s) bidirectionally.
func pipe(origin, remote net.Conn) error {
func pipe(origin, remote net.Conn) {
wg := sync.WaitGroup{}
wg.Add(2)
var leftErr, rightErr error
go unidirectionalStream(remote, origin, "origin->remote", &wg)
go unidirectionalStream(origin, remote, "remote->origin", &wg)
go func() {
defer wg.Done()
if err := copyBuffer(remote, origin); err != nil {
leftErr = errors.Join(leftErr, err)
wg.Wait()
}
remote.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}()
go func() {
func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) {
defer wg.Done()
if err := copyBuffer(origin, remote); err != nil {
rightErr = errors.Join(rightErr, err)
buf := pool.Get(pool.RelayBufferSize)
if _, err := io.CopyBuffer(dst, src, buf); err != nil {
log.Debugf("[TCP] copy data for %s: %v", dir, err)
}
origin.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}()
wg.Wait()
return errors.Join(leftErr, rightErr)
pool.Put(buf)
// Do the upload/download side TCP half-close.
if cr, ok := src.(interface{ CloseRead() error }); ok {
cr.CloseRead()
}
func copyBuffer(dst io.Writer, src io.Reader) error {
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf)
_, err := io.CopyBuffer(dst, src, buf)
return err
if cw, ok := dst.(interface{ CloseWrite() error }); ok {
cw.CloseWrite()
}
// Set TCP half-close timeout.
dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}

Loading…
Cancel
Save