Browse Source

Refactor: new dialer impl

pull/133/head
xjasonlyu 3 years ago
parent
commit
a4bedf6080
  1. 66
      common/singledo/singledo.go
  2. 69
      common/singledo/singledo_test.go
  3. 21
      component/dialer/bind.go
  4. 32
      component/dialer/bind_darwin.go
  5. 27
      component/dialer/bind_linux.go
  6. 15
      component/dialer/bind_others.go
  7. 35
      component/dialer/control.go
  8. 48
      component/dialer/dialer.go
  9. 14
      component/dialer/fwmark.go
  10. 21
      component/dialer/fwmark_linux.go
  11. 14
      component/dialer/fwmark_others.go
  12. 19
      component/dialer/sockopt.go
  13. 61
      component/dialer/sockopt_darwin.go
  14. 39
      component/dialer/sockopt_linux.go
  15. 7
      component/dialer/sockopt_others.go
  16. 21
      engine/engine.go

66
common/singledo/singledo.go

@ -0,0 +1,66 @@
package singledo
// Ref: github.com/Dreamacro/clash/common/singledo
import (
"sync"
"time"
)
type call struct {
wg sync.WaitGroup
val any
err error
}
type Single struct {
mux sync.Mutex
last time.Time
wait time.Duration
call *call
result *Result
}
type Result struct {
Val any
Err error
}
// Do single.Do likes sync.singleFlight
//lint:ignore ST1008 it likes sync.singleFlight
func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) {
s.mux.Lock()
now := time.Now()
if now.Before(s.last.Add(s.wait)) {
s.mux.Unlock()
return s.result.Val, s.result.Err, true
}
if call := s.call; call != nil {
s.mux.Unlock()
call.wg.Wait()
return call.val, call.err, true
}
call := &call{}
call.wg.Add(1)
s.call = call
s.mux.Unlock()
call.val, call.err = fn()
call.wg.Done()
s.mux.Lock()
s.call = nil
s.result = &Result{call.val, call.err}
s.last = now
s.mux.Unlock()
return call.val, call.err, false
}
func (s *Single) Reset() {
s.last = time.Time{}
}
func NewSingle(wait time.Duration) *Single {
return &Single{wait: wait}
}

69
common/singledo/singledo_test.go

@ -0,0 +1,69 @@
package singledo
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
)
func TestBasic(t *testing.T) {
single := NewSingle(time.Millisecond * 30)
foo := 0
shardCount := atomic.NewInt32(0)
call := func() (any, error) {
foo++
time.Sleep(time.Millisecond * 5)
return nil, nil
}
var wg sync.WaitGroup
const n = 5
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
_, _, shard := single.Do(call)
if shard {
shardCount.Inc()
}
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, 1, foo)
assert.Equal(t, int32(4), shardCount.Load())
}
func TestTimer(t *testing.T) {
single := NewSingle(time.Millisecond * 30)
foo := 0
call := func() (any, error) {
foo++
return nil, nil
}
single.Do(call)
time.Sleep(10 * time.Millisecond)
_, _, shard := single.Do(call)
assert.Equal(t, 1, foo)
assert.True(t, shard)
}
func TestReset(t *testing.T) {
single := NewSingle(time.Millisecond * 30)
foo := 0
call := func() (any, error) {
foo++
return nil, nil
}
single.Do(call)
single.Reset()
single.Do(call)
assert.Equal(t, 2, foo)
}

21
component/dialer/bind.go

@ -1,21 +0,0 @@
package dialer
import (
"net"
"sync"
)
var _bindOnce sync.Once
// BindToInterface binds dialer to specific interface.
func BindToInterface(name string) error {
i, err := net.InterfaceByName(name)
if err != nil {
return err
}
_bindOnce.Do(func() {
addControl(bindToInterface(i))
})
return nil
}

32
component/dialer/bind_darwin.go

