Browse Source

update to statsConn

pull/15/head
Jason 5 years ago
parent
commit
bb59dfed62
  1. 8
      common/stats/stats.go
  2. 63
      proxy/socks/tcp.go
  3. 81
      proxy/socks/udp.go

8
common/stats/stats.go

@ -25,14 +25,6 @@ type Session struct {
SessionStart time.Time
}
func (s *Session) AddUploadBytes(n int64) {
atomic.AddInt64(&s.UploadBytes, n)
}
func (s *Session) AddDownloadBytes(n int64) {
atomic.AddInt64(&s.DownloadBytes, n)
}
type SessionConn struct {
net.Conn
*Session

63
proxy/socks/tcp.go

@ -43,48 +43,6 @@ func NewTCPHandler(proxyHost string, proxyPort uint16, fakeDns dns.FakeDns, sess
}
}
type direction byte
const (
dirUplink direction = iota
dirDownlink
)
func statsCopy(dst io.Writer, src io.Reader, sess *stats.Session, dir direction) (written int64, err error) {
buf := make([]byte, 64*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
switch dir {
case dirUplink:
sess.AddUploadBytes(int64(nw))
case dirDownlink:
sess.AddDownloadBytes(int64(nw))
default:
}
written += int64(nw)
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return
}
func (h *tcpHandler) relay(localConn, remoteConn net.Conn, sess *stats.Session) {
upCh := make(chan struct{})
@ -96,21 +54,13 @@ func (h *tcpHandler) relay(localConn, remoteConn net.Conn, sess *stats.Session)
// UpLink
go func() {
if h.sessionStater != nil && sess != nil {
statsCopy(remoteConn, localConn, sess, dirUplink)
} else {
io.Copy(remoteConn, localConn)
}
io.Copy(remoteConn, localConn)
remoteConn.SetReadDeadline(time.Now())
upCh <- struct{}{}
}()
// DownLink
if h.sessionStater != nil && sess != nil {
statsCopy(localConn, remoteConn, sess, dirDownlink)
} else {
io.Copy(localConn, remoteConn)
}
io.Copy(localConn, remoteConn)
localConn.SetReadDeadline(time.Now())
<-upCh // Wait for UpLink done.
@ -143,16 +93,13 @@ func (h *tcpHandler) Handle(localConn net.Conn, target *net.TCPAddr) error {
return err
}
var process string
var process = "N/A"
var sess *stats.Session
if h.sessionStater != nil {
// Get name of the process.
localHost, localPortStr, _ := net.SplitHostPort(localConn.LocalAddr().String())
localPortInt, _ := strconv.Atoi(localPortStr)
process, err = lsof.GetCommandNameBySocket(target.Network(), localHost, uint16(localPortInt))
if err != nil {
process = "N/A"
}
process, _ = lsof.GetCommandNameBySocket(target.Network(), localHost, uint16(localPortInt))
sess = &stats.Session{
ProcessName: process,
@ -165,6 +112,8 @@ func (h *tcpHandler) Handle(localConn net.Conn, target *net.TCPAddr) error {
SessionStart: time.Now(),
}
h.sessionStater.AddSession(localConn, sess)
remoteConn = stats.NewSessionConn(remoteConn, sess)
}
// set keepalive

81
proxy/socks/udp.go

@ -74,7 +74,7 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn) {
input.SetDeadline(time.Now().Add(h.timeout))
n, _, err := input.ReadFrom(buf)
if err != nil {
// log.Printf("read remote failed: %v", err)
log.Warnf("read remote failed: %v", err)
return
}
@ -83,12 +83,8 @@ func (h *udpHandler) fetchUDPInput(conn core.UDPConn, input net.PacketConn) {
if err != nil {
return
}
n, err = conn.WriteFrom(buf[int(3+len(addr)):n], resolvedAddr)
if n > 0 && h.sessionStater != nil {
if sess := h.sessionStater.GetSession(conn); sess != nil {
sess.AddDownloadBytes(int64(n))
}
}
_, err = conn.WriteFrom(buf[int(3+len(addr)):n], resolvedAddr)
if err != nil {
log.Warnf("write local failed: %v", err)
return
@ -109,6 +105,9 @@ func (h *udpHandler) Connect(conn core.UDPConn, target *net.UDPAddr) error {
}
}
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(target.Port))
if len(targetAddr) == 0 {
return errors.New("target address is empty")
}
return h.connectInternal(conn, targetAddr)
}
@ -118,7 +117,7 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error
if err != nil {
return err
}
remoteConn.SetDeadline(time.Now().Add(4 * time.Second))
remoteConn.SetDeadline(time.Now().Add(5 * time.Second))
// send VER, NMETHODS, METHODS
remoteConn.Write([]byte{5, 1, 0})
@ -129,12 +128,13 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error
return err
}
if len(targetAddr) != 0 {
targetAddr := ParseAddr(targetAddr)
// write VER CMD RSV ATYP DST.ADDR DST.PORT
_, _ = remoteConn.Write(append([]byte{5, socks5UDPAssociate, 0}, targetAddr...))
} else {
switch len(targetAddr) {
case 0:
_, _ = remoteConn.Write(append([]byte{5, socks5UDPAssociate, 0}, []byte{1, 0, 0, 0, 0, 0, 0}...))
default:
destination := ParseAddr(targetAddr)
// write VER CMD RSV ATYP DST.ADDR DST.PORT
_, _ = remoteConn.Write(append([]byte{5, socks5UDPAssociate, 0}, destination...))
}
// read VER REP RSV ATYP BND.ADDR BND.PORT
@ -164,6 +164,28 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error
return err
}
var process = "N/A"
if h.sessionStater != nil {
// Get name of the process.
localHost, localPortStr, _ := net.SplitHostPort(conn.LocalAddr().String())
localPortInt, _ := strconv.Atoi(localPortStr)
process, _ = lsof.GetCommandNameBySocket(conn.LocalAddr().Network(), localHost, uint16(localPortInt))
sess := &stats.Session{
ProcessName: process,
Network: conn.LocalAddr().Network(),
DialerAddr: remoteConn.LocalAddr().String(),
ClientAddr: conn.LocalAddr().String(),
TargetAddr: targetAddr,
UploadBytes: 0,
DownloadBytes: 0,
SessionStart: time.Now(),
}
h.sessionStater.AddSession(conn, sess)
remoteUDPConn = stats.NewSessionPacketConn(remoteUDPConn, sess)
}
h.Lock()
h.tcpConns[conn] = remoteConn
h.udpConns[conn] = remoteUDPConn
@ -172,31 +194,7 @@ func (h *udpHandler) connectInternal(conn core.UDPConn, targetAddr string) error
go h.fetchUDPInput(conn, remoteUDPConn)
if len(targetAddr) != 0 {
var process string
if h.sessionStater != nil {
// Get name of the process.
localHost, localPortStr, _ := net.SplitHostPort(conn.LocalAddr().String())
localPortInt, _ := strconv.Atoi(localPortStr)
process, err = lsof.GetCommandNameBySocket(conn.LocalAddr().Network(), localHost, uint16(localPortInt))
if err != nil {
process = "N/A"
}
sess := &stats.Session{
ProcessName: process,
Network: conn.LocalAddr().Network(),
DialerAddr: remoteConn.LocalAddr().String(),
ClientAddr: conn.LocalAddr().String(),
TargetAddr: targetAddr,
UploadBytes: 0,
DownloadBytes: 0,
SessionStart: time.Now(),
}
h.sessionStater.AddSession(conn, sess)
}
log.Access(process, "proxy", "udp", conn.LocalAddr().String(), targetAddr)
}
log.Access(process, "proxy", "udp", conn.LocalAddr().String(), targetAddr)
return nil
}
@ -218,12 +216,7 @@ func (h *udpHandler) ReceiveTo(conn core.UDPConn, data []byte, addr *net.UDPAddr
targetAddr := net.JoinHostPort(targetHost, strconv.Itoa(addr.Port))
buf := append([]byte{0, 0, 0}, ParseAddr(targetAddr)...)
buf = append(buf, data[:]...)
n, err := remoteUDPConn.WriteTo(buf, remoteAddr)
if n > 0 && h.sessionStater != nil {
if sess := h.sessionStater.GetSession(conn); sess != nil {
sess.AddUploadBytes(int64(n))
}
}
_, err := remoteUDPConn.WriteTo(buf, remoteAddr)
if err != nil {
h.Close(conn)
return errors.New(fmt.Sprintf("write remote failed: %v", err))

Loading…
Cancel
Save