Browse Source

Refactor: optimize UDP module

Symmetric NAT support for now.
pull/133/head
xjasonlyu 3 years ago
parent
commit
dd0cde04b4
  1. 28
      core/adapter.go
  2. 6
      core/handler.go
  3. 20
      core/stack/tcp.go
  4. 190
      core/stack/udp.go
  5. 8
      engine/tunnel.go
  6. 25
      tunnel/addr.go
  7. 23
      tunnel/tcp.go
  8. 51
      tunnel/tunnel.go
  9. 123
      tunnel/udp.go
  10. 20
      tunnel/util.go

28
core/adapter.go

@ -2,33 +2,15 @@ package core
import (
"net"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// TCPConn implements the net.Conn interface.
type TCPConn interface {
net.Conn
ID() *stack.TransportEndpointID
}
type UDPPacket interface {
// Data get the payload of UDP Packet.
Data() []byte
// Drop call after packet is used, could release resources in this function.
Drop()
// ID returns the transport endpoint id of packet.
ID() *stack.TransportEndpointID
// LocalAddr returns the source IP/Port of packet.
LocalAddr() net.Addr
// RemoteAddr returns the destination IP/Port of packet.
RemoteAddr() net.Addr
// WriteBack writes the payload with source IP/Port equals addr
// - variable source IP/Port is important to STUN
// - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target.
WriteBack([]byte, net.Addr) (int, error)
// UDPConn implements net.Conn and net.PacketConn.
type UDPConn interface {
net.Conn
net.PacketConn
}

6
core/handler.go

@ -1,6 +1,8 @@
package core
// Handler is a TCP/UDP connection handler that implements
// HandleTCPConn and HandleUDPConn methods.
type Handler interface {
Add(TCPConn)
AddPacket(UDPPacket)
HandleTCPConn(TCPConn)
HandleUDPConn(UDPConn)
}

20
core/stack/tcp.go

@ -2,12 +2,10 @@ package stack
import (
"fmt"
"net"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
@ -36,10 +34,9 @@ func withTCPHandler() Option {
return func(s *Stack) error {
tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
id := r.ID()
ep, err := r.CreateEndpoint(&wq)
if err != nil {
// prevent potential half-open TCP connection leak.
// RST: prevent potential half-open TCP connection leak.
r.Complete(true)
return
}
@ -47,11 +44,7 @@ func withTCPHandler() Option {
setKeepalive(ep)
conn := &tcpConn{
Conn: gonet.NewTCPConn(&wq, ep),
id: &id,
}
s.handler.Add(conn)
s.handler.HandleTCPConn(gonet.NewTCPConn(&wq, ep))
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
return nil
@ -72,12 +65,3 @@ func setKeepalive(ep tcpip.Endpoint) error {
}
return nil
}
type tcpConn struct {
net.Conn
id *stack.TransportEndpointID
}
func (c *tcpConn) ID() *stack.TransportEndpointID {
return c.id
}

190
core/stack/udp.go

@ -1,192 +1,24 @@
package stack
import (
"fmt"
"net"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// udpNoChecksum disables UDP checksum if set to true.
udpNoChecksum = true
"gvisor.dev/gvisor/pkg/waiter"
)
func withUDPHandler() Option {
return func(s *Stack) error {
udpHandlePacket := func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go HandlePacket.
udpHdr := header.UDP(pkt.TransportHeader().View())
if int(udpHdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
// Malformed packet.
s.Stats().UDP.MalformedPacketsReceived.Increment()
return true
}
if !verifyChecksum(udpHdr, pkt) {
// Checksum error.
s.Stats().UDP.ChecksumErrors.Increment()
return true
}
s.Stats().UDP.PacketsReceived.Increment()
packet := &udpPacket{
s: s,
id: &id,
data: pkt.Data().ExtractVV(),
nicID: pkt.NICID,
netHdr: pkt.Network(),
netProto: pkt.NetworkProtocolNumber,
udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) {
var wq waiter.Queue
ep, err := r.CreateEndpoint(&wq)
if err != nil {
// TODO: handler errors in the future.
return
}
s.handler.AddPacket(packet)
return true
}
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpHandlePacket)
s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep))
})
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
return nil
}
}
type udpPacket struct {
s *Stack
id *stack.TransportEndpointID
data buffer.VectorisedView
nicID tcpip.NICID
netHdr header.Network
netProto tcpip.NetworkProtocolNumber
}
func (p *udpPacket) Data() []byte {
return p.data.ToView()
}
func (p *udpPacket) Drop() {}
func (p *udpPacket) ID() *stack.TransportEndpointID {
return p.id
}
func (p *udpPacket) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(p.id.LocalAddress), Port: int(p.id.LocalPort)}
}
func (p *udpPacket) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(p.id.RemoteAddress), Port: int(p.id.RemotePort)}
}
func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
v := buffer.View(b)
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
}
var (
localAddress tcpip.Address
localPort uint16
)
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
localAddress = p.netHdr.DestinationAddress()
localPort = p.id.LocalPort
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
localAddress = tcpip.Address(ipv4)
localPort = uint16(udpAddr.Port)
} else {
localAddress = tcpip.Address(udpAddr.IP)
localPort = uint16(udpAddr.Port)
}
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
if err != nil {
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
}
defer route.Release()
data := v.ToVectorisedView()
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
return 0, fmt.Errorf("%v", err)
}
return data.Size(), nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, noChecksum bool) tcpip.Error {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
Data: data,
})
defer pkt.DecRef()
// Initialize the UDP header.
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
pkt.TransportProtocolNumber = udp.ProtocolNumber
length := uint16(pkt.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
// Set the checksum field unless TX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
if r.RequiresTXTransportChecksum() &&
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
}
ttl := r.DefaultTTL()
if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: udp.ProtocolNumber,
TTL: ttl,
TOS: 0, /* default */
}, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
return nil
}
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go verifyChecksum.
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
if pkt.RXTransportChecksumValidated {
return true
}
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation, as per RFC 768:
//
// An all zero transmitted checksum value means that the transmitter
// generated no checksum (for debugging or for higher level protocols that
// don't care).
//
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
//
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
// checksum is not optional.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
return true
}
netHdr := pkt.Network()
payloadChecksum := pkt.Data().AsRange().Checksum()
return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
}