@ -1,32 +0,0 @@
package dialer
import (
"net"
"syscall"
"golang.org/x/sys/unix"
)
func bindToInterface(i *net.Interface) controlFunc {
return func(network, address string, c syscall.RawConn) (err error) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
var innerErr error
err = c.Control(func(fd uintptr) {
switch network {
case "tcp4", "udp4":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index)
case "tcp6", "udp6":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index)
}
})
if innerErr != nil {
err = innerErr
}
return
}
}

27
component/dialer/bind_linux.go

@ -1,27 +0,0 @@
package dialer
import (
"net"
"syscall"
"golang.org/x/sys/unix"
)
func bindToInterface(i *net.Interface) controlFunc {
return func(network, address string, c syscall.RawConn) (err error) {
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return nil
}
var innerErr error
err = c.Control(func(fd uintptr) {
innerErr = unix.BindToDevice(int(fd), i.Name)
})
if innerErr != nil {
err = innerErr
}
return
}
}

15
component/dialer/bind_others.go

@ -1,15 +0,0 @@
//go:build !linux && !darwin
package dialer
import (
"errors"
"net"
"syscall"
)
func bindToInterface(_ *net.Interface) controlFunc {
return func(string, string, syscall.RawConn) error {
return errors.New("unsupported platform")
}
}

35
component/dialer/control.go

@ -1,35 +0,0 @@
package dialer
import (
"errors"
"net"
"syscall"
)
type controlFunc func(string, string, syscall.RawConn) error
var _controlPool = make([]controlFunc, 0, 2)
func addControl(f controlFunc) {
_controlPool = append(_controlPool, f)
}
func setControl(i any) {
control := func(address, network string, c syscall.RawConn) error {
for _, f := range _controlPool {
if err := f(address, network, c); err != nil {
return err
}
}
return nil
}
switch v := i.(type) {
case *net.Dialer:
v.Control = control
case *net.ListenConfig:
v.Control = control
default:
panic(errors.New("wrong type"))
}
}

48
component/dialer/dialer.go

@ -3,20 +3,56 @@ package dialer
import (
"context"
"net"
"syscall"
"go.uber.org/atomic"
)
var (
DefaultInterfaceName = atomic.NewString("")
DefaultRoutingMark = atomic.NewInt32(0)
)
func Dial(network, address string) (net.Conn, error) {
return DialContext(context.Background(), network, address)
type Options struct {
// InterfaceName is the name of interface/device to bind.
// If a socket is bound to an interface, only packets received
// from that particular interface are processed by the socket.
InterfaceName string
// RoutingMark is the mark for each packet sent through this
// socket. Changing the mark can be used for mark-based routing
// without netfilter or for packet filtering.
RoutingMark int
}
func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
d := &net.Dialer{}
setControl(d)
return DialContextWithOptions(ctx, network, address, &Options{
InterfaceName: DefaultInterfaceName.Load(),
RoutingMark: int(DefaultRoutingMark.Load()),
})
}
func DialContextWithOptions(ctx context.Context, network, address string, opts *Options) (net.Conn, error) {
d := &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts)
},
}
return d.DialContext(ctx, network, address)
}
func ListenPacket(network, address string) (net.PacketConn, error) {
lc := &net.ListenConfig{}
setControl(lc)
return ListenPacketWithOptions(network, address, &Options{
InterfaceName: DefaultInterfaceName.Load(),
RoutingMark: int(DefaultRoutingMark.Load()),
})
}
func ListenPacketWithOptions(network, address string, opts *Options) (net.PacketConn, error) {
lc := &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return setSocketOptions(network, address, c, opts)
},
}
return lc.ListenPacket(context.Background(), network, address)
}

14
component/dialer/fwmark.go

@ -1,14 +0,0 @@
package dialer
import (
"sync"
)
var _setOnce sync.Once
// SetMark sets the mark for each packet sent through this dialer(socket).
func SetMark(i int) {
_setOnce.Do(func() {
addControl(setMark(i))
})
}

21
component/dialer/fwmark_linux.go

