Browse Source

Refactor(socks): replace net.IP with netip.Addr (#397)

pull/398/head
Jason Lyu 2 months ago
committed by GitHub
parent
commit
978803cdf8
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 2
      proxy/socks5.go
  2. 36
      transport/socks4/socks4.go
  3. 84
      transport/socks4/socks4_test.go
  4. 26
      transport/socks5/socks5.go

2
proxy/socks5.go

@ -186,5 +186,5 @@ func (pc *socksPacketConn) Close() error {
}
func serializeSocksAddr(m *M.Metadata) socks5.Addr {
return socks5.SerializeAddr("", m.DstIP.AsSlice(), m.DstPort)
return socks5.SerializeAddr("", m.DstIP, m.DstPort)
}

36
transport/socks4/socks4.go

@ -7,6 +7,7 @@ import (
"errors"
"io"
"net"
"net/netip"
"strconv"
)
@ -31,6 +32,7 @@ const (
var (
errVersionMismatched = errors.New("version code mismatched")
errIPv6NotSupported = errors.New("IPv6 not supported")
errCmdNotSupported = errors.New("command not supported")
ErrRequestRejected = errors.New("request rejected or failed")
ErrRequestIdentdFailed = errors.New("request rejected because SOCKS server cannot connect to identd on the client")
@ -39,6 +41,10 @@ var (
)
func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID string) (err error) {
if command == CmdBind {
return errCmdNotSupported
}
var (
host string
port uint16
@ -47,24 +53,26 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
return err
}
ip := net.ParseIP(host)
if ip == nil /* HOST */ {
ip = net.IPv4(0, 0, 0, 1)
} else if ip.To4() == nil /* IPv6 */ {
ip, _ := netip.ParseAddr(host)
switch {
case !ip.IsValid(): /* HOST */
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
case ip.Is4In6(): /* IPv4-mapped IPv6 */
ip = netip.AddrFrom4(ip.As4())
case ip.Is4(): /* IPv4 */
case ip.Is6(): /* IPv6 */
return errIPv6NotSupported
}
dstIP := /* [4]byte */ ip.To4()
req := &bytes.Buffer{}
req.WriteByte(Version)
req.WriteByte(command)
binary.Write(req, binary.BigEndian, port)
req.Write(dstIP)
_ = binary.Write(req, binary.BigEndian, port)
req.Write(ip.AsSlice())
req.WriteString(userID)
req.WriteByte(0) /* NULL */
if isReservedIP(dstIP) /* SOCKS4A */ {
if isReservedIP(ip) /* SOCKS4A */ {
req.WriteString(host)
req.WriteByte(0) /* NULL */
}
@ -103,13 +111,9 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
// Internet Assigned Numbers Authority -- such an address is inadmissible
// as a destination IP address and thus should never occur if the client
// can resolve the domain name.)
func isReservedIP(ip net.IP) bool {
subnet := net.IPNet{
IP: net.IPv4zero,
Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00),
}
return !ip.IsUnspecified() && subnet.Contains(ip)
func isReservedIP(ip netip.Addr) bool {
prefix := netip.PrefixFrom(netip.IPv4Unspecified(), 24)
return !ip.IsUnspecified() && prefix.Contains(ip)
}
func splitHostPort(addr string) (string, uint16, error) {

84
transport/socks4/socks4_test.go

@ -0,0 +1,84 @@
package socks4
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsReservedIP(t *testing.T) {
reservedIPs := []string{
"0.0.0.1",
"0.0.0.2",
"0.0.0.50",
"0.0.0.100",
"0.0.0.255",
}
for _, ip := range reservedIPs {
assert.True(t, isReservedIP(netip.MustParseAddr(ip)))
}
unReservedIPs := []string{
"0.0.0.0",
"0.0.1.0",
"1.1.1.1",
"10.0.0.0",
"255.255.255.255",
}
for _, ip := range unReservedIPs {
assert.False(t, isReservedIP(netip.MustParseAddr(ip)))
}
}
func TestSplitHostPort(t *testing.T) {
tests := []struct {
addr string
host string
port uint16
}{
{
"1.1.1.1:80",
"1.1.1.1",
80,
},
{
"1.1.1.1:0",
"1.1.1.1",
0,
},
{
"0.0.0.0:0",
"0.0.0.0",
0,
},
{
"[::1]:443",
"::1",
443,
},
{
"example.com:80",
"example.com",
80,
},
}
for _, tt := range tests {
host, port, err := splitHostPort(tt.addr)
assert.NoError(t, err)
assert.Equal(t, tt.host, host)
assert.Equal(t, tt.port, port)
}
addrs := []string{
"1.1.1.1:-80",
"1.1.1.1:abcd",
"::1:80",
"[::1]",
"example.com",
}
for _, addr := range addrs {
_, _, err := splitHostPort(addr)
assert.Error(t, err)
}
}

26
transport/socks5/socks5.go

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"strconv"
)
@ -308,7 +309,7 @@ func SplitAddr(b []byte) Addr {
// SerializeAddr serializes destination address and port to Addr.
// If a domain name is provided, AtypDomainName would be used first.
func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr {
func SerializeAddr(domainName string, dstIP netip.Addr, dstPort uint16) Addr {
var (
buf [][]byte
port [2]byte
@ -318,10 +319,10 @@ func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr {
if domainName != "" /* Domain Name */ {
length := len(domainName)
buf = [][]byte{{AtypDomainName, uint8(length)}, []byte(domainName), port[:]}
} else if dstIP.To4() != nil /* IPv4 */ {
buf = [][]byte{{AtypIPv4}, dstIP.To4(), port[:]}
} else if dstIP.Is4() /* IPv4 */ {
buf = [][]byte{{AtypIPv4}, dstIP.AsSlice(), port[:]}
} else /* IPv6 */ {
buf = [][]byte{{AtypIPv6}, dstIP.To16(), port[:]}
buf = [][]byte{{AtypIPv6}, dstIP.AsSlice(), port[:]}
}
return bytes.Join(buf, nil)
}
@ -329,14 +330,13 @@ func SerializeAddr(domainName string, dstIP net.IP, dstPort uint16) Addr {
// ParseAddr parses a socks addr from net.Addr.
// This is a fast path of ParseAddrString(addr.String())
func ParseAddr(addr net.Addr) Addr {
switch v := addr.(type) {
case *net.TCPAddr:
return SerializeAddr("", v.IP, uint16(v.Port))
case *net.UDPAddr:
return SerializeAddr("", v.IP, uint16(v.Port))
default:
return ParseAddrString(addr.String())
if v, ok := addr.(interface {
AddrPort() netip.AddrPort
}); ok {
ap := v.AddrPort()
return SerializeAddr("", ap.Addr(), ap.Port())
}
return ParseAddrString(addr.String())
}
// ParseAddrString parses the address in string s to Addr. Returns nil if failed.
@ -351,10 +351,10 @@ func ParseAddrString(s string) Addr {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if ip, _ := netip.ParseAddr(host); ip.IsValid() {
return SerializeAddr("", ip, uint16(dstPort))
}
return SerializeAddr(host, nil, uint16(dstPort))
return SerializeAddr(host, netip.Addr{}, uint16(dstPort))
}
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`

Loading…
Cancel
Save