From 5d54629cdc55656518a9cf0a9ddb21eb713b7ffb Mon Sep 17 00:00:00 2001 From: Jason Date: Sat, 27 Jul 2019 14:08:57 +0800 Subject: [PATCH] use lru cache to avoid outdate --- common/cache/cache_test.go | 30 +++---- common/cache/lrucache.go | 148 +++++++++++++++++++++++++++++++ common/cache/lrucache_test.go | 117 ++++++++++++++++++++++++ common/dns/fakedns/middleware.go | 33 +++---- common/dns/fakedns/server.go | 27 ++---- common/dns/fakedns/utils.go | 24 ----- common/fakeip/pool.go | 80 ++++++++++++++--- common/fakeip/pool_test.go | 51 ++++++----- 8 files changed, 393 insertions(+), 117 deletions(-) create mode 100755 common/cache/lrucache.go create mode 100755 common/cache/lrucache_test.go diff --git a/common/cache/cache_test.go b/common/cache/cache_test.go index 101ca86..cf4a391 100755 --- a/common/cache/cache_test.go +++ b/common/cache/cache_test.go @@ -4,6 +4,8 @@ import ( "runtime" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestCache_Basic(t *testing.T) { @@ -14,32 +16,30 @@ func TestCache_Basic(t *testing.T) { c.Put("string", "a", ttl) i := c.Get("int") - if i.(int) != 1 { - t.Error("should recv 1") - } + assert.Equal(t, i.(int), 1, "should recv 1") s := c.Get("string") - if s.(string) != "a" { - t.Error("should recv 'a'") - } + assert.Equal(t, s.(string), "a", "should recv 'a'") } func TestCache_TTL(t *testing.T) { interval := 200 * time.Millisecond ttl := 20 * time.Millisecond + now := time.Now() c := New(interval) c.Put("int", 1, ttl) + c.Put("int2", 2, ttl) i := c.Get("int") - if i.(int) != 1 { - t.Error("should recv 1") - } + _, expired := c.GetWithExpire("int2") + assert.Equal(t, i.(int), 1, "should recv 1") + assert.True(t, now.Before(expired)) time.Sleep(ttl * 2) i = c.Get("int") - if i != nil { - t.Error("should recv nil") - } + j, _ := c.GetWithExpire("int2") + assert.Nil(t, i, "should recv nil") + assert.Nil(t, j, "should recv nil") } func TestCache_AutoCleanup(t *testing.T) { @@ -50,9 +50,9 @@ func TestCache_AutoCleanup(t *testing.T) { time.Sleep(ttl * 2) i := c.Get("int") - if i != nil { - t.Error("should recv nil") - } + j, _ := c.GetWithExpire("int") + assert.Nil(t, i, "should recv nil") + assert.Nil(t, j, "should recv nil") } func TestCache_AutoGC(t *testing.T) { diff --git a/common/cache/lrucache.go b/common/cache/lrucache.go new file mode 100755 index 0000000..5a139bf --- /dev/null +++ b/common/cache/lrucache.go @@ -0,0 +1,148 @@ +package cache + +// Modified by https://github.com/die-net/lrucache + +import ( + "container/list" + "sync" + "time" +) + +// Option is part of Functional Options Pattern +type Option func(*LruCache) + +// WithUpdateAgeOnGet update expires when Get element +func WithUpdateAgeOnGet() Option { + return func(l *LruCache) { + l.updateAgeOnGet = true + } +} + +// WithAge defined element max age (second) +func WithAge(maxAge int64) Option { + return func(l *LruCache) { + l.maxAge = maxAge + } +} + +// WithSize defined max length of LruCache +func WithSize(maxSize int) Option { + return func(l *LruCache) { + l.maxSize = maxSize + } +} + +// LruCache is a thread-safe, in-memory lru-cache that evicts the +// least recently used entries from memory when (if set) the entries are +// older than maxAge (in seconds). Use the New constructor to create one. +type LruCache struct { + maxAge int64 + maxSize int + mu sync.Mutex + cache map[interface{}]*list.Element + lru *list.List // Front is least-recent + updateAgeOnGet bool +} + +// NewLRUCache creates an LruCache +func NewLRUCache(options ...Option) *LruCache { + lc := &LruCache{ + lru: list.New(), + cache: make(map[interface{}]*list.Element), + } + + for _, option := range options { + option(lc) + } + + return lc +} + +// Get returns the interface{} representation of a cached response and a bool +// set to true if the key was found. +func (c *LruCache) Get(key interface{}) (interface{}, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + le, ok := c.cache[key] + if !ok { + return nil, false + } + + if c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() { + c.deleteElement(le) + c.maybeDeleteOldest() + + return nil, false + } + + c.lru.MoveToBack(le) + entry := le.Value.(*entry) + if c.maxAge > 0 && c.updateAgeOnGet { + entry.expires = time.Now().Unix() + c.maxAge + } + value := entry.value + + return value, true +} + +// Set stores the interface{} representation of a response for a given key. +func (c *LruCache) Set(key interface{}, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + expires := int64(0) + if c.maxAge > 0 { + expires = time.Now().Unix() + c.maxAge + } + + if le, ok := c.cache[key]; ok { + c.lru.MoveToBack(le) + e := le.Value.(*entry) + e.value = value + e.expires = expires + } else { + e := &entry{key: key, value: value, expires: expires} + c.cache[key] = c.lru.PushBack(e) + + if c.maxSize > 0 { + if len := c.lru.Len(); len > c.maxSize { + c.deleteElement(c.lru.Front()) + } + } + } + + c.maybeDeleteOldest() +} + +// Delete removes the value associated with a key. +func (c *LruCache) Delete(key string) { + c.mu.Lock() + + if le, ok := c.cache[key]; ok { + c.deleteElement(le) + } + + c.mu.Unlock() +} + +func (c *LruCache) maybeDeleteOldest() { + if c.maxAge > 0 { + now := time.Now().Unix() + for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() { + c.deleteElement(le) + } + } +} + +func (c *LruCache) deleteElement(le *list.Element) { + c.lru.Remove(le) + e := le.Value.(*entry) + delete(c.cache, e.key) +} + +type entry struct { + key interface{} + value interface{} + expires int64 +} diff --git a/common/cache/lrucache_test.go b/common/cache/lrucache_test.go new file mode 100755 index 0000000..31f9a91 --- /dev/null +++ b/common/cache/lrucache_test.go @@ -0,0 +1,117 @@ +package cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var entries = []struct { + key string + value string +}{ + {"1", "one"}, + {"2", "two"}, + {"3", "three"}, + {"4", "four"}, + {"5", "five"}, +} + +func TestLRUCache(t *testing.T) { + c := NewLRUCache() + + for _, e := range entries { + c.Set(e.key, e.value) + } + + c.Delete("missing") + _, ok := c.Get("missing") + assert.False(t, ok) + + for _, e := range entries { + value, ok := c.Get(e.key) + if assert.True(t, ok) { + assert.Equal(t, e.value, value.(string)) + } + } + + for _, e := range entries { + c.Delete(e.key) + + _, ok := c.Get(e.key) + assert.False(t, ok) + } +} + +func TestLRUMaxAge(t *testing.T) { + c := NewLRUCache(WithAge(86400)) + + now := time.Now().Unix() + expected := now + 86400 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = now + + // Reset + c.Set("foo", "bar") + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= now) + c.lru.Back().Value.(*entry).expires = now + + // Set a few and verify expiration times + for _, s := range entries { + c.Set(s.key, s.value) + e := c.lru.Back().Value.(*entry) + assert.True(t, e.expires >= expected && e.expires <= expected+10) + } + + // Make sure we can get them all + for _, s := range entries { + _, ok := c.Get(s.key) + assert.True(t, ok) + } + + // Expire all entries + for _, s := range entries { + le, ok := c.cache[s.key] + if assert.True(t, ok) { + le.Value.(*entry).expires = now + } + } + + // Get one expired entry, which should clear all expired entries + _, ok := c.Get("3") + assert.False(t, ok) + assert.Equal(t, c.lru.Len(), 0) +} + +func TestLRUpdateOnGet(t *testing.T) { + c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet()) + + now := time.Now().Unix() + expires := now + 86400/2 + + // Add one expired entry + c.Set("foo", "bar") + c.lru.Back().Value.(*entry).expires = expires + + _, ok := c.Get("foo") + assert.True(t, ok) + assert.True(t, c.lru.Back().Value.(*entry).expires > expires) +} + +func TestMaxSize(t *testing.T) { + c := NewLRUCache(WithSize(2)) + // Add one expired entry + c.Set("foo", "bar") + _, ok := c.Get("foo") + assert.True(t, ok) + + c.Set("bar", "foo") + c.Set("baz", "foo") + + _, ok = c.Get("foo") + assert.False(t, ok) +} diff --git a/common/dns/fakedns/middleware.go b/common/dns/fakedns/middleware.go index 6686dc6..f7f9c42 100644 --- a/common/dns/fakedns/middleware.go +++ b/common/dns/fakedns/middleware.go @@ -6,40 +6,27 @@ import ( "strings" D "github.com/miekg/dns" - "github.com/xjasonlyu/tun2socks/common/cache" trie "github.com/xjasonlyu/tun2socks/common/domain-trie" "github.com/xjasonlyu/tun2socks/common/fakeip" ) type handler func(w D.ResponseWriter, r *D.Msg) -func withFakeIP(cache *cache.Cache, pool *fakeip.Pool) handler { +func withFakeIP(pool *fakeip.Pool) handler { return func(w D.ResponseWriter, r *D.Msg) { q := r.Question[0] - - if msg := getMsgFromCache(cache, "fakeip:"+q.String()); msg != nil { - // Update Cache TTL - ip := msg.Answer[0].(*D.A).A - putMsgToCache(cache, ip.String(), msg) - - setMsgTTL(msg, dnsFakeTTL) - msg.SetReply(r) - _ = w.WriteMsg(msg) - return - } + host := strings.TrimRight(q.Name, ".") rr := &D.A{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - ip := pool.Get() + ip := pool.Lookup(host) rr.A = ip msg := r.Copy() msg.Answer = []D.RR{rr} - putMsgToCache(cache, "fakeip:"+q.String(), msg) - putMsgToCache(cache, ip.String(), msg) - - setMsgTTL(msg, dnsFakeTTL) - _ = w.WriteMsg(msg) + setMsgTTL(msg, 1) + msg.SetReply(r) + w.WriteMsg(msg) return } } @@ -88,7 +75,7 @@ func withHost(hosts *trie.Trie, next handler) handler { msg := r.Copy() msg.Answer = []D.RR{rr} msg.SetReply(r) - _ = w.WriteMsg(msg) + w.WriteMsg(msg) return } } @@ -115,9 +102,9 @@ func lineToHosts(str string) *trie.Trie { return tree } -func newHandler(hosts *trie.Trie, cache *cache.Cache, pool *fakeip.Pool) handler { +func newHandler(hosts *trie.Trie, pool *fakeip.Pool) handler { if hosts != nil { - return withHost(hosts, withFakeIP(cache, pool)) + return withHost(hosts, withFakeIP(pool)) } - return withFakeIP(cache, pool) + return withFakeIP(pool) } diff --git a/common/dns/fakedns/server.go b/common/dns/fakedns/server.go index 508f999..27a6035 100644 --- a/common/dns/fakedns/server.go +++ b/common/dns/fakedns/server.go @@ -3,24 +3,21 @@ package fakedns import ( "errors" "net" - "strings" - "time" D "github.com/miekg/dns" - "github.com/xjasonlyu/tun2socks/common/cache" "github.com/xjasonlyu/tun2socks/common/fakeip" ) const ( + lruCacheSize = 1000 + dnsFakeTTL uint32 = 1 dnsDefaultTTL uint32 = 600 ) -var cacheDuration = time.Duration(dnsDefaultTTL) * time.Second - type Server struct { *D.Server - c *cache.Cache + p *fakeip.Pool h handler } @@ -50,20 +47,13 @@ func (s *Server) StartServer(addr string) error { } s.Server = &D.Server{Addr: addr, PacketConn: p, Handler: s} + go s.ActivateAndServe() - go func() { - _ = s.ActivateAndServe() - }() return nil } func (s *Server) IPToHost(ip net.IP) (string, bool) { - msg := getMsgFromCache(s.c, ip.String()) - if msg == nil { - return "", false - } - fqdn := msg.Question[0].Name - return strings.TrimRight(fqdn, "."), true + return s.p.LookBack(ip) } func NewServer(fakeIPRange, hostsLine string) (*Server, error) { @@ -71,17 +61,16 @@ func NewServer(fakeIPRange, hostsLine string) (*Server, error) { if err != nil { return nil, err } - pool, err := fakeip.New(ipnet) + pool, err := fakeip.New(ipnet, lruCacheSize) if err != nil { return nil, err } hosts := lineToHosts(hostsLine) - cacheItem := cache.New(cacheDuration) - handler := newHandler(hosts, cacheItem, pool) + handler := newHandler(hosts, pool) return &Server{ - c: cacheItem, + p: pool, h: handler, }, nil } diff --git a/common/dns/fakedns/utils.go b/common/dns/fakedns/utils.go index b19d36a..dfbc9aa 100644 --- a/common/dns/fakedns/utils.go +++ b/common/dns/fakedns/utils.go @@ -1,33 +1,9 @@ package fakedns import ( - "strings" - "time" - D "github.com/miekg/dns" - "github.com/xjasonlyu/tun2socks/common/cache" ) -func putMsgToCache(c *cache.Cache, key string, msg *D.Msg) { - var ttl time.Duration - if strings.HasPrefix(key, "fakeip:") { - ttl = time.Duration(dnsDefaultTTL) * time.Second - } else { - ttl = 3 * time.Duration(dnsDefaultTTL) * time.Second - } - c.Put(key, msg.Copy(), ttl) -} - -func getMsgFromCache(c *cache.Cache, key string) (msg *D.Msg) { - item := c.Get(key) - if item == nil { - return - } - msg = item.(*D.Msg).Copy() - putMsgToCache(c, key, msg) - return -} - func setMsgTTL(msg *D.Msg, ttl uint32) { for _, answer := range msg.Answer { answer.Header().Ttl = ttl diff --git a/common/fakeip/pool.go b/common/fakeip/pool.go index 32d5d57..1427e80 100755 --- a/common/fakeip/pool.go +++ b/common/fakeip/pool.go @@ -4,22 +4,72 @@ import ( "errors" "net" "sync" + + "github.com/xjasonlyu/tun2socks/common/cache" ) // Pool is a implementation about fake ip generator without storage type Pool struct { - max uint32 - min uint32 - offset uint32 - mux *sync.Mutex + max uint32 + min uint32 + gateway uint32 + offset uint32 + mux *sync.Mutex + cache *cache.LruCache +} + +// Lookup return a fake ip with host +func (p *Pool) Lookup(host string) net.IP { + p.mux.Lock() + defer p.mux.Unlock() + if ip, exist := p.cache.Get(host); exist { + return ip.(net.IP) + } + + ip := p.get(host) + p.cache.Set(host, ip) + return ip } -// Get return a new fake ip -func (p *Pool) Get() net.IP { +// LookBack return host with the fake ip +func (p *Pool) LookBack(ip net.IP) (string, bool) { p.mux.Lock() defer p.mux.Unlock() - ip := uintToIP(p.min + p.offset) - p.offset = (p.offset + 1) % (p.max - p.min) + + if ip = ip.To4(); ip == nil { + return "", false + } + + n := ipToUint(ip.To4()) + offset := n - p.min + 1 + + if host, exist := p.cache.Get(offset); exist { + return host.(string), true + } + + return "", false +} + +// Gateway return gateway ip +func (p *Pool) Gateway() net.IP { + return uintToIP(p.gateway) +} + +func (p *Pool) get(host string) net.IP { + current := p.offset + for { + p.offset = (p.offset + 1) % (p.max - p.min) + // Avoid infinite loops + if p.offset == current { + break + } + + if _, exist := p.cache.Get(p.offset); !exist { + break + } + } + ip := uintToIP(p.min + p.offset - 1) + p.cache.Set(p.offset, host) return ip } @@ -36,8 +86,8 @@ func uintToIP(v uint32) net.IP { } // New return Pool instance -func New(ipnet *net.IPNet) (*Pool, error) { - min := ipToUint(ipnet.IP) + 1 +func New(ipnet *net.IPNet, size int) (*Pool, error) { + min := ipToUint(ipnet.IP) + 2 ones, bits := ipnet.Mask.Size() total := 1<