package tunnel import ( "io" "net" "sync" "time" "github.com/xjasonlyu/tun2socks/v2/buffer" "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) // TODO: Port Restricted NAT support. func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) { defer uc.Close() id := uc.ID() metadata := &M.Metadata{ Network: M.UDP, SrcIP: parseTCPIPAddress(id.RemoteAddress), SrcPort: id.RemotePort, DstIP: parseTCPIPAddress(id.LocalAddress), DstPort: id.LocalPort, } pc, err := t.Dialer().DialUDP(metadata) if err != nil { log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) return } metadata.MidIP, metadata.MidPort = parseNetAddr(pc.LocalAddr()) pc = statistic.NewUDPTracker(pc, metadata, t.manager) defer pc.Close() var remote net.Addr if udpAddr := metadata.UDPAddr(); udpAddr != nil { remote = udpAddr } else { remote = metadata.Addr() } pc = newSymmetricNATPacketConn(pc, metadata) log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) pipePacket(uc, pc, remote, t.udpTimeout.Load()) } func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) { wg := sync.WaitGroup{} wg.Add(2) go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout) go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout) wg.Wait() } func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() if err := copyPacketData(dst, src, to, timeout); err != nil { log.Debugf("[UDP] copy data for %s: %v", dir, err) } } func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error { buf := buffer.Get(buffer.MaxSegmentSize) defer buffer.Put(buf) for { src.SetReadDeadline(time.Now().Add(timeout)) n, _, err := src.ReadFrom(buf) if ne, ok := err.(net.Error); ok && ne.Timeout() { return nil /* ignore I/O timeout */ } else if err == io.EOF { return nil /* ignore EOF */ } else if err != nil { return err } if _, err = dst.WriteTo(buf[:n], to); err != nil { return err } dst.SetReadDeadline(time.Now().Add(timeout)) } } type symmetricNATPacketConn struct { net.PacketConn src string dst string } func newSymmetricNATPacketConn(pc net.PacketConn, metadata *M.Metadata) *symmetricNATPacketConn { return &symmetricNATPacketConn{ PacketConn: pc, src: metadata.SourceAddress(), dst: metadata.DestinationAddress(), } } func (pc *symmetricNATPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { for { n, from, err := pc.PacketConn.ReadFrom(p) if from != nil && from.String() != pc.dst { log.Warnf("[UDP] symmetric NAT %s->%s: drop packet from %s", pc.src, pc.dst, from) continue } return n, from, err } }