diff --git a/proxy/socks5.go b/proxy/socks5.go index 290adae..bdc9b04 100644 --- a/proxy/socks5.go +++ b/proxy/socks5.go @@ -186,5 +186,5 @@ func (pc *socksPacketConn) Close() error { } func serializeSocksAddr(m *M.Metadata) socks5.Addr { - return socks5.SerializeAddr("", m.DstIP.AsSlice(), m.DstPort) + return socks5.SerializeAddr("", m.DstIP, m.DstPort) } diff --git a/transport/socks4/socks4.go b/transport/socks4/socks4.go index 9eab102..514f2d5 100644 --- a/transport/socks4/socks4.go +++ b/transport/socks4/socks4.go @@ -7,6 +7,7 @@ import ( "errors" "io" "net" + "net/netip" "strconv" ) @@ -31,6 +32,7 @@ const ( var ( errVersionMismatched = errors.New("version code mismatched") errIPv6NotSupported = errors.New("IPv6 not supported") + errCmdNotSupported = errors.New("command not supported") ErrRequestRejected = errors.New("request rejected or failed") ErrRequestIdentdFailed = errors.New("request rejected because SOCKS server cannot connect to identd on the client") @@ -39,6 +41,10 @@ var ( ) func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID string) (err error) { + if command == CmdBind { + return errCmdNotSupported + } + var ( host string port uint16 @@ -47,24 +53,26 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri return err } - ip := net.ParseIP(host) - if ip == nil /* HOST */ { - ip = net.IPv4(0, 0, 0, 1) - } else if ip.To4() == nil /* IPv6 */ { + ip, _ := netip.ParseAddr(host) + switch { + case !ip.IsValid(): /* HOST */ + ip = netip.AddrFrom4([4]byte{0, 0, 0, 1}) + case ip.Is4In6(): /* IPv4-mapped IPv6 */ + ip = netip.AddrFrom4(ip.As4()) + case ip.Is4(): /* IPv4 */ + case ip.Is6(): /* IPv6 */ return errIPv6NotSupported } - dstIP := /* [4]byte */ ip.To4() - req := &bytes.Buffer{} req.WriteByte(Version) req.WriteByte(command) - binary.Write(req, binary.BigEndian, port) - req.Write(dstIP) + _ = binary.Write(req, binary.BigEndian, port) + req.Write(ip.AsSlice()) req.WriteString(userID) req.WriteByte(0) /* NULL */ - if isReservedIP(dstIP) /* SOCKS4A */ { + if isReservedIP(ip) /* SOCKS4A */ { req.WriteString(host) req.WriteByte(0) /* NULL */ } @@ -103,13 +111,9 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri // Internet Assigned Numbers Authority -- such an address is inadmissible // as a destination IP address and thus should never occur if the client // can resolve the domain name.) -func isReservedIP(ip net.IP) bool { - subnet := net.IPNet{ - IP: net.IPv4zero, - Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00), - } - - return !ip.IsUnspecified() && subnet.Contains(ip) +func isReservedIP(ip netip.Addr) bool { + prefix := netip.PrefixFrom(netip.IPv4Unspecified(), 24) + return !ip.IsUnspecified() && prefix.Contains(ip) } func splitHostPort(addr string) (string, uint16, error) { diff --git a/transport/socks4/socks4_test.go b/transport/socks4/socks4_test.go new file mode 100644 index 0000000..c38b399 --- /dev/null +++ b/transport/socks4/socks4_test.go @@ -0,0 +1,84 @@ +package socks4 + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsReservedIP(t *testing.T) { + reservedIPs := []string{ + "0.0.0.1", + "0.0.0.2", + "0.0.0.50", + "0.0.0.100", + "0.0.0.255", + } + for _, ip := range reservedIPs { + assert.True(t, isReservedIP(netip.MustParseAddr(ip))) + } + + unReservedIPs := []string{ + "0.0.0.0", + "0.0.1.0", + "1.1.1.1", + "10.0.0.0", + "255.255.255.255", + } + for _, ip := range unReservedIPs { + assert.False(t, isReservedIP(netip.MustParseAddr(ip))) + } +} + +func TestSplitHostPort(t *testing.T) { + tests := []struct { + addr string + host string + port uint16 + }{ + { + "1.1.1.1:80", + "1.1.1.1", + 80, + }, + { + "1.1.1.1:0", + "1.1.1.1", + 0, + }, + { + "0.0.0.0:0", + "0.0.0.0", + 0, + }, + { + "[::1]:443", + "::1", + 443, + }, + { + "example.com:80", + "example.com", + 80, + }, + } + for _, tt := range tests { + host, port, err := splitHostPort(tt.addr) + assert.NoError(t, err) + assert.Equal(t, tt.host, host) + assert.Equal(t, tt.port, port) + } + + addrs := []string{ + "1.1.1.1:-80", + "1.1.1.1:abcd", + "::1:80", + "[::1]", + "example.com", + } + for _, addr := range addrs { + _, _, err := splitHostPort(addr) + assert.Error(t, err) + } +} diff --git a/transport/socks5/socks5.go b/transport/socks5/socks5.go index e257149..43bab41 100644 --- a/transport/socks5/socks5.go +++ b/transport/socks5/socks5.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "net/netip" "strconv" ) @@ -308,7 +309,7 @@ func SplitAddr(b []byte) Addr { // SerializeAddr serializes destination address and port to Addr. // If a domain name is provided, AtypDomainName would be used first. -func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr { +func SerializeAddr(domainName string, dstIP netip.Addr, dstPort uint16) Addr { var ( buf [][]byte port [2]byte @@ -318,10 +319,10 @@ func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr { if domainName != "" /* Domain Name */ { length := len(domainName) buf = [][]byte{{AtypDomainName, uint8(length)}, []byte(domainName), port[:]} - } else if dstIP.To4() != nil /* IPv4 */ { - buf = [][]byte{{AtypIPv4}, dstIP.To4(), port[:]} + } else if dstIP.Is4() /* IPv4 */ { + buf = [][]byte{{AtypIPv4}, dstIP.AsSlice(), port[:]} } else /* IPv6 */ { - buf = [][]byte{{AtypIPv6}, dstIP.To16(), port[:]} + buf = [][]byte{{AtypIPv6}, dstIP.AsSlice(), port[:]} } return bytes.Join(buf, nil) } @@ -329,14 +330,13 @@ func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr { // ParseAddr parses a socks addr from net.Addr. // This is a fast path of ParseAddrString(addr.String()) func ParseAddr(addr net.Addr) Addr { - switch v := addr.(type) { - case *net.TCPAddr: - return SerializeAddr("", v.IP, uint16(v.Port)) - case *net.UDPAddr: - return SerializeAddr("", v.IP, uint16(v.Port)) - default: - return ParseAddrString(addr.String()) + if v, ok := addr.(interface { + AddrPort() netip.AddrPort + }); ok { + ap := v.AddrPort() + return SerializeAddr("", ap.Addr(), ap.Port()) } + return ParseAddrString(addr.String()) } // ParseAddrString parses the address in string s to Addr. Returns nil if failed. @@ -351,10 +351,10 @@ func ParseAddrString(s string) Addr { return nil } - if ip := net.ParseIP(host); ip != nil { + if ip, _ := netip.ParseAddr(host); ip.IsValid() { return SerializeAddr("", ip, uint16(dstPort)) } - return SerializeAddr(host, nil, uint16(dstPort)) + return SerializeAddr(host, netip.Addr{}, uint16(dstPort)) } // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`