You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

252 lines
4.7 KiB

package proxy
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
"sync"
"github.com/go-gost/relay"
"github.com/xjasonlyu/tun2socks/v2/common/pool"
"github.com/xjasonlyu/tun2socks/v2/dialer"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy/proto"
)
var _ Proxy = (*Relay)(nil)
type Relay struct {
*Base
user string
pass string
noDelay bool
}
func NewRelay(addr, user, pass string, noDelay bool) (*Relay, error) {
return &Relay{
Base: &Base{
addr: addr,
proto: proto.Relay,
},
user: user,
pass: pass,
noDelay: noDelay,
}, nil
}
func (rl *Relay) DialContext(ctx context.Context, metadata *M.Metadata) (c net.Conn, err error) {
return rl.dialContext(ctx, metadata)
}
func (rl *Relay) DialUDP(metadata *M.Metadata) (net.PacketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
return rl.dialContext(ctx, metadata)
}
func (rl *Relay) dialContext(ctx context.Context, metadata *M.Metadata) (rc *relayConn, err error) {
var c net.Conn
c, err = dialer.DialContext(ctx, "tcp", rl.Addr())
if err != nil {
return nil, fmt.Errorf("connect to %s: %w", rl.Addr(), err)
}
setKeepAlive(c)
defer func(c net.Conn) {
safeConnClose(c, err)
}(c)
req := relay.Request{
Version: relay.Version1,
Cmd: relay.CmdConnect,
}
if metadata.Network == M.UDP {
req.Cmd |= relay.FUDP
req.Features = append(req.Features, &relay.NetworkFeature{
Network: relay.NetworkUDP,
})
}
if rl.user != "" {
req.Features = append(req.Features, &relay.UserAuthFeature{
Username: rl.user,
Password: rl.pass,
})
}
req.Features = append(req.Features, serializeRelayAddr(metadata))
if rl.noDelay {
if _, err = req.WriteTo(c); err != nil {
return
}
if err = readRelayResponse(c); err != nil {
return
}
}
switch metadata.Network {
case M.TCP:
rc = newRelayConn(c, metadata.Addr(), rl.noDelay, false)
if !rl.noDelay {
if _, err = req.WriteTo(rc.wbuf); err != nil {
return
}
}
case M.UDP:
rc = newRelayConn(c, metadata.Addr(), rl.noDelay, true)
if !rl.noDelay {
if _, err = req.WriteTo(rc.wbuf); err != nil {
return
}
}
default:
err = fmt.Errorf("network %s is unsupported", metadata.Network)
return
}
return
}
type relayConn struct {
net.Conn
udp bool
addr net.Addr
once sync.Once
wbuf *bytes.Buffer
}
func newRelayConn(c net.Conn, addr net.Addr, noDelay, udp bool) *relayConn {
rc := &relayConn{
Conn: c,
addr: addr,
udp: udp,
}
if !noDelay {
rc.wbuf = &bytes.Buffer{}
}
return rc
}
func (rc *relayConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := rc.Read(b)
return n, rc.addr, err
}
func (rc *relayConn) Read(b []byte) (n int, err error) {
rc.once.Do(func() {
if rc.wbuf != nil {
err = readRelayResponse(rc.Conn)
}
})
if err != nil {
return
}
if !rc.udp {
return rc.Conn.Read(b)
}
var bb [2]byte
_, err = io.ReadFull(rc.Conn, bb[:])
if err != nil {
return
}
dLen := int(binary.BigEndian.Uint16(bb[:]))
if len(b) >= dLen {
return io.ReadFull(rc.Conn, b[:dLen])
}
buf := pool.Get(dLen)
defer pool.Put(buf)
_, err = io.ReadFull(rc.Conn, buf)
n = copy(b, buf)
return
}
func (rc *relayConn) WriteTo(b []byte, _ net.Addr) (int, error) {
return rc.Write(b)
}
func (rc *relayConn) Write(b []byte) (int, error) {
if rc.udp {
return rc.udpWrite(b)
}
return rc.tcpWrite(b)
}
func (rc *relayConn) tcpWrite(b []byte) (n int, err error) {
if rc.wbuf != nil && rc.wbuf.Len() > 0 {
n = len(b)
rc.wbuf.Write(b)
_, err = rc.Conn.Write(rc.wbuf.Bytes())
rc.wbuf.Reset()
return
}
return rc.Conn.Write(b)
}
func (rc *relayConn) udpWrite(b []byte) (n int, err error) {
if len(b) > math.MaxUint16 {
err = errors.New("write: data maximum exceeded")
return
}
n = len(b)
if rc.wbuf != nil && rc.wbuf.Len() > 0 {
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
rc.wbuf.Write(bb[:])
rc.wbuf.Write(b)
_, err = rc.wbuf.WriteTo(rc.Conn)
return
}
var bb [2]byte
binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
_, err = rc.Conn.Write(bb[:])
if err != nil {
return
}
return rc.Conn.Write(b)
}
func readRelayResponse(r io.Reader) error {
resp := relay.Response{}
if _, err := resp.ReadFrom(r); err != nil {
return err
}
if resp.Version != relay.Version1 {
return relay.ErrBadVersion
}
if resp.Status != relay.StatusOK {
return fmt.Errorf("status %d", resp.Status)
}
return nil
}
func serializeRelayAddr(m *M.Metadata) *relay.AddrFeature {
af := &relay.AddrFeature{
Host: m.DstIP.String(),
Port: m.DstPort,
}
if m.DstIP.To4() != nil {
af.AType = relay.AddrIPv4
} else {
af.AType = relay.AddrIPv6
}
return af
}