@ -1,21 +0,0 @@
package dialer
import (
"syscall"
"golang.org/x/sys/unix"
)
func setMark(m int) controlFunc {
return func(_, _ string, c syscall.RawConn) (err error) {
var innerErr error
err = c.Control(func(fd uintptr) {
innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, m)
})
if innerErr != nil {
err = innerErr
}
return
}
}

14
component/dialer/fwmark_others.go

@ -1,14 +0,0 @@
//go:build !linux
package dialer
import (
"errors"
"syscall"
)
func setMark(_ int) controlFunc {
return func(string, string, syscall.RawConn) error {
return errors.New("fwmark: linux only")
}
}

19
component/dialer/sockopt.go

@ -0,0 +1,19 @@
package dialer
func isTCPSocket(network string) bool {
switch network {
case "tcp", "tcp4", "tcp6":
return true
default:
return false
}
}
func isUDPSocket(network string) bool {
switch network {
case "udp", "udp4", "udp6":
return true
default:
return false
}
}

61
component/dialer/sockopt_darwin.go

@ -0,0 +1,61 @@
package dialer
import (
"net"
"syscall"
"time"
"github.com/xjasonlyu/tun2socks/v2/common/singledo"
"golang.org/x/sys/unix"
)
var interfaces = singledo.NewSingle(30 * time.Second)
func resolveInterfaceByName(name string) (*net.Interface, error) {
value, err, _ := interfaces.Do(func() (any, error) {
return net.InterfaceByName(name)
})
if err != nil {
return nil, err
}
return value.(*net.Interface), nil
}
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if !isTCPSocket(network) && !isUDPSocket(network) {
return
}
var innerErr error
err = c.Control(func(fd uintptr) {
// must be GlobalUnicast.
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceName != "" {
var iface *net.Interface
iface, innerErr = resolveInterfaceByName(opts.InterfaceName)
if innerErr != nil {
return
}
switch network {
case "tcp4", "udp4":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index)
case "tcp6", "udp6":
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index)
}
if innerErr != nil {
return
}
}
})
if innerErr != nil {
err = innerErr
}
return
}

39
component/dialer/sockopt_linux.go

@ -0,0 +1,39 @@
package dialer
import (
"net"
"syscall"
"golang.org/x/sys/unix"
)
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) {
if !isTCPSocket(network) && !isUDPSocket(network) {
return
}
var innerErr error
err = c.Control(func(fd uintptr) {
// must be GlobalUnicast.
host, _, _ := net.SplitHostPort(address)
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() {
return
}
if opts.InterfaceName != "" {
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil {
return
}
}
if opts.RoutingMark != 0 {
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil {
return
}
}
})
if innerErr != nil {
err = innerErr
}
return
}

7
component/dialer/sockopt_others.go

@ -0,0 +1,7 @@
//go:build !linux && !darwin
package dialer
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error {
return nil
}

21
engine/engine.go

@ -65,8 +65,7 @@ func (e *engine) start() error {
for _, f := range []func() error{
e.applyLogLevel,
e.applyMark,
e.applyInterface,
e.applyDialer,
e.applyStats,
e.applyUDPTimeout,
e.applyProxy,
@ -104,21 +103,15 @@ func (e *engine) applyLogLevel() error {
return nil
}
func (e *engine) applyMark() error {
if e.Mark != 0 {
dialer.SetMark(e.Mark)
log.Infof("[DIALER] set fwmark: %#x", e.Mark)
}
return nil
}
func (e *engine) applyInterface() error {
func (e *engine) applyDialer() error {
if e.Interface != "" {
if err := dialer.BindToInterface(e.Interface); err != nil {
return err
}
dialer.DefaultInterfaceName.Store(e.Interface)
log.Infof("[DIALER] bind to interface: %s", e.Interface)
}
if e.Mark != 0 {
dialer.DefaultRoutingMark.Store(int32(e.Mark))
log.Infof("[DIALER] set fwmark: %#x", e.Mark)
}
return nil
}

Loading…
Cancel
Save