diff --git a/core/adapter/adapter.go b/core/adapter/adapter.go index e4e4296..9a5649e 100644 --- a/core/adapter/adapter.go +++ b/core/adapter/adapter.go @@ -2,15 +2,23 @@ package adapter import ( "net" + + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // TCPConn implements the net.Conn interface. type TCPConn interface { net.Conn + + // ID returns the transport endpoint id of TCPConn. + ID() *stack.TransportEndpointID } // UDPConn implements net.Conn and net.PacketConn. type UDPConn interface { net.Conn net.PacketConn + + // ID returns the transport endpoint id of UDPConn. + ID() *stack.TransportEndpointID } diff --git a/core/stack/tcp.go b/core/stack/tcp.go index ad468f3..d726c07 100644 --- a/core/stack/tcp.go +++ b/core/stack/tcp.go @@ -6,6 +6,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -40,11 +41,15 @@ func withTCPHandler() Option { r.Complete(true) return } - r.Complete(false) + defer r.Complete(false) setKeepalive(ep) - s.handler.HandleTCPConn(gonet.NewTCPConn(&wq, ep)) + conn := &tcpConn{ + TCPConn: gonet.NewTCPConn(&wq, ep), + id: r.ID(), + } + s.handler.HandleTCPConn(conn) }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) return nil @@ -65,3 +70,12 @@ func setKeepalive(ep tcpip.Endpoint) error { } return nil } + +type tcpConn struct { + *gonet.TCPConn + id stack.TransportEndpointID +} + +func (c *tcpConn) ID() *stack.TransportEndpointID { + return &c.id +} diff --git a/core/stack/udp.go b/core/stack/udp.go index 4a130e7..d4b9aef 100644 --- a/core/stack/udp.go +++ b/core/stack/udp.go @@ -2,6 +2,7 @@ package stack import ( "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -16,9 +17,22 @@ func withUDPHandler() Option { return } - s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep)) + conn := &udpConn{ + UDPConn: gonet.NewUDPConn(s.Stack, &wq, ep), + id: r.ID(), + } + s.handler.HandleUDPConn(conn) }) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) return nil } } + +type udpConn struct { + *gonet.UDPConn + id stack.TransportEndpointID +} + +func (c *udpConn) ID() *stack.TransportEndpointID { + return &c.id +} diff --git a/tunnel/addr.go b/tunnel/addr.go index bf06b46..eb7ed57 100644 --- a/tunnel/addr.go +++ b/tunnel/addr.go @@ -12,6 +12,8 @@ func parseAddr(addr net.Addr) (net.IP, uint16) { return v.IP, uint16(v.Port) case *net.UDPAddr: return v.IP, uint16(v.Port) + case nil: + return nil, 0 default: return parseAddrString(addr.String()) } diff --git a/tunnel/tcp.go b/tunnel/tcp.go index e0a7017..3705cfb 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -25,16 +25,13 @@ func newTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn { func handleTCPConn(localConn adapter.TCPConn) { defer localConn.Close() - var ( - srcIP, srcPort = parseAddr(localConn.RemoteAddr()) - dstIP, dstPort = parseAddr(localConn.LocalAddr()) - ) + id := localConn.ID() metadata := &M.Metadata{ Network: M.TCP, - SrcIP: srcIP, - SrcPort: srcPort, - DstIP: dstIP, - DstPort: dstPort, + SrcIP: net.IP(id.RemoteAddress), + SrcPort: id.RemotePort, + DstIP: net.IP(id.LocalAddress), + DstPort: id.LocalPort, } targetConn, err := proxy.Dial(metadata) diff --git a/tunnel/udp.go b/tunnel/udp.go index b6f1f03..983d740 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -29,16 +29,13 @@ func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { func handleUDPConn(uc adapter.UDPConn) { defer uc.Close() - var ( - srcIP, srcPort = parseAddr(uc.RemoteAddr()) - dstIP, dstPort = parseAddr(uc.LocalAddr()) - ) + id := uc.ID() metadata := &M.Metadata{ Network: M.UDP, - SrcIP: srcIP, - SrcPort: srcPort, - DstIP: dstIP, - DstPort: dstPort, + SrcIP: net.IP(id.RemoteAddress), + SrcPort: id.RemotePort, + DstIP: net.IP(id.LocalAddress), + DstPort: id.LocalPort, } pc, err := proxy.DialUDP(metadata)