8
engine/tunnel.go

@ -9,10 +9,10 @@ var _ core.Handler = (*fakeTunnel)(nil)
type fakeTunnel struct{}
func (*fakeTunnel) Add(conn core.TCPConn) {
tunnel.Add(conn)
func (*fakeTunnel) HandleTCPConn(conn core.TCPConn) {
tunnel.TCPIn() <- conn
}
func (*fakeTunnel) AddPacket(packet core.UDPPacket) {
tunnel.AddPacket(packet)
func (*fakeTunnel) HandleUDPConn(conn core.UDPConn) {
tunnel.UDPIn() <- conn
}

25
tunnel/addr.go

@ -0,0 +1,25 @@
package tunnel
import (
"net"
"strconv"
)
// parseAddr parses net.Addr to IP and port.
func parseAddr(addr net.Addr) (net.IP, uint16) {
switch v := addr.(type) {
case *net.TCPAddr:
return v.IP, uint16(v.Port)
case *net.UDPAddr:
return v.IP, uint16(v.Port)
default:
return parseAddrString(addr.String())
}
}
// parseAddrString parses address string to IP and port.
func parseAddrString(addr string) (net.IP, uint16) {
host, port, _ := net.SplitHostPort(addr)
portInt, _ := strconv.ParseUint(port, 10, 16)
return net.ParseIP(host), uint16(portInt)
}

23
tunnel/tcp.go

@ -22,16 +22,19 @@ func newTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
return statistic.NewTCPTracker(conn, metadata, statistic.DefaultManager)
}
func handleTCP(localConn core.TCPConn) {
func handleTCPConn(localConn core.TCPConn) {
defer localConn.Close()
id := localConn.ID()
var (
srcIP, srcPort = parseAddr(localConn.RemoteAddr())
dstIP, dstPort = parseAddr(localConn.LocalAddr())
)
metadata := &M.Metadata{
Net: M.TCP,
SrcIP: net.IP(id.RemoteAddress),
SrcPort: id.RemotePort,
DstIP: net.IP(id.LocalAddress),
DstPort: id.LocalPort,
SrcIP: srcIP,
SrcPort: srcPort,
DstIP: dstIP,
DstPort: dstPort,
}
targetConn, err := proxy.Dial(metadata)
@ -39,13 +42,7 @@ func handleTCP(localConn core.TCPConn) {
log.Warnf("[TCP] dial %s error: %v", metadata.DestinationAddress(), err)
return
}
if dialerAddr, ok := targetConn.LocalAddr().(*net.TCPAddr); ok {
metadata.MidIP = dialerAddr.IP
metadata.MidPort = uint16(dialerAddr.Port)
} else { /* fallback */
metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr().String())
}
metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr())
targetConn = newTCPTracker(targetConn, metadata)
defer targetConn.Close()

51
tunnel/tunnel.go

@ -1,55 +1,36 @@
package tunnel
import (
"runtime"
"github.com/xjasonlyu/tun2socks/v2/core"
"github.com/xjasonlyu/tun2socks/v2/log"
)
const (
// maxUDPQueueSize is the max number of UDP packets
// could be buffered. if queue is full, upcoming packets
// would be dropped util queue is ready again.
maxUDPQueueSize = 1 << 9
)
// Unbuffered TCP/UDP queues.
var (
_tcpQueue = make(chan core.TCPConn) /* unbuffered */
_udpQueue = make(chan core.UDPPacket, maxUDPQueueSize)
_numUDPWorkers = max(runtime.GOMAXPROCS(0), 4 /* at least 4 workers */)
_tcpQueue = make(chan core.TCPConn)
_udpQueue = make(chan core.UDPConn)
)
func init() {
go process()
}
// Add adds tcpConn to tcpQueue.
func Add(conn core.TCPConn) {
_tcpQueue <- conn
// TCPIn return fan-in TCP queue.
func TCPIn() chan<- core.TCPConn {
return _tcpQueue
}
// AddPacket adds udpPacket to udpQueue.
func AddPacket(packet core.UDPPacket) {
select {
case _udpQueue <- packet:
default:
log.Warnf("queue is currently full, packet will be dropped")
packet.Drop()
}
// UDPIn return fan-in UDP queue.
func UDPIn() chan<- core.UDPConn {
return _udpQueue
}
func process() {
for i := 0; i < _numUDPWorkers; i++ {
queue := _udpQueue
go func() {
for packet := range queue {
handleUDP(packet)
}
}()
}
for conn := range _tcpQueue {
go handleTCP(conn)
for {
select {
case conn := <-_tcpQueue:
go handleTCPConn(conn)
case conn := <-_udpQueue:
go handleUDPConn(conn)
}
}
}

123
tunnel/udp.go

@ -7,7 +7,6 @@ import (
"time"
"github.com/xjasonlyu/tun2socks/v2/common/pool"
"github.com/xjasonlyu/tun2socks/v2/component/nat"
"github.com/xjasonlyu/tun2socks/v2/core"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
@ -15,15 +14,8 @@ import (
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)
var (
// _natTable uses source udp packet information
// as key to store destination udp packetConn.
_natTable = nat.NewTable()
// _udpSessionTimeout is the default timeout for
// each UDP session.
_udpSessionTimeout = 60 * time.Second
)
// _udpSessionTimeout is the default timeout for each UDP session.
var _udpSessionTimeout = 60 * time.Second
func SetUDPTimeout(v int) {
_udpSessionTimeout = time.Duration(v) * time.Second
@ -33,98 +25,58 @@ func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager)
}
func handleUDP(packet core.UDPPacket) {
id := packet.ID()
func handleUDPConn(uc core.UDPConn) {
defer uc.Close()
var (
srcIP, srcPort = parseAddr(uc.RemoteAddr())
dstIP, dstPort = parseAddr(uc.LocalAddr())
)
metadata := &M.Metadata{
Net: M.UDP,
SrcIP: net.IP(id.RemoteAddress),
SrcPort: id.RemotePort,
DstIP: net.IP(id.LocalAddress),
DstPort: id.LocalPort,
SrcIP: srcIP,
SrcPort: srcPort,
DstIP: dstIP,
DstPort: dstPort,
}
generateNATKey := func(m *M.Metadata) string {
return m.SourceAddress() /* as Full Cone NAT Key */
}
key := generateNATKey(metadata)
handle := func(drop bool) bool {
pc := _natTable.Get(key)
if pc != nil {
handleUDPToRemote(packet, pc, metadata /* as net.Addr */, drop)
return true
}
return false
}
if handle(true /* drop */) {
pc, err := proxy.DialUDP(metadata)
if err != nil {
log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err)
return
}
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())
lockKey := key + "-lock"
cond, loaded := _natTable.GetOrCreateLock(lockKey)
go func() {
if loaded {
cond.L.Lock()
cond.Wait()
handle(true) /* drop after sending data to remote */
cond.L.Unlock()
return
}
pc = newUDPTracker(pc, metadata)
defer pc.Close()
defer func() {
_natTable.Delete(lockKey)
cond.Broadcast()
}()
go handleUDPToRemote(uc, pc, metadata)
handleUDPToLocal(uc, pc, metadata)
}
func handleUDPToRemote(uc core.UDPConn, pc net.PacketConn, remote net.Addr) {
buf := pool.Get(pool.MaxSegmentSize)
defer pool.Put(buf)
pc, err := proxy.DialUDP(metadata)
for {
n, err := uc.Read(buf)
if err != nil {
log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err)
return
}
if dialerAddr, ok := pc.LocalAddr().(*net.UDPAddr); ok {
metadata.MidIP = dialerAddr.IP
metadata.MidPort = uint16(dialerAddr.Port)
} else { /* fallback */
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr().String())
if _, err := pc.WriteTo(buf[:n], remote); err != nil {
log.Warnf("[UDP] write to %s error: %v", remote, err)
}
pc = newUDPTracker(pc, metadata)
go func() {
defer pc.Close()
defer packet.Drop()
defer _natTable.Delete(key)
handleUDPToLocal(packet, pc)
}()
_natTable.Set(key, pc)
handle(false /* drop */)
}()
}
func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr, drop bool) {
defer func() {
if drop {
packet.Drop()
}
}()
if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil {
log.Warnf("[UDP] write to %s error: %v", remote, err)
log.Infof("[UDP] %s --> %s", uc.RemoteAddr(), remote)
}
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */
log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote)
}
func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
func handleUDPToLocal(uc core.UDPConn, pc net.PacketConn, remote net.Addr) {
buf := pool.Get(pool.MaxSegmentSize)
defer pool.Put(buf)
for /* just loop */ {
for {
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout))
n, from, err := pc.ReadFrom(buf)
if err != nil {
@ -134,11 +86,14 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
return
}
if _, err := packet.WriteBack(buf[:n], from); err != nil {
log.Warnf("[UDP] write back from %s error: %v", from, err)
if from.Network() != remote.Network() || from.String() != remote.String() {
log.Warnf("[UDP] drop unknown packet from %s", from)
return
}
log.Infof("[UDP] %s <-- %s", packet.RemoteAddr(), from)
if _, err := uc.Write(buf[:n]); err != nil {
log.Warnf("[UDP] write back from %s error: %v", from, err)
return
}
}
}

20
tunnel/util.go

@ -1,20 +0,0 @@
package tunnel
import (
"net"
"strconv"
)
func max(a, b int) int {
if a > b {
return a
}
return b
}
// parseAddr parses address to IP and port.
func parseAddr(addr string) (net.IP, uint16) {
host, portStr, _ := net.SplitHostPort(addr)
portInt, _ := strconv.ParseUint(portStr, 10, 16)
return net.ParseIP(host), uint16(portInt)
}
Loading…
Cancel
Save