You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
3.0 KiB

4 years ago
package tunnel
import (
"net"
"sync"
4 years ago
"time"
"github.com/xjasonlyu/tun2socks/v2/common/pool"
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
4 years ago
)
// _udpSessionTimeout is the default timeout for each UDP session.
var _udpSessionTimeout = 60 * time.Second
4 years ago
func SetUDPTimeout(t time.Duration) {
_udpSessionTimeout = t
}
func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager)
}
3 years ago
// TODO: Port Restricted NAT support.
func handleUDPConn(uc adapter.UDPConn) {
defer uc.Close()
id := uc.ID()
metadata := &M.Metadata{
Network: M.UDP,
SrcIP: net.IP(id.RemoteAddress),
SrcPort: id.RemotePort,
DstIP: net.IP(id.LocalAddress),
DstPort: id.LocalPort,
}
4 years ago
pc, err := proxy.DialUDP(metadata)
if err != nil {
log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err)
4 years ago
return
}
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())
4 years ago
pc = newUDPTracker(pc, metadata)
defer pc.Close()
4 years ago
var remote net.Addr
if udpAddr := metadata.UDPAddr(); udpAddr != nil {
remote = udpAddr
} else {
remote = metadata.Addr()
}
pc = newSymmetricNATPacketConn(pc, metadata)
log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
relayPacket(uc, pc, remote)
}
func relayPacket(left net.PacketConn, right net.PacketConn, to net.Addr) {
wg := sync.WaitGroup{}
wg.Add(2)
4 years ago
go func() {
defer wg.Done()
if err := copyPacketBuffer(right, left, to, _udpSessionTimeout); err != nil {
log.Warnf("[UDP] %v", err)
4 years ago
}
}()
4 years ago
go func() {
defer wg.Done()
if err := copyPacketBuffer(left, right, nil, _udpSessionTimeout); err != nil {
log.Warnf("[UDP] %v", err)
4 years ago
}
}()
4 years ago
wg.Wait()
4 years ago
}
func copyPacketBuffer(dst net.PacketConn, src net.PacketConn, to net.Addr, timeout time.Duration) error {
4 years ago
buf := pool.Get(pool.MaxSegmentSize)
4 years ago
defer pool.Put(buf)
for {
src.SetReadDeadline(time.Now().Add(timeout))
n, _, err := src.ReadFrom(buf)
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return nil /* ignore I/O timeout */
} else if err != nil {
return err
4 years ago
}
if _, err = dst.WriteTo(buf[:n], to); err != nil {
return err
}
dst.SetReadDeadline(time.Now().Add(timeout))
4 years ago
}
}
type symmetricNATPacketConn struct {
net.PacketConn
src string
dst string
}
func newSymmetricNATPacketConn(pc net.PacketConn, metadata *M.Metadata) *symmetricNATPacketConn {
return &symmetricNATPacketConn{
PacketConn: pc,
src: metadata.SourceAddress(),
dst: metadata.DestinationAddress(),
}
}
func (pc *symmetricNATPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
for {
n, from, err := pc.PacketConn.ReadFrom(p)
if from != nil && from.String() != pc.dst {
log.Warnf("[UDP] symmetric NAT %s->%s: drop packet from %s", pc.src, pc.dst, from)
continue
}
return n, from, err
}
}