Browse Source

Refactor(pool): use generic buffer/pool (#399)

pull/401/head
Jason Lyu 2 months ago
committed by GitHub
parent
commit
391d3d9f89
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 28
      buffer/allocator/allocator.go
  2. 8
      buffer/allocator/allocator_test.go
  3. 29
      buffer/pool.go
  4. 17
      internal/pool/buffer.go
  5. 47
      internal/pool/pool.go
  6. 85
      internal/pool/pool_test.go
  7. 6
      proxy/relay.go
  8. 25
      transport/internal/bufferpool/bufferpool.go
  9. 6
      transport/shadowsocks/shadowaead/packet.go
  10. 18
      transport/shadowsocks/shadowaead/stream.go
  11. 6
      transport/shadowsocks/shadowstream/packet.go
  12. 12
      transport/simple-obfs/http.go
  13. 15
      transport/simple-obfs/tls.go
  14. 8
      transport/socks4/socks4.go
  15. 17
      transport/socks5/socks5.go
  16. 6
      tunnel/tcp.go
  17. 6
      tunnel/udp.go

28
internal/pool/alloc.go → buffer/allocator/allocator.go

@ -1,30 +1,29 @@
package pool
package allocator
import (
"errors"
"math/bits"
"sync"
)
var _allocator = NewAllocator()
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
)
// Allocator for incoming frames, optimized to prevent overwriting
// after zeroing.
type Allocator struct {
buffers []sync.Pool
buffers []*pool.Pool[[]byte]
}
// NewAllocator initiates a []byte allocator for frames less than
// 65536 bytes, the waste(memory fragmentation) of space allocation
// is guaranteed to be no more than 50%.
func NewAllocator() *Allocator {
// New initiates a []byte allocator for frames less than 65536 bytes,
// the waste(memory fragmentation) of space allocation is guaranteed
// to be no more than 50%.
func New() *Allocator {
alloc := &Allocator{}
alloc.buffers = make([]sync.Pool, 17) // 1B -> 64K
alloc.buffers = make([]*pool.Pool[[]byte], 17) // 1B -> 64K
for k := range alloc.buffers {
i := k
alloc.buffers[k].New = func() any {
alloc.buffers[k] = pool.New(func() []byte {
return make([]byte, 1<<uint32(i))
}
})
}
return alloc
}
@ -37,10 +36,10 @@ func (alloc *Allocator) Get(size int) []byte {
b := msb(size)
if size == 1<<b {
return alloc.buffers[b].Get().([]byte)[:size]
return alloc.buffers[b].Get()[:size]
}
return alloc.buffers[b+1].Get().([]byte)[:size]
return alloc.buffers[b+1].Get()[:size]
}
// Put returns a []byte to pool for future use,
@ -51,7 +50,6 @@ func (alloc *Allocator) Put(buf []byte) error {
return errors.New("allocator Put() incorrect buffer size")
}
//nolint:staticcheck
alloc.buffers[b].Put(buf)
return nil
}

8
internal/pool/alloc_test.go → buffer/allocator/allocator_test.go

