xjasonlyu
3 years ago
16 changed files with 310 additions and 199 deletions
@ -0,0 +1,66 @@ |
|||
package singledo |
|||
|
|||
// Ref: github.com/Dreamacro/clash/common/singledo
|
|||
|
|||
import ( |
|||
"sync" |
|||
"time" |
|||
) |
|||
|
|||
type call struct { |
|||
wg sync.WaitGroup |
|||
val any |
|||
err error |
|||
} |
|||
|
|||
type Single struct { |
|||
mux sync.Mutex |
|||
last time.Time |
|||
wait time.Duration |
|||
call *call |
|||
result *Result |
|||
} |
|||
|
|||
type Result struct { |
|||
Val any |
|||
Err error |
|||
} |
|||
|
|||
// Do single.Do likes sync.singleFlight
|
|||
//lint:ignore ST1008 it likes sync.singleFlight
|
|||
func (s *Single) Do(fn func() (any, error)) (v any, err error, shared bool) { |
|||
s.mux.Lock() |
|||
now := time.Now() |
|||
if now.Before(s.last.Add(s.wait)) { |
|||
s.mux.Unlock() |
|||
return s.result.Val, s.result.Err, true |
|||
} |
|||
|
|||
if call := s.call; call != nil { |
|||
s.mux.Unlock() |
|||
call.wg.Wait() |
|||
return call.val, call.err, true |
|||
} |
|||
|
|||
call := &call{} |
|||
call.wg.Add(1) |
|||
s.call = call |
|||
s.mux.Unlock() |
|||
call.val, call.err = fn() |
|||
call.wg.Done() |
|||
|
|||
s.mux.Lock() |
|||
s.call = nil |
|||
s.result = &Result{call.val, call.err} |
|||
s.last = now |
|||
s.mux.Unlock() |
|||
return call.val, call.err, false |
|||
} |
|||
|
|||
func (s *Single) Reset() { |
|||
s.last = time.Time{} |
|||
} |
|||
|
|||
func NewSingle(wait time.Duration) *Single { |
|||
return &Single{wait: wait} |
|||
} |
@ -0,0 +1,69 @@ |
|||
package singledo |
|||
|
|||
import ( |
|||
"sync" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/stretchr/testify/assert" |
|||
"go.uber.org/atomic" |
|||
) |
|||
|
|||
func TestBasic(t *testing.T) { |
|||
single := NewSingle(time.Millisecond * 30) |
|||
foo := 0 |
|||
shardCount := atomic.NewInt32(0) |
|||
call := func() (any, error) { |
|||
foo++ |
|||
time.Sleep(time.Millisecond * 5) |
|||
return nil, nil |
|||
} |
|||
|
|||
var wg sync.WaitGroup |
|||
const n = 5 |
|||
wg.Add(n) |
|||
for i := 0; i < n; i++ { |
|||
go func() { |
|||
_, _, shard := single.Do(call) |
|||
if shard { |
|||
shardCount.Inc() |
|||
} |
|||
wg.Done() |
|||
}() |
|||
} |
|||
|
|||
wg.Wait() |
|||
assert.Equal(t, 1, foo) |
|||
assert.Equal(t, int32(4), shardCount.Load()) |
|||
} |
|||
|
|||
func TestTimer(t *testing.T) { |
|||
single := NewSingle(time.Millisecond * 30) |
|||
foo := 0 |
|||
call := func() (any, error) { |
|||
foo++ |
|||
return nil, nil |
|||
} |
|||
|
|||
single.Do(call) |
|||
time.Sleep(10 * time.Millisecond) |
|||
_, _, shard := single.Do(call) |
|||
|
|||
assert.Equal(t, 1, foo) |
|||
assert.True(t, shard) |
|||
} |
|||
|
|||
func TestReset(t *testing.T) { |
|||
single := NewSingle(time.Millisecond * 30) |
|||
foo := 0 |
|||
call := func() (any, error) { |
|||
foo++ |
|||
return nil, nil |
|||
} |
|||
|
|||
single.Do(call) |
|||
single.Reset() |
|||
single.Do(call) |
|||
|
|||
assert.Equal(t, 2, foo) |
|||
} |
@ -1,21 +0,0 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"net" |
|||
"sync" |
|||
) |
|||
|
|||
var _bindOnce sync.Once |
|||
|
|||
// BindToInterface binds dialer to specific interface.
|
|||
func BindToInterface(name string) error { |
|||
i, err := net.InterfaceByName(name) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
_bindOnce.Do(func() { |
|||
addControl(bindToInterface(i)) |
|||
}) |
|||
return nil |
|||
} |
@ -1,32 +0,0 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"net" |
|||
"syscall" |
|||
|
|||
"golang.org/x/sys/unix" |
|||
) |
|||
|
|||
func bindToInterface(i *net.Interface) controlFunc { |
|||
return func(network, address string, c syscall.RawConn) (err error) { |
|||
host, _, _ := net.SplitHostPort(address) |
|||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { |
|||
return |
|||
} |
|||
|
|||
var innerErr error |
|||
err = c.Control(func(fd uintptr) { |
|||
switch network { |
|||
case "tcp4", "udp4": |
|||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, i.Index) |
|||
case "tcp6", "udp6": |
|||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, i.Index) |
|||
} |
|||
}) |
|||
|
|||
if innerErr != nil { |
|||
err = innerErr |
|||
} |
|||
return |
|||
} |
|||
} |
@ -1,27 +0,0 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"net" |
|||
"syscall" |
|||
|
|||
"golang.org/x/sys/unix" |
|||
) |
|||
|
|||
func bindToInterface(i *net.Interface) controlFunc { |
|||
return func(network, address string, c syscall.RawConn) (err error) { |
|||
host, _, _ := net.SplitHostPort(address) |
|||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { |
|||
return nil |
|||
} |
|||
|
|||
var innerErr error |
|||
err = c.Control(func(fd uintptr) { |
|||
innerErr = unix.BindToDevice(int(fd), i.Name) |
|||
}) |
|||
|
|||
if innerErr != nil { |
|||
err = innerErr |
|||
} |
|||
return |
|||
} |
|||
} |
@ -1,15 +0,0 @@ |
|||
//go:build !linux && !darwin
|
|||
|
|||
package dialer |
|||
|
|||
import ( |
|||
"errors" |
|||
"net" |
|||
"syscall" |
|||
) |
|||
|
|||
func bindToInterface(_ *net.Interface) controlFunc { |
|||
return func(string, string, syscall.RawConn) error { |
|||
return errors.New("unsupported platform") |
|||
} |
|||
} |
@ -1,35 +0,0 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"errors" |
|||
"net" |
|||
"syscall" |
|||
) |
|||
|
|||
type controlFunc func(string, string, syscall.RawConn) error |
|||
|
|||
var _controlPool = make([]controlFunc, 0, 2) |
|||
|
|||
func addControl(f controlFunc) { |
|||
_controlPool = append(_controlPool, f) |
|||
} |
|||
|
|||
func setControl(i any) { |
|||
control := func(address, network string, c syscall.RawConn) error { |
|||
for _, f := range _controlPool { |
|||
if err := f(address, network, c); err != nil { |
|||
return err |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
switch v := i.(type) { |
|||
case *net.Dialer: |
|||
v.Control = control |
|||
case *net.ListenConfig: |
|||
v.Control = control |
|||
default: |
|||
panic(errors.New("wrong type")) |
|||
} |
|||
} |
@ -1,14 +0,0 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"sync" |
|||
) |
|||
|
|||
var _setOnce sync.Once |
|||
|
|||
// SetMark sets the mark for each packet sent through this dialer(socket).
|
|||
func SetMark(i int) { |
|||
_setOnce.Do(func() { |
|||
addControl(setMark(i)) |
|||
}) |
|||
} |
@ -1,21 +0,0 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"syscall" |
|||
|
|||
"golang.org/x/sys/unix" |
|||
) |
|||
|
|||
func setMark(m int) controlFunc { |
|||
return func(_, _ string, c syscall.RawConn) (err error) { |
|||
var innerErr error |
|||
err = c.Control(func(fd uintptr) { |
|||
innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, m) |
|||
}) |
|||
|
|||
if innerErr != nil { |
|||
err = innerErr |
|||
} |
|||
return |
|||
} |
|||
} |
@ -1,14 +0,0 @@ |
|||
//go:build !linux
|
|||
|
|||
package dialer |
|||
|
|||
import ( |
|||
"errors" |
|||
"syscall" |
|||
) |
|||
|
|||
func setMark(_ int) controlFunc { |
|||
return func(string, string, syscall.RawConn) error { |
|||
return errors.New("fwmark: linux only") |
|||
} |
|||
} |
@ -0,0 +1,19 @@ |
|||
package dialer |
|||
|
|||
func isTCPSocket(network string) bool { |
|||
switch network { |
|||
case "tcp", "tcp4", "tcp6": |
|||
return true |
|||
default: |
|||
return false |
|||
} |
|||
} |
|||
|
|||
func isUDPSocket(network string) bool { |
|||
switch network { |
|||
case "udp", "udp4", "udp6": |
|||
return true |
|||
default: |
|||
return false |
|||
} |
|||
} |
@ -0,0 +1,61 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"net" |
|||
"syscall" |
|||
"time" |
|||
|
|||
"github.com/xjasonlyu/tun2socks/v2/common/singledo" |
|||
|
|||
"golang.org/x/sys/unix" |
|||
) |
|||
|
|||
var interfaces = singledo.NewSingle(30 * time.Second) |
|||
|
|||
func resolveInterfaceByName(name string) (*net.Interface, error) { |
|||
value, err, _ := interfaces.Do(func() (any, error) { |
|||
return net.InterfaceByName(name) |
|||
}) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return value.(*net.Interface), nil |
|||
} |
|||
|
|||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { |
|||
if !isTCPSocket(network) && !isUDPSocket(network) { |
|||
return |
|||
} |
|||
|
|||
var innerErr error |
|||
err = c.Control(func(fd uintptr) { |
|||
// must be GlobalUnicast.
|
|||
host, _, _ := net.SplitHostPort(address) |
|||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { |
|||
return |
|||
} |
|||
|
|||
if opts.InterfaceName != "" { |
|||
var iface *net.Interface |
|||
iface, innerErr = resolveInterfaceByName(opts.InterfaceName) |
|||
if innerErr != nil { |
|||
return |
|||
} |
|||
|
|||
switch network { |
|||
case "tcp4", "udp4": |
|||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, iface.Index) |
|||
case "tcp6", "udp6": |
|||
innerErr = unix.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_BOUND_IF, iface.Index) |
|||
} |
|||
if innerErr != nil { |
|||
return |
|||
} |
|||
} |
|||
}) |
|||
|
|||
if innerErr != nil { |
|||
err = innerErr |
|||
} |
|||
return |
|||
} |
@ -0,0 +1,39 @@ |
|||
package dialer |
|||
|
|||
import ( |
|||
"net" |
|||
"syscall" |
|||
|
|||
"golang.org/x/sys/unix" |
|||
) |
|||
|
|||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) (err error) { |
|||
if !isTCPSocket(network) && !isUDPSocket(network) { |
|||
return |
|||
} |
|||
|
|||
var innerErr error |
|||
err = c.Control(func(fd uintptr) { |
|||
// must be GlobalUnicast.
|
|||
host, _, _ := net.SplitHostPort(address) |
|||
if ip := net.ParseIP(host); ip != nil && !ip.IsGlobalUnicast() { |
|||
return |
|||
} |
|||
|
|||
if opts.InterfaceName != "" { |
|||
if innerErr = unix.BindToDevice(int(fd), opts.InterfaceName); innerErr != nil { |
|||
return |
|||
} |
|||
} |
|||
if opts.RoutingMark != 0 { |
|||
if innerErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, opts.RoutingMark); innerErr != nil { |
|||
return |
|||
} |
|||
} |
|||
}) |
|||
|
|||
if innerErr != nil { |
|||
err = innerErr |
|||
} |
|||
return |
|||
} |
@ -0,0 +1,7 @@ |
|||
//go:build !linux && !darwin
|
|||
|
|||
package dialer |
|||
|
|||
func setSocketOptions(network, address string, c syscall.RawConn, opts *Options) error { |
|||
return nil |
|||
} |
Loading…
Reference in new issue