diff --git a/internal/pool/alloc.go b/buffer/allocator/allocator.go similarity index 64% rename from internal/pool/alloc.go rename to buffer/allocator/allocator.go index 9a4cc10..ea12ac9 100644 --- a/internal/pool/alloc.go +++ b/buffer/allocator/allocator.go @@ -1,30 +1,29 @@ -package pool +package allocator import ( "errors" "math/bits" - "sync" -) -var _allocator = NewAllocator() + "github.com/xjasonlyu/tun2socks/v2/internal/pool" +) // Allocator for incoming frames, optimized to prevent overwriting // after zeroing. type Allocator struct { - buffers []sync.Pool + buffers []*pool.Pool[[]byte] } -// NewAllocator initiates a []byte allocator for frames less than -// 65536 bytes, the waste(memory fragmentation) of space allocation -// is guaranteed to be no more than 50%. -func NewAllocator() *Allocator { +// New initiates a []byte allocator for frames less than 65536 bytes, +// the waste(memory fragmentation) of space allocation is guaranteed +// to be no more than 50%. +func New() *Allocator { alloc := &Allocator{} - alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K + alloc.buffers = make([]*pool.Pool[[]byte], 17) // 1B -> 64K for k := range alloc.buffers { i := k - alloc.buffers[k].New = func() any { + alloc.buffers[k] = pool.New(func() []byte { return make([]byte, 1<= -1 { + x.value = i + } + }() + } +} diff --git a/proxy/relay.go b/proxy/relay.go index 87343e7..7dadc97 100644 --- a/proxy/relay.go +++ b/proxy/relay.go @@ -13,8 +13,8 @@ import ( "github.com/go-gost/relay" + "github.com/xjasonlyu/tun2socks/v2/buffer" "github.com/xjasonlyu/tun2socks/v2/dialer" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/proxy/proto" ) @@ -169,8 +169,8 @@ func (rc *relayConn) Read(b []byte) (n int, err error) { return io.ReadFull(rc.Conn, b[:dLen]) } - buf := pool.Get(dLen) - defer pool.Put(buf) + buf := buffer.Get(dLen) + defer buffer.Put(buf) _, err = io.ReadFull(rc.Conn, buf) n = copy(b, buf) diff --git a/transport/internal/bufferpool/bufferpool.go b/transport/internal/bufferpool/bufferpool.go new file mode 100644 index 0000000..170ca9d --- /dev/null +++ b/transport/internal/bufferpool/bufferpool.go @@ -0,0 +1,25 @@ +package bufferpool + +import ( + "bytes" + + "github.com/xjasonlyu/tun2socks/v2/internal/pool" +) + +const _size = 1024 // by default, create 1 KiB buffers + +var _pool = pool.New(func() *bytes.Buffer { + buf := &bytes.Buffer{} + buf.Grow(_size) + return buf +}) + +func Get() *bytes.Buffer { + buf := _pool.Get() + buf.Reset() + return buf +} + +func Put(b *bytes.Buffer) { + _pool.Put(b) +} diff --git a/transport/shadowsocks/shadowaead/packet.go b/transport/shadowsocks/shadowaead/packet.go index 91b0bf4..25fb1dc 100755 --- a/transport/shadowsocks/shadowaead/packet.go +++ b/transport/shadowsocks/shadowaead/packet.go @@ -6,7 +6,7 @@ import ( "io" "net" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" + "github.com/xjasonlyu/tun2socks/v2/buffer" ) // ErrShortPacket means that the packet is too short for a valid encrypted packet. @@ -70,8 +70,8 @@ func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn { // WriteTo encrypts b and write to addr using the embedded PacketConn. func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - buf := pool.Get(maxPacketSize) - defer pool.Put(buf) + buf := buffer.Get(maxPacketSize) + defer buffer.Put(buf) buf, err := Pack(buf, b, c) if err != nil { return 0, err diff --git a/transport/shadowsocks/shadowaead/stream.go b/transport/shadowsocks/shadowaead/stream.go index 83cc641..3a13ba5 100755 --- a/transport/shadowsocks/shadowaead/stream.go +++ b/transport/shadowsocks/shadowaead/stream.go @@ -7,7 +7,7 @@ import ( "io" "net" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" + "github.com/xjasonlyu/tun2socks/v2/buffer" ) const ( @@ -29,8 +29,8 @@ func NewWriter(w io.Writer, aead cipher.AEAD) *Writer { return &Writer{Writer: w // Write encrypts p and writes to the embedded io.Writer. func (w *Writer) Write(p []byte) (n int, err error) { - buf := pool.Get(bufSize) - defer pool.Put(buf) + buf := buffer.Get(bufSize) + defer buffer.Put(buf) nonce := w.nonce[:w.NonceSize()] tag := w.Overhead() off := 2 + tag @@ -65,8 +65,8 @@ func (w *Writer) Write(p []byte) (n int, err error) { // writes to the embedded io.Writer. Returns number of bytes read from r and // any error encountered. func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) { - buf := pool.Get(bufSize) - defer pool.Put(buf) + buf := buffer.Get(bufSize) + defer buffer.Put(buf) nonce := w.nonce[:w.NonceSize()] tag := w.Overhead() off := 2 + tag @@ -142,7 +142,7 @@ func (r *Reader) Read(p []byte) (int, error) { if len(p) >= payloadSizeMask+r.Overhead() { return r.read(p) } - b := pool.Get(bufSize) + b := buffer.Get(bufSize) n, err := r.read(b) if err != nil { return 0, err @@ -154,7 +154,7 @@ func (r *Reader) Read(p []byte) (int, error) { n := copy(p, r.buf[r.off:]) r.off += n if r.off == len(r.buf) { - pool.Put(r.buf[:cap(r.buf)]) + buffer.Put(r.buf[:cap(r.buf)]) r.buf = nil } return n, nil @@ -165,7 +165,7 @@ func (r *Reader) Read(p []byte) (int, error) { // bytes written to w and any error encountered. func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { if r.buf == nil { - r.buf = pool.Get(bufSize) + r.buf = buffer.Get(bufSize) r.off = len(r.buf) } @@ -176,7 +176,7 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) { n += int64(nw) if ew != nil { if r.off == len(r.buf) { - pool.Put(r.buf[:cap(r.buf)]) + buffer.Put(r.buf[:cap(r.buf)]) r.buf = nil } err = ew diff --git a/transport/shadowsocks/shadowstream/packet.go b/transport/shadowsocks/shadowstream/packet.go index 082a6d5..732474f 100755 --- a/transport/shadowsocks/shadowstream/packet.go +++ b/transport/shadowsocks/shadowstream/packet.go @@ -6,7 +6,7 @@ import ( "io" "net" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" + "github.com/xjasonlyu/tun2socks/v2/buffer" ) // ErrShortPacket means the packet is too short to be a valid encrypted packet. @@ -55,8 +55,8 @@ func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn { const maxPacketSize = 64 * 1024 func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - buf := pool.Get(maxPacketSize) - defer pool.Put(buf) + buf := buffer.Get(maxPacketSize) + defer buffer.Put(buf) buf, err := Pack(buf, b, c.Cipher) if err != nil { return 0, err diff --git a/transport/simple-obfs/http.go b/transport/simple-obfs/http.go index bdae1a3..9c71d78 100644 --- a/transport/simple-obfs/http.go +++ b/transport/simple-obfs/http.go @@ -10,7 +10,7 @@ import ( "net" "net/http" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" + "github.com/xjasonlyu/tun2socks/v2/buffer" ) // HTTPObfs is shadowsocks http simple-obfs implementation @@ -29,22 +29,22 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) { n := copy(b, ho.buf[ho.offset:]) ho.offset += n if ho.offset == len(ho.buf) { - pool.Put(ho.buf) + buffer.Put(ho.buf) ho.buf = nil } return n, nil } if ho.firstResponse { - buf := pool.Get(pool.RelayBufferSize) + buf := buffer.Get(buffer.RelayBufferSize) n, err := ho.Conn.Read(buf) if err != nil { - pool.Put(buf) + buffer.Put(buf) return 0, err } idx := bytes.Index(buf[:n], []byte("\r\n\r\n")) if idx == -1 { - pool.Put(buf) + buffer.Put(buf) return 0, io.EOF } ho.firstResponse = false @@ -54,7 +54,7 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) { ho.buf = buf[:idx+4+length] ho.offset = idx + 4 + n } else { - pool.Put(buf) + buffer.Put(buf) } return n, nil } diff --git a/transport/simple-obfs/tls.go b/transport/simple-obfs/tls.go index a9a3926..e91e6bc 100644 --- a/transport/simple-obfs/tls.go +++ b/transport/simple-obfs/tls.go @@ -1,14 +1,14 @@ package obfs import ( - "bytes" "crypto/rand" "encoding/binary" "io" "net" "time" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" + "github.com/xjasonlyu/tun2socks/v2/buffer" + "github.com/xjasonlyu/tun2socks/v2/transport/internal/bufferpool" ) const ( @@ -25,12 +25,12 @@ type TLSObfs struct { } func (to *TLSObfs) read(b []byte, discardN int) (int, error) { - buf := pool.Get(discardN) + buf := buffer.Get(discardN) _, err := io.ReadFull(to.Conn, buf) if err != nil { return 0, err } - pool.Put(buf) + buffer.Put(buf) sizeBuf := make([]byte, 2) _, err = io.ReadFull(to.Conn, sizeBuf) @@ -99,8 +99,8 @@ func (to *TLSObfs) write(b []byte) (int, error) { return len(b), err } - buf := pool.GetBuffer() - defer pool.PutBuffer(buf) + buf := bufferpool.Get() + defer bufferpool.Put(buf) buf.Write([]byte{0x17, 0x03, 0x03}) binary.Write(buf, binary.BigEndian, uint16(len(b))) buf.Write(b) @@ -124,7 +124,8 @@ func makeClientHelloMsg(data []byte, server string) []byte { rand.Read(random) rand.Read(sessionID) - buf := &bytes.Buffer{} + buf := bufferpool.Get() + defer bufferpool.Put(buf) // handshake, TLS 1.0 version, length buf.WriteByte(22) diff --git a/transport/socks4/socks4.go b/transport/socks4/socks4.go index 514f2d5..e096da1 100644 --- a/transport/socks4/socks4.go +++ b/transport/socks4/socks4.go @@ -2,13 +2,14 @@ package socks4 import ( - "bytes" "encoding/binary" "errors" "io" "net" "net/netip" "strconv" + + "github.com/xjasonlyu/tun2socks/v2/transport/internal/bufferpool" ) const Version = 0x04 @@ -64,13 +65,14 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri return errIPv6NotSupported } - req := &bytes.Buffer{} + req := bufferpool.Get() + defer bufferpool.Put(req) req.WriteByte(Version) req.WriteByte(command) _ = binary.Write(req, binary.BigEndian, port) req.Write(ip.AsSlice()) req.WriteString(userID) - req.WriteByte(0) /* NULL */ + req.WriteByte(0x00) /* NULL */ if isReservedIP(ip) /* SOCKS4A */ { req.WriteString(host) diff --git a/transport/socks5/socks5.go b/transport/socks5/socks5.go index 43bab41..f8acbf6 100644 --- a/transport/socks5/socks5.go +++ b/transport/socks5/socks5.go @@ -10,6 +10,8 @@ import ( "net" "net/netip" "strconv" + + "github.com/xjasonlyu/tun2socks/v2/transport/internal/bufferpool" ) // AuthMethod is the authentication method as defined in RFC 1928 section 3. @@ -206,10 +208,9 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) ( return nil, errors.New("auth username/password too long") } - authMsgLen := 1 + 1 + uLen + 1 + pLen - // password protocol version - authMsg := bytes.NewBuffer(make([]byte, 0, authMsgLen)) + authMsg := bufferpool.Get() + defer bufferpool.Put(authMsg) authMsg.WriteByte(0x01 /* VER */) authMsg.WriteByte(byte(uLen) /* ULEN */) authMsg.WriteString(user.Username /* UNAME */) @@ -233,7 +234,15 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) ( } // VER, CMD, RSV, ADDR - if _, err := rw.Write(bytes.Join([][]byte{{Version, byte(command), 0x00 /* RSV */}, addr}, nil)); err != nil { + req := bufferpool.Get() + defer bufferpool.Put(req) + req.Grow(3 + MaxAddrLen) + req.WriteByte(Version) + req.WriteByte(byte(command)) + req.WriteByte(0x00 /* RSV */) + req.Write(addr) + + if _, err := rw.Write(req.Bytes()); err != nil { return nil, err } diff --git a/tunnel/tcp.go b/tunnel/tcp.go index 6ac563d..be7b4fc 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -7,8 +7,8 @@ import ( "sync" "time" + "github.com/xjasonlyu/tun2socks/v2/buffer" "github.com/xjasonlyu/tun2socks/v2/core/adapter" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" @@ -56,11 +56,11 @@ func pipe(origin, remote net.Conn) { func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) { defer wg.Done() - buf := pool.Get(pool.RelayBufferSize) + buf := buffer.Get(buffer.RelayBufferSize) if _, err := io.CopyBuffer(dst, src, buf); err != nil { log.Debugf("[TCP] copy data for %s: %v", dir, err) } - pool.Put(buf) + buffer.Put(buf) // Do the upload/download side TCP half-close. if cr, ok := src.(interface{ CloseRead() error }); ok { cr.CloseRead() diff --git a/tunnel/udp.go b/tunnel/udp.go index 03c5e3e..8797386 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -6,8 +6,8 @@ import ( "sync" "time" + "github.com/xjasonlyu/tun2socks/v2/buffer" "github.com/xjasonlyu/tun2socks/v2/core/adapter" - "github.com/xjasonlyu/tun2socks/v2/internal/pool" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" @@ -66,8 +66,8 @@ func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string } func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error { - buf := pool.Get(pool.MaxSegmentSize) - defer pool.Put(buf) + buf := buffer.Get(buffer.MaxSegmentSize) + defer buffer.Put(buf) for { src.SetReadDeadline(time.Now().Add(timeout))