@ -1,4 +1,4 @@
package pool
package allocator
import (
"math/rand"
@ -8,7 +8,7 @@ import (
)
func TestAllocGet(t *testing.T) {
alloc := NewAllocator()
alloc := New()
assert.Nil(t, alloc.Get(0))
assert.Equal(t, 1, len(alloc.Get(1)))
assert.Equal(t, 2, len(alloc.Get(2)))
@ -23,7 +23,7 @@ func TestAllocGet(t *testing.T) {
}
func TestAllocPut(t *testing.T) {
alloc := NewAllocator()
alloc := New()
assert.NotNil(t, alloc.Put(nil), "put nil misbehavior")
assert.NotNil(t, alloc.Put(make([]byte, 3)), "put elem:3 []bytes misbehavior")
assert.Nil(t, alloc.Put(make([]byte, 4)), "put elem:4 []bytes misbehavior")
@ -33,7 +33,7 @@ func TestAllocPut(t *testing.T) {
}
func TestAllocPutThenGet(t *testing.T) {
alloc := NewAllocator()
alloc := New()
data := alloc.Get(4)
_ = alloc.Put(data)
newData := alloc.Get(4)

29
buffer/pool.go

@ -0,0 +1,29 @@
// Package buffer provides a pool of []byte.
package buffer
import (
"github.com/xjasonlyu/tun2socks/v2/buffer/allocator"
)
const (
// MaxSegmentSize is the largest possible UDP datagram size.
MaxSegmentSize = (1 << 16) - 1
// RelayBufferSize is the default buffer size for TCP relays.
// io.Copy default buffer size is 32 KiB, but the maximum packet
// size of vmess/shadowsocks is about 16 KiB, so define a buffer
// of 20 KiB to reduce the memory of each TCP relay.
RelayBufferSize = 20 << 10
)
var _allocator = allocator.New()
// Get gets a []byte from default allocator with most appropriate cap.
func Get(size int) []byte {
return _allocator.Get(size)
}
// Put returns a []byte to default allocator for future use.
func Put(buf []byte) error {
return _allocator.Put(buf)
}

17
internal/pool/buffer.go

@ -1,17 +0,0 @@
package pool
import (
"bytes"
"sync"
)
var _bufferPool = sync.Pool{New: func() any { return &bytes.Buffer{} }}
func GetBuffer() *bytes.Buffer {
return _bufferPool.Get().(*bytes.Buffer)
}
func PutBuffer(buf *bytes.Buffer) {
buf.Reset()
_bufferPool.Put(buf)
}

47
internal/pool/pool.go

@ -1,23 +1,38 @@
// Package pool provides a pool of []byte.
// Package pool provides internal pool utilities.
package pool
const (
// MaxSegmentSize is the largest possible UDP datagram size.
MaxSegmentSize = (1 << 16) - 1
// RelayBufferSize is a buffer of 20 KiB to reduce the memory
// of each TCP relay as io.Copy default buffer size is 32 KiB,
// but the maximum packet size of vmess/shadowsocks is about
// 16 KiB, so define .
RelayBufferSize = 20 << 10
import (
"sync"
)
// Get gets a []byte from default allocator with most appropriate cap.
func Get(size int) []byte {
return _allocator.Get(size)
// A Pool is a generic wrapper around [sync.Pool] to provide strongly-typed
// object pooling.
//
// Note that SA6002 (ref: https://staticcheck.io/docs/checks/#SA6002) will
// not be detected, so all internal pool use must take care to only store
// pointer types.
type Pool[T any] struct {
pool sync.Pool
}
// New returns a new [Pool] for T, and will use fn to construct new Ts when
// the pool is empty.
func New[T any](fn func() T) *Pool[T] {
return &Pool[T]{
pool: sync.Pool{
New: func() any {
return fn()
},
},
}
}
// Get gets a T from the pool, or creates a new one if the pool is empty.
func (p *Pool[T]) Get() T {
return p.pool.Get().(T)
}
// Put returns a []byte to default allocator for future use.
func Put(buf []byte) error {
return _allocator.Put(buf)
// Put returns x into the pool.
func (p *Pool[T]) Put(x T) {
p.pool.Put(x)
}

85
internal/pool/pool_test.go

@ -0,0 +1,85 @@
package pool
import (
"runtime/debug"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
type pooledValue[T any] struct {
value T
}
func TestNew(t *testing.T) {
// Disable GC to avoid the victim cache during the test.
defer debug.SetGCPercent(debug.SetGCPercent(-1))
p := New(func() *pooledValue[string] {
return &pooledValue[string]{
value: "new",
}
})
// Probabilistically, 75% of sync.Pool.Put calls will succeed when -race
// is enabled (see ref below); attempt to make this quasi-deterministic by
// brute force (i.e., put significantly more objects in the pool than we
// will need for the test) in order to avoid testing without race enabled.
//
// ref: https://cs.opensource.google/go/go/+/refs/tags/go1.20.2:src/sync/pool.go;l=100-103
for i := 0; i < 1_000; i++ {
p.Put(&pooledValue[string]{
value: t.Name(),
})
}
// Ensure that we always get the expected value. Note that this must only
// run a fraction of the number of times that Put is called above.
for i := 0; i < 10; i++ {
func() {
x := p.Get()
defer p.Put(x)
require.Equal(t, t.Name(), x.value)
}()
}
// Depool all objects that might be in the pool to ensure that it's empty.
for i := 0; i < 1_000; i++ {
p.Get()
}
// Now that the pool is empty, it should use the value specified in the
// underlying sync.Pool.New func.
require.Equal(t, "new", p.Get().value)
}
func TestNew_Race(t *testing.T) {
p := New(func() *pooledValue[int] {
return &pooledValue[int]{
value: -1,
}
})
var wg sync.WaitGroup
defer wg.Wait()
// Run a number of goroutines that read and write pool object fields to
// tease out races.
for i := 0; i < 1_000; i++ {
i := i
wg.Add(1)
go func() {
defer wg.Done()
x := p.Get()
defer p.Put(x)
// Must both read and write the field.
if n := x.value; n >= -1 {
x.value = i
}
}()
}
}

6
proxy/relay.go

@ -13,8 +13,8 @@ import (
"github.com/go-gost/relay"
"github.com/xjasonlyu/tun2socks/v2/buffer"
"github.com/xjasonlyu/tun2socks/v2/dialer"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy/proto"
)
@ -169,8 +169,8 @@ func (rc *relayConn) Read(b []byte) (n int, err error) {
return io.ReadFull(rc.Conn, b[:dLen])
}
buf := pool.Get(dLen)
defer pool.Put(buf)
buf := buffer.Get(dLen)
defer buffer.Put(buf)
_, err = io.ReadFull(rc.Conn, buf)
n = copy(b, buf)

25
transport/internal/bufferpool/bufferpool.go

@ -0,0 +1,25 @@
package bufferpool
import (
"bytes"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
)
const _size = 1024 // by default, create 1 KiB buffers
var _pool = pool.New(func() *bytes.Buffer {
buf := &bytes.Buffer{}
buf.Grow(_size)
return buf
})
func Get() *bytes.Buffer {
buf := _pool.Get()
buf.Reset()
return buf
}
func Put(b *bytes.Buffer) {
_pool.Put(b)
}

6
transport/shadowsocks/shadowaead/packet.go

@ -6,7 +6,7 @@ import (
"io"
"net"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/buffer"
)
// ErrShortPacket means that the packet is too short for a valid encrypted packet.
@ -70,8 +70,8 @@ func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
// WriteTo encrypts b and write to addr using the embedded PacketConn.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
buf := pool.Get(maxPacketSize)
defer pool.Put(buf)
buf := buffer.Get(maxPacketSize)
defer buffer.Put(buf)
buf, err := Pack(buf, b, c)
if err != nil {
return 0, err

18
transport/shadowsocks/shadowaead/stream.go

@ -7,7 +7,7 @@ import (
"io"
"net"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/buffer"
)
const (
@ -29,8 +29,8 @@ func NewWriter(w io.Writer, aead cipher.AEAD) *Writer { return &Writer{Writer: w
// Write encrypts p and writes to the embedded io.Writer.
func (w *Writer) Write(p []byte) (n int, err error) {
buf := pool.Get(bufSize)
defer pool.Put(buf)
buf := buffer.Get(bufSize)
defer buffer.Put(buf)
nonce := w.nonce[:w.NonceSize()]
tag := w.Overhead()
off := 2 + tag
@ -65,8 +65,8 @@ func (w *Writer) Write(p []byte) (n int, err error) {
// writes to the embedded io.Writer. Returns number of bytes read from r and
// any error encountered.
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
buf := pool.Get(bufSize)
defer pool.Put(buf)
buf := buffer.Get(bufSize)
defer buffer.Put(buf)
nonce := w.nonce[:w.NonceSize()]
tag := w.Overhead()
off := 2 + tag
@ -142,7 +142,7 @@ func (r *Reader) Read(p []byte) (int, error) {
if len(p) >= payloadSizeMask+r.Overhead() {
return r.read(p)
}
b := pool.Get(bufSize)
b := buffer.Get(bufSize)
n, err := r.read(b)
if err != nil {
return 0, err
@ -154,7 +154,7 @@ func (r *Reader) Read(p []byte) (int, error) {
n := copy(p, r.buf[r.off:])
r.off += n
if r.off == len(r.buf) {
pool.Put(r.buf[:cap(r.buf)])
buffer.Put(r.buf[:cap(r.buf)])
r.buf = nil
}
return n, nil
@ -165,7 +165,7 @@ func (r *Reader) Read(p []byte) (int, error) {
// bytes written to w and any error encountered.
func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
if r.buf == nil {
r.buf = pool.Get(bufSize)
r.buf = buffer.Get(bufSize)
r.off = len(r.buf)
}
@ -176,7 +176,7 @@ func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
n += int64(nw)
if ew != nil {
if r.off == len(r.buf) {
pool.Put(r.buf[:cap(r.buf)])
buffer.Put(r.buf[:cap(r.buf)])
r.buf = nil
}
err = ew

6
transport/shadowsocks/shadowstream/packet.go

@ -6,7 +6,7 @@ import (
"io"
"net"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/buffer"
)
// ErrShortPacket means the packet is too short to be a valid encrypted packet.
@ -55,8 +55,8 @@ func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
const maxPacketSize = 64 * 1024
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
buf := pool.Get(maxPacketSize)
defer pool.Put(buf)
buf := buffer.Get(maxPacketSize)
defer buffer.Put(buf)
buf, err := Pack(buf, b, c.Cipher)
if err != nil {
return 0, err

12
transport/simple-obfs/http.go

@ -10,7 +10,7 @@ import (
"net"
"net/http"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/buffer"
)
// HTTPObfs is shadowsocks http simple-obfs implementation
@ -29,22 +29,22 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) {
n := copy(b, ho.buf[ho.offset:])
ho.offset += n
if ho.offset == len(ho.buf) {
pool.Put(ho.buf)
buffer.Put(ho.buf)
ho.buf = nil
}
return n, nil
}
if ho.firstResponse {
buf := pool.Get(pool.RelayBufferSize)
buf := buffer.Get(buffer.RelayBufferSize)
n, err := ho.Conn.Read(buf)
if err != nil {
pool.Put(buf)
buffer.Put(buf)
return 0, err
}
idx := bytes.Index(buf[:n], []byte("\r\n\r\n"))
if idx == -1 {
pool.Put(buf)
buffer.Put(buf)
return 0, io.EOF
}
ho.firstResponse = false
@ -54,7 +54,7 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) {
ho.buf = buf[:idx+4+length]
ho.offset = idx + 4 + n
} else {
pool.Put(buf)
buffer.Put(buf)
}
return n, nil
}

15
transport/simple-obfs/tls.go

@ -1,14 +1,14 @@
package obfs
import (
"bytes"
"crypto/rand"
"encoding/binary"
"io"
"net"
"time"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/buffer"
"github.com/xjasonlyu/tun2socks/v2/transport/internal/bufferpool"
)
const (
@ -25,12 +25,12 @@ type TLSObfs struct {
}
func (to *TLSObfs) read(b []byte, discardN int) (int, error) {
buf := pool.Get(discardN)
buf := buffer.Get(discardN)
_, err := io.ReadFull(to.Conn, buf)
if err != nil {
return 0, err
}
pool.Put(buf)
buffer.Put(buf)
sizeBuf := make([]byte, 2)
_, err = io.ReadFull(to.Conn, sizeBuf)
@ -99,8 +99,8 @@ func (to *TLSObfs) write(b []byte) (int, error) {
return len(b), err
}
buf := pool.GetBuffer()
defer pool.PutBuffer(buf)
buf := bufferpool.Get()
defer bufferpool.Put(buf)
buf.Write([]byte{0x17, 0x03, 0x03})
binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
@ -124,7 +124,8 @@ func makeClientHelloMsg(data []byte, server string) []byte {
rand.Read(random)
rand.Read(sessionID)
buf := &bytes.Buffer{}
buf := bufferpool.Get()
defer bufferpool.Put(buf)
// handshake, TLS 1.0 version, length
buf.WriteByte(22)

8
transport/socks4/socks4.go

@ -2,13 +2,14 @@
package socks4
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"net/netip"
"strconv"
"github.com/xjasonlyu/tun2socks/v2/transport/internal/bufferpool"
)
const Version = 0x04
@ -64,13 +65,14 @@ func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID stri
return errIPv6NotSupported
}
req := &bytes.Buffer{}
req := bufferpool.Get()
defer bufferpool.Put(req)
req.WriteByte(Version)
req.WriteByte(command)
_ = binary.Write(req, binary.BigEndian, port)
req.Write(ip.AsSlice())
req.WriteString(userID)
req.WriteByte(0) /* NULL */
req.WriteByte(0x00) /* NULL */
if isReservedIP(ip) /* SOCKS4A */ {
req.WriteString(host)

17
transport/socks5/socks5.go

@ -10,6 +10,8 @@ import (
"net"
"net/netip"
"strconv"
"github.com/xjasonlyu/tun2socks/v2/transport/internal/bufferpool"
)
// AuthMethod is the authentication method as defined in RFC 1928 section 3.
@ -206,10 +208,9 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (
return nil, errors.New("auth username/password too long")
}
authMsgLen := 1 + 1 + uLen + 1 + pLen
// password protocol version
authMsg := bytes.NewBuffer(make([]byte, 0, authMsgLen))
authMsg := bufferpool.Get()
defer bufferpool.Put(authMsg)
authMsg.WriteByte(0x01 /* VER */)
authMsg.WriteByte(byte(uLen) /* ULEN */)
authMsg.WriteString(user.Username /* UNAME */)
@ -233,7 +234,15 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (
}
// VER, CMD, RSV, ADDR
if _, err := rw.Write(bytes.Join([][]byte{{Version, byte(command), 0x00 /* RSV */}, addr}, nil)); err != nil {
req := bufferpool.Get()
defer bufferpool.Put(req)
req.Grow(3 + MaxAddrLen)
req.WriteByte(Version)
req.WriteByte(byte(command))
req.WriteByte(0x00 /* RSV */)
req.Write(addr)
if _, err := rw.Write(req.Bytes()); err != nil {
return nil, err
}

6
tunnel/tcp.go

@ -7,8 +7,8 @@ import (
"sync"
"time"
"github.com/xjasonlyu/tun2socks/v2/buffer"
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
@ -56,11 +56,11 @@ func pipe(origin, remote net.Conn) {
func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) {
defer wg.Done()
buf := pool.Get(pool.RelayBufferSize)
buf := buffer.Get(buffer.RelayBufferSize)
if _, err := io.CopyBuffer(dst, src, buf); err != nil {
log.Debugf("[TCP] copy data for %s: %v", dir, err)
}
pool.Put(buf)
buffer.Put(buf)
// Do the upload/download side TCP half-close.
if cr, ok := src.(interface{ CloseRead() error }); ok {
cr.CloseRead()

6
tunnel/udp.go

@ -6,8 +6,8 @@ import (
"sync"
"time"
"github.com/xjasonlyu/tun2socks/v2/buffer"
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/internal/pool"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
@ -66,8 +66,8 @@ func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string
}
func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error {
buf := pool.Get(pool.MaxSegmentSize)
defer pool.Put(buf)
buf := buffer.Get(buffer.MaxSegmentSize)
defer buffer.Put(buf)
for {
src.SetReadDeadline(time.Now().Add(timeout))

Loading…
Cancel
Save