Browse Source

Fix: nil addr panic (#99)

pull/133/head
xjasonlyu 3 years ago
parent
commit
5679d15442
  1. 8
      core/adapter/adapter.go
  2. 18
      core/stack/tcp.go
  3. 16
      core/stack/udp.go
  4. 2
      tunnel/addr.go
  5. 13
      tunnel/tcp.go
  6. 13
      tunnel/udp.go

8
core/adapter/adapter.go

@ -2,15 +2,23 @@ package adapter
import ( import (
"net" "net"
"gvisor.dev/gvisor/pkg/tcpip/stack"
) )
// TCPConn implements the net.Conn interface. // TCPConn implements the net.Conn interface.
type TCPConn interface { type TCPConn interface {
net.Conn net.Conn
// ID returns the transport endpoint id of TCPConn.
ID() *stack.TransportEndpointID
} }
// UDPConn implements net.Conn and net.PacketConn. // UDPConn implements net.Conn and net.PacketConn.
type UDPConn interface { type UDPConn interface {
net.Conn net.Conn
net.PacketConn net.PacketConn
// ID returns the transport endpoint id of UDPConn.
ID() *stack.TransportEndpointID
} }

18
core/stack/tcp.go

@ -6,6 +6,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
) )
@ -40,11 +41,15 @@ func withTCPHandler() Option {
r.Complete(true) r.Complete(true)
return return
} }
r.Complete(false) defer r.Complete(false)
setKeepalive(ep) setKeepalive(ep)
s.handler.HandleTCPConn(gonet.NewTCPConn(&wq, ep)) conn := &tcpConn{
TCPConn: gonet.NewTCPConn(&wq, ep),
id: r.ID(),
}
s.handler.HandleTCPConn(conn)
}) })
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
return nil return nil
@ -65,3 +70,12 @@ func setKeepalive(ep tcpip.Endpoint) error {
} }
return nil return nil
} }
type tcpConn struct {
*gonet.TCPConn
id stack.TransportEndpointID
}
func (c *tcpConn) ID() *stack.TransportEndpointID {
return &c.id
}

16
core/stack/udp.go

@ -2,6 +2,7 @@ package stack
import ( import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
) )
@ -16,9 +17,22 @@ func withUDPHandler() Option {
return return
} }
s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep)) conn := &udpConn{
UDPConn: gonet.NewUDPConn(s.Stack, &wq, ep),
id: r.ID(),
}
s.handler.HandleUDPConn(conn)
}) })
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil return nil
} }
} }
type udpConn struct {
*gonet.UDPConn
id stack.TransportEndpointID
}
func (c *udpConn) ID() *stack.TransportEndpointID {
return &c.id
}

2
tunnel/addr.go

@ -12,6 +12,8 @@ func parseAddr(addr net.Addr) (net.IP, uint16) {
return v.IP, uint16(v.Port) return v.IP, uint16(v.Port)
case *net.UDPAddr: case *net.UDPAddr:
return v.IP, uint16(v.Port) return v.IP, uint16(v.Port)
case nil:
return nil, 0
default: default:
return parseAddrString(addr.String()) return parseAddrString(addr.String())
} }

13
tunnel/tcp.go

@ -25,16 +25,13 @@ func newTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
func handleTCPConn(localConn adapter.TCPConn) { func handleTCPConn(localConn adapter.TCPConn) {
defer localConn.Close() defer localConn.Close()
var ( id := localConn.ID()
srcIP, srcPort = parseAddr(localConn.RemoteAddr())
dstIP, dstPort = parseAddr(localConn.LocalAddr())
)
metadata := &M.Metadata{ metadata := &M.Metadata{
Network: M.TCP, Network: M.TCP,
SrcIP: srcIP, SrcIP: net.IP(id.RemoteAddress),
SrcPort: srcPort, SrcPort: id.RemotePort,
DstIP: dstIP, DstIP: net.IP(id.LocalAddress),
DstPort: dstPort, DstPort: id.LocalPort,
} }
targetConn, err := proxy.Dial(metadata) targetConn, err := proxy.Dial(metadata)

13
tunnel/udp.go

@ -29,16 +29,13 @@ func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
func handleUDPConn(uc adapter.UDPConn) { func handleUDPConn(uc adapter.UDPConn) {
defer uc.Close() defer uc.Close()
var ( id := uc.ID()
srcIP, srcPort = parseAddr(uc.RemoteAddr())
dstIP, dstPort = parseAddr(uc.LocalAddr())
)
metadata := &M.Metadata{ metadata := &M.Metadata{
Network: M.UDP, Network: M.UDP,
SrcIP: srcIP, SrcIP: net.IP(id.RemoteAddress),
SrcPort: srcPort, SrcPort: id.RemotePort,
DstIP: dstIP, DstIP: net.IP(id.LocalAddress),
DstPort: dstPort, DstPort: id.LocalPort,
} }
pc, err := proxy.DialUDP(metadata) pc, err := proxy.DialUDP(metadata)

Loading…
Cancel
Save