Browse Source

Refactor(socks): replace net.IP with netip.Addr (#397)

pull/398/head
Jason Lyu 2 months ago
committed by GitHub
parent
commit
978803cdf8
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 2
      proxy/socks5.go
  2. 36
      transport/socks4/socks4.go
  3. 84
      transport/socks4/socks4_test.go
  4. 26
      transport/socks5/socks5.go

2
proxy/socks5.go

@ -186,5 +186,5 @@ func (pc *socksPacketConn) Close() error {
} }
func serializeSocksAddr(m *M.Metadata) socks5.Addr { func serializeSocksAddr(m *M.Metadata) socks5.Addr {
return socks5.SerializeAddr("", m.DstIP.AsSlice(), m.DstPort) return socks5.SerializeAddr("", m.DstIP, m.DstPort)
} }

36
transport/socks4/socks4.go

@ -7,6 +7,7 @@ import (
"errors" "errors"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
) )
@ -31,6 +32,7 @@ const (
var ( var (
errVersionMismatched = errors.New("version code mismatched") errVersionMismatched = errors.New("version code mismatched")
errIPv6NotSupported = errors.New("IPv6 not supported") errIPv6NotSupported = errors.New("IPv6 not supported")
errCmdNotSupported = errors.New("command not supported")
ErrRequestRejected = errors.New("request rejected or failed") ErrRequestRejected = errors.New("request rejected or failed")
ErrRequestIdentdFailed = errors.New("request rejected because SOCKS server cannot connect to identd on the client") ErrRequestIdentdFailed = errors.New("request rejected because SOCKS server cannot connect to identd on the client")
@ -39,6 +41,10 @@ var (
) )
func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID string) (err error) { func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID string) (err error) {
if command == CmdBind {
return errCmdNotSupported
}
var ( var (
host string host string
port uint16 port uint16
@ -47,24 +53,26 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
return err return err
} }
ip := net.ParseIP(host) ip, _ := netip.ParseAddr(host)
if ip == nil /* HOST */ { switch {
ip = net.IPv4(0, 0, 0, 1) case !ip.IsValid(): /* HOST */
} else if ip.To4() == nil /* IPv6 */ { ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
case ip.Is4In6(): /* IPv4-mapped IPv6 */
ip = netip.AddrFrom4(ip.As4())
case ip.Is4(): /* IPv4 */
case ip.Is6(): /* IPv6 */
return errIPv6NotSupported return errIPv6NotSupported
} }
dstIP := /* [4]byte */ ip.To4()
req := &bytes.Buffer{} req := &bytes.Buffer{}
req.WriteByte(Version) req.WriteByte(Version)
req.WriteByte(command) req.WriteByte(command)
binary.Write(req, binary.BigEndian, port) _ = binary.Write(req, binary.BigEndian, port)
req.Write(dstIP) req.Write(ip.AsSlice())
req.WriteString(userID) req.WriteString(userID)
req.WriteByte(0) /* NULL */ req.WriteByte(0) /* NULL */
if isReservedIP(dstIP) /* SOCKS4A */ { if isReservedIP(ip) /* SOCKS4A */ {
req.WriteString(host) req.WriteString(host)
req.WriteByte(0) /* NULL */ req.WriteByte(0) /* NULL */
} }
@ -103,13 +111,9 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
// Internet Assigned Numbers Authority -- such an address is inadmissible // Internet Assigned Numbers Authority -- such an address is inadmissible
// as a destination IP address and thus should never occur if the client // as a destination IP address and thus should never occur if the client
// can resolve the domain name.) // can resolve the domain name.)
func isReservedIP(ip net.IP) bool { func isReservedIP(ip netip.Addr) bool {
subnet := net.IPNet{ prefix := netip.PrefixFrom(netip.IPv4Unspecified(), 24)
IP: net.IPv4zero, return !ip.IsUnspecified() && prefix.Contains(ip)
Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00),
}
return !ip.IsUnspecified() && subnet.Contains(ip)
} }
func splitHostPort(addr string) (string, uint16, error) { func splitHostPort(addr string) (string, uint16, error) {

84
transport/socks4/socks4_test.go

@ -0,0 +1,84 @@
package socks4
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsReservedIP(t *testing.T) {
reservedIPs := []string{
"0.0.0.1",
"0.0.0.2",
"0.0.0.50",
"0.0.0.100",
"0.0.0.255",
}
for _, ip := range reservedIPs {
assert.True(t, isReservedIP(netip.MustParseAddr(ip)))
}
unReservedIPs := []string{
"0.0.0.0",
"0.0.1.0",
"1.1.1.1",
"10.0.0.0",
"255.255.255.255",
}
for _, ip := range unReservedIPs {
assert.False(t, isReservedIP(netip.MustParseAddr(ip)))
}
}
func TestSplitHostPort(t *testing.T) {
tests := []struct {
addr string
host string
port uint16
}{
{
"1.1.1.1:80",
"1.1.1.1",
80,
},
{
"1.1.1.1:0",
"1.1.1.1",
0,
},
{
"0.0.0.0:0",
"0.0.0.0",
0,
},
{
"[::1]:443",
"::1",
443,
},
{
"example.com:80",
"example.com",
80,
},
}
for _, tt := range tests {
host, port, err := splitHostPort(tt.addr)
assert.NoError(t, err)
assert.Equal(t, tt.host, host)
assert.Equal(t, tt.port, port)
}
addrs := []string{
"1.1.1.1:-80",
"1.1.1.1:abcd",
"::1:80",
"[::1]",
"example.com",
}
for _, addr := range addrs {
_, _, err := splitHostPort(addr)
assert.Error(t, err)
}
}

26
transport/socks5/socks5.go

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
) )
@ -308,7 +309,7 @@ func SplitAddr(b []byte) Addr {
// SerializeAddr serializes destination address and port to Addr. // SerializeAddr serializes destination address and port to Addr.
// If a domain name is provided, AtypDomainName would be used first. // If a domain name is provided, AtypDomainName would be used first.
func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr { func SerializeAddr(domainName string, dstIP netip.Addr, dstPort uint16) Addr {
var ( var (
buf [][]byte buf [][]byte
port [2]byte port [2]byte
@ -318,10 +319,10 @@ func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr {
if domainName != "" /* Domain Name */ { if domainName != "" /* Domain Name */ {
length := len(domainName) length := len(domainName)
buf = [][]byte{{AtypDomainName, uint8(length)}, []byte(domainName), port[:]} buf = [][]byte{{AtypDomainName, uint8(length)}, []byte(domainName), port[:]}
} else if dstIP.To4() != nil /* IPv4 */ { } else if dstIP.Is4() /* IPv4 */ {
buf = [][]byte{{AtypIPv4}, dstIP.To4(), port[:]} buf = [][]byte{{AtypIPv4}, dstIP.AsSlice(), port[:]}
} else /* IPv6 */ { } else /* IPv6 */ {
buf = [][]byte{{AtypIPv6}, dstIP.To16(), port[:]} buf = [][]byte{{AtypIPv6}, dstIP.AsSlice(), port[:]}
} }
return bytes.Join(buf, nil) return bytes.Join(buf, nil)
} }
@ -329,14 +330,13 @@ func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr {
// ParseAddr parses a socks addr from net.Addr. // ParseAddr parses a socks addr from net.Addr.
// This is a fast path of ParseAddrString(addr.String()) // This is a fast path of ParseAddrString(addr.String())
func ParseAddr(addr net.Addr) Addr { func ParseAddr(addr net.Addr) Addr {
switch v := addr.(type) { if v, ok := addr.(interface {
case *net.TCPAddr: AddrPort() netip.AddrPort
return SerializeAddr("", v.IP, uint16(v.Port)) }); ok {
case *net.UDPAddr: ap := v.AddrPort()
return SerializeAddr("", v.IP, uint16(v.Port)) return SerializeAddr("", ap.Addr(), ap.Port())
default:
return ParseAddrString(addr.String())
} }
return ParseAddrString(addr.String())
} }
// ParseAddrString parses the address in string s to Addr. Returns nil if failed. // ParseAddrString parses the address in string s to Addr. Returns nil if failed.
@ -351,10 +351,10 @@ func ParseAddrString(s string) Addr {
return nil return nil
} }
if ip := net.ParseIP(host); ip != nil { if ip, _ := netip.ParseAddr(host); ip.IsValid() {
return SerializeAddr("", ip, uint16(dstPort)) return SerializeAddr("", ip, uint16(dstPort))
} }
return SerializeAddr(host, nil, uint16(dstPort)) return SerializeAddr(host, netip.Addr{}, uint16(dstPort))
} }
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`

Loading…
Cancel
Save