Browse Source

Refactor: main.go

pull/133/head
xjasonlyu 3 years ago
parent
commit
c6ca52326a
  1. 58
      engine/engine.go
  2. 38
      main.go

58
engine/engine.go

@ -2,20 +2,15 @@ package engine
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"os"
"github.com/xjasonlyu/tun2socks/v2/component/dialer" "github.com/xjasonlyu/tun2socks/v2/component/dialer"
"github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/device"
"github.com/xjasonlyu/tun2socks/v2/core/stack" "github.com/xjasonlyu/tun2socks/v2/core/stack"
"github.com/xjasonlyu/tun2socks/v2/internal/version"
"github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/log"
"github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/stats" "github.com/xjasonlyu/tun2socks/v2/stats"
"github.com/xjasonlyu/tun2socks/v2/tunnel" "github.com/xjasonlyu/tun2socks/v2/tunnel"
"gopkg.in/yaml.v3"
) )
var _engine = &engine{} var _engine = &engine{}
@ -45,8 +40,6 @@ type Key struct {
Device string `yaml:"device"` Device string `yaml:"device"`
LogLevel string `yaml:"loglevel"` LogLevel string `yaml:"loglevel"`
Interface string `yaml:"interface"` Interface string `yaml:"interface"`
Config string `yaml:"-"`
Version bool `yaml:"-"`
} }
type engine struct { type engine struct {
@ -62,22 +55,15 @@ func (e *engine) start() error {
return errors.New("empty key") return errors.New("empty key")
} }
if e.Version {
fmt.Println(version.String())
fmt.Println(version.BuildString())
os.Exit(0)
}
for _, f := range []func() error{ for _, f := range []func() error{
e.setConfig, e.applyLogLevel,
e.setLogLevel, e.applyMark,
e.setMark, e.applyInterface,
e.setInterface, e.applyStats,
e.setStats, e.applyUDPTimeout,
e.setUDPTimeout, e.applyProxy,
e.setProxy, e.applyDevice,
e.setDevice, e.applyStack,
e.setStack,
} { } {
if err := f(); err != nil { if err := f(); err != nil {
return err return err
@ -97,19 +83,7 @@ func (e *engine) insert(k *Key) {
e.Key = k e.Key = k
} }
func (e *engine) setConfig() error { func (e *engine) applyLogLevel() error {
if e.Config == "" {
return nil
}
data, err := os.ReadFile(e.Config)
if err != nil {
return err
}
return yaml.Unmarshal(data, e.Key)
}
func (e *engine) setLogLevel() error {
level, err := log.ParseLevel(e.LogLevel) level, err := log.ParseLevel(e.LogLevel)
if err != nil { if err != nil {
return err return err
@ -118,7 +92,7 @@ func (e *engine) setLogLevel() error {
return nil return nil
} }
func (e *engine) setMark() error { func (e *engine) applyMark() error {
if e.Mark != 0 { if e.Mark != 0 {
dialer.SetMark(e.Mark) dialer.SetMark(e.Mark)
log.Infof("[DIALER] set fwmark: %#x", e.Mark) log.Infof("[DIALER] set fwmark: %#x", e.Mark)
@ -126,7 +100,7 @@ func (e *engine) setMark() error {
return nil return nil
} }
func (e *engine) setInterface() error { func (e *engine) applyInterface() error {
if e.Interface != "" { if e.Interface != "" {
if err := dialer.BindToInterface(e.Interface); err != nil { if err := dialer.BindToInterface(e.Interface); err != nil {
return err return err
@ -136,7 +110,7 @@ func (e *engine) setInterface() error {
return nil return nil
} }
func (e *engine) setStats() error { func (e *engine) applyStats() error {
if e.Stats != "" { if e.Stats != "" {
addr, err := net.ResolveTCPAddr("tcp", e.Stats) addr, err := net.ResolveTCPAddr("tcp", e.Stats)
if err != nil { if err != nil {
@ -153,14 +127,14 @@ func (e *engine) setStats() error {
return nil return nil
} }
func (e *engine) setUDPTimeout() error { func (e *engine) applyUDPTimeout() error {
if e.UDPTimeout > 0 { if e.UDPTimeout > 0 {
tunnel.SetUDPTimeout(e.UDPTimeout) tunnel.SetUDPTimeout(e.UDPTimeout)
} }
return nil return nil
} }
func (e *engine) setProxy() (err error) { func (e *engine) applyProxy() (err error) {
if e.Proxy == "" { if e.Proxy == "" {
return errors.New("empty proxy") return errors.New("empty proxy")
} }
@ -170,7 +144,7 @@ func (e *engine) setProxy() (err error) {
return return
} }
func (e *engine) setDevice() (err error) { func (e *engine) applyDevice() (err error) {
if e.Device == "" { if e.Device == "" {
return errors.New("empty device") return errors.New("empty device")
} }
@ -179,7 +153,7 @@ func (e *engine) setDevice() (err error) {
return return
} }
func (e *engine) setStack() (err error) { func (e *engine) applyStack() (err error) {
defer func() { defer func() {
if err == nil { if err == nil {
log.Infof( log.Infof(

38
main.go

@ -2,24 +2,32 @@ package main
import ( import (
"flag" "flag"
"fmt"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/xjasonlyu/tun2socks/v2/engine" "github.com/xjasonlyu/tun2socks/v2/engine"
"github.com/xjasonlyu/tun2socks/v2/internal/version"
"github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/log"
"go.uber.org/automaxprocs/maxprocs" "go.uber.org/automaxprocs/maxprocs"
"gopkg.in/yaml.v3"
) )
var key = new(engine.Key) var (
key = new(engine.Key)
configFile string
versionFlag bool
)
func init() { func init() {
flag.BoolVar(&versionFlag, "version", false, "Show version and then quit")
flag.IntVar(&key.Mark, "fwmark", 0, "Set firewall MARK (Linux only)") flag.IntVar(&key.Mark, "fwmark", 0, "Set firewall MARK (Linux only)")
flag.IntVar(&key.MTU, "mtu", 0, "Set device maximum transmission unit (MTU)") flag.IntVar(&key.MTU, "mtu", 0, "Set device maximum transmission unit (MTU)")
flag.IntVar(&key.UDPTimeout, "udp-timeout", 0, "Set timeout for each UDP session") flag.IntVar(&key.UDPTimeout, "udp-timeout", 0, "Set timeout for each UDP session")
flag.BoolVar(&key.Version, "version", false, "Show version information and quit") flag.StringVar(&configFile, "config", "", "YAML format configuration file")
flag.StringVar(&key.Config, "config", "", "YAML format configuration file")
flag.StringVar(&key.Device, "device", "", "Use this device [driver://]name") flag.StringVar(&key.Device, "device", "", "Use this device [driver://]name")
flag.StringVar(&key.Interface, "interface", "", "Use network INTERFACE (Linux/MacOS only)") flag.StringVar(&key.Interface, "interface", "", "Use network INTERFACE (Linux/MacOS only)")
flag.StringVar(&key.LogLevel, "loglevel", "info", "Log level [debug|info|warning|error|silent]") flag.StringVar(&key.LogLevel, "loglevel", "info", "Log level [debug|info|warning|error|silent]")
@ -32,16 +40,32 @@ func init() {
func main() { func main() {
maxprocs.Set(maxprocs.Logger(func(string, ...any) {})) maxprocs.Set(maxprocs.Logger(func(string, ...any) {}))
if versionFlag {
fmt.Println(version.String())
fmt.Println(version.BuildString())
os.Exit(0)
}
if configFile != "" {
data, err := os.ReadFile(configFile)
if err != nil {
log.Fatalf("Failed to read config %s: %v", configFile, err)
}
if err = yaml.Unmarshal(data, key); err != nil {
log.Fatalf("Failed to unmarshal config %s: %v", configFile, err)
}
}
engine.Insert(key) engine.Insert(key)
checkErr := func(msg string, f func() error) { assert := func(msg string, err error) {
if err := f(); err != nil { if err != nil {
log.Fatalf("Failed to %s: %v", msg, err) log.Fatalf("Failed to %s: %v", msg, err)
} }
} }
checkErr("start engine", engine.Start) assert("start engine", engine.Start())
defer checkErr("stop engine", engine.Stop) defer assert("stop engine", engine.Stop())
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)

Loading…
Cancel
Save