Browse Source

use lru cache to avoid outdate

pull/15/head
Jason 5 years ago
parent
commit
5d54629cdc
  1. 30
      common/cache/cache_test.go
  2. 148
      common/cache/lrucache.go
  3. 117
      common/cache/lrucache_test.go
  4. 33
      common/dns/fakedns/middleware.go
  5. 27
      common/dns/fakedns/server.go
  6. 24
      common/dns/fakedns/utils.go
  7. 80
      common/fakeip/pool.go
  8. 51
      common/fakeip/pool_test.go

30
common/cache/cache_test.go

@ -4,6 +4,8 @@ import (
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCache_Basic(t *testing.T) {
@ -14,32 +16,30 @@ func TestCache_Basic(t *testing.T) {
c.Put("string", "a", ttl)
i := c.Get("int")
if i.(int) != 1 {
t.Error("should recv 1")
}
assert.Equal(t, i.(int), 1, "should recv 1")
s := c.Get("string")
if s.(string) != "a" {
t.Error("should recv 'a'")
}
assert.Equal(t, s.(string), "a", "should recv 'a'")
}
func TestCache_TTL(t *testing.T) {
interval := 200 * time.Millisecond
ttl := 20 * time.Millisecond
now := time.Now()
c := New(interval)
c.Put("int", 1, ttl)
c.Put("int2", 2, ttl)
i := c.Get("int")
if i.(int) != 1 {
t.Error("should recv 1")
}
_, expired := c.GetWithExpire("int2")
assert.Equal(t, i.(int), 1, "should recv 1")
assert.True(t, now.Before(expired))
time.Sleep(ttl * 2)
i = c.Get("int")
if i != nil {
t.Error("should recv nil")
}
j, _ := c.GetWithExpire("int2")
assert.Nil(t, i, "should recv nil")
assert.Nil(t, j, "should recv nil")
}
func TestCache_AutoCleanup(t *testing.T) {
@ -50,9 +50,9 @@ func TestCache_AutoCleanup(t *testing.T) {
time.Sleep(ttl * 2)
i := c.Get("int")
if i != nil {
t.Error("should recv nil")
}
j, _ := c.GetWithExpire("int")
assert.Nil(t, i, "should recv nil")
assert.Nil(t, j, "should recv nil")
}
func TestCache_AutoGC(t *testing.T) {

148
common/cache/lrucache.go

@ -0,0 +1,148 @@
package cache
// Modified by https://github.com/die-net/lrucache
import (
"container/list"
"sync"
"time"
)
// Option is part of Functional Options Pattern
type Option func(*LruCache)
// WithUpdateAgeOnGet update expires when Get element
func WithUpdateAgeOnGet() Option {
return func(l *LruCache) {
l.updateAgeOnGet = true
}
}
// WithAge defined element max age (second)
func WithAge(maxAge int64) Option {
return func(l *LruCache) {
l.maxAge = maxAge
}
}
// WithSize defined max length of LruCache
func WithSize(maxSize int) Option {
return func(l *LruCache) {
l.maxSize = maxSize
}
}
// LruCache is a thread-safe, in-memory lru-cache that evicts the
// least recently used entries from memory when (if set) the entries are
// older than maxAge (in seconds). Use the New constructor to create one.
type LruCache struct {
maxAge int64
maxSize int
mu sync.Mutex
cache map[interface{}]*list.Element
lru *list.List // Front is least-recent
updateAgeOnGet bool
}
// NewLRUCache creates an LruCache
func NewLRUCache(options ...Option) *LruCache {
lc := &LruCache{
lru: list.New(),
cache: make(map[interface{}]*list.Element),
}
for _, option := range options {
option(lc)
}
return lc
}
// Get returns the interface{} representation of a cached response and a bool
// set to true if the key was found.
func (c *LruCache) Get(key interface{}) (interface{}, bool) {
c.mu.Lock()
defer c.mu.Unlock()
le, ok := c.cache[key]
if !ok {
return nil, false
}
if c.maxAge > 0 && le.Value.(*entry).expires <= time.Now().Unix() {
c.deleteElement(le)
c.maybeDeleteOldest()
return nil, false
}
c.lru.MoveToBack(le)
entry := le.Value.(*entry)
if c.maxAge > 0 && c.updateAgeOnGet {
entry.expires = time.Now().Unix() + c.maxAge
}
value := entry.value
return value, true
}
// Set stores the interface{} representation of a response for a given key.
func (c *LruCache) Set(key interface{}, value interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
expires := int64(0)
if c.maxAge > 0 {
expires = time.Now().Unix() + c.maxAge
}
if le, ok := c.cache[key]; ok {
c.lru.MoveToBack(le)
e := le.Value.(*entry)
e.value = value
e.expires = expires
} else {
e := &entry{key: key, value: value, expires: expires}
c.cache[key] = c.lru.PushBack(e)
if c.maxSize > 0 {
if len := c.lru.Len(); len > c.maxSize {
c.deleteElement(c.lru.Front())
}
}
}
c.maybeDeleteOldest()
}
// Delete removes the value associated with a key.
func (c *LruCache) Delete(key string) {
c.mu.Lock()
if le, ok := c.cache[key]; ok {
c.deleteElement(le)
}
c.mu.Unlock()
}
func (c *LruCache) maybeDeleteOldest() {
if c.maxAge > 0 {
now := time.Now().Unix()
for le := c.lru.Front(); le != nil && le.Value.(*entry).expires <= now; le = c.lru.Front() {
c.deleteElement(le)
}
}
}
func (c *LruCache) deleteElement(le *list.Element) {
c.lru.Remove(le)
e := le.Value.(*entry)
delete(c.cache, e.key)
}
type entry struct {
key interface{}
value interface{}
expires int64
}

117
common/cache/lrucache_test.go

@ -0,0 +1,117 @@
package cache
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var entries = []struct {
key string
value string
}{
{"1", "one"},
{"2", "two"},
{"3", "three"},
{"4", "four"},
{"5", "five"},
}
func TestLRUCache(t *testing.T) {
c := NewLRUCache()
for _, e := range entries {
c.Set(e.key, e.value)
}
c.Delete("missing")
_, ok := c.Get("missing")
assert.False(t, ok)
for _, e := range entries {
value, ok := c.Get(e.key)
if assert.True(t, ok) {
assert.Equal(t, e.value, value.(string))
}
}
for _, e := range entries {
c.Delete(e.key)
_, ok := c.Get(e.key)
assert.False(t, ok)
}
}
func TestLRUMaxAge(t *testing.T) {
c := NewLRUCache(WithAge(86400))
now := time.Now().Unix()
expected := now + 86400
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = now
// Reset
c.Set("foo", "bar")
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= now)
c.lru.Back().Value.(*entry).expires = now
// Set a few and verify expiration times
for _, s := range entries {
c.Set(s.key, s.value)
e := c.lru.Back().Value.(*entry)
assert.True(t, e.expires >= expected && e.expires <= expected+10)
}
// Make sure we can get them all
for _, s := range entries {
_, ok := c.Get(s.key)
assert.True(t, ok)
}
// Expire all entries
for _, s := range entries {
le, ok := c.cache[s.key]
if assert.True(t, ok) {
le.Value.(*entry).expires = now
}
}
// Get one expired entry, which should clear all expired entries
_, ok := c.Get("3")
assert.False(t, ok)
assert.Equal(t, c.lru.Len(), 0)
}
func TestLRUpdateOnGet(t *testing.T) {
c := NewLRUCache(WithAge(86400), WithUpdateAgeOnGet())
now := time.Now().Unix()
expires := now + 86400/2
// Add one expired entry
c.Set("foo", "bar")
c.lru.Back().Value.(*entry).expires = expires
_, ok := c.Get("foo")
assert.True(t, ok)
assert.True(t, c.lru.Back().Value.(*entry).expires > expires)
}
func TestMaxSize(t *testing.T) {
c := NewLRUCache(WithSize(2))
// Add one expired entry
c.Set("foo", "bar")
_, ok := c.Get("foo")
assert.True(t, ok)
c.Set("bar", "foo")
c.Set("baz", "foo")
_, ok = c.Get("foo")
assert.False(t, ok)
}

33
common/dns/fakedns/middleware.go

@ -6,40 +6,27 @@ import (
"strings"
D "github.com/miekg/dns"
"github.com/xjasonlyu/tun2socks/common/cache"
trie "github.com/xjasonlyu/tun2socks/common/domain-trie"
"github.com/xjasonlyu/tun2socks/common/fakeip"
)
type handler func(w D.ResponseWriter, r *D.Msg)
func withFakeIP(cache *cache.Cache, pool *fakeip.Pool) handler {
func withFakeIP(pool *fakeip.Pool) handler {
return func(w D.ResponseWriter, r *D.Msg) {
q := r.Question[0]
if msg := getMsgFromCache(cache, "fakeip:"+q.String()); msg != nil {
// Update Cache TTL
ip := msg.Answer[0].(*D.A).A
putMsgToCache(cache, ip.String(), msg)
setMsgTTL(msg, dnsFakeTTL)
msg.SetReply(r)
_ = w.WriteMsg(msg)
return
}
host := strings.TrimRight(q.Name, ".")
rr := &D.A{}
rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
ip := pool.Get()
ip := pool.Lookup(host)
rr.A = ip
msg := r.Copy()
msg.Answer = []D.RR{rr}
putMsgToCache(cache, "fakeip:"+q.String(), msg)
putMsgToCache(cache, ip.String(), msg)
setMsgTTL(msg, dnsFakeTTL)
_ = w.WriteMsg(msg)
setMsgTTL(msg, 1)
msg.SetReply(r)
w.WriteMsg(msg)
return
}
}
@ -88,7 +75,7 @@ func withHost(hosts *trie.Trie, next handler) handler {
msg := r.Copy()
msg.Answer = []D.RR{rr}
msg.SetReply(r)
_ = w.WriteMsg(msg)
w.WriteMsg(msg)
return
}
}
@ -115,9 +102,9 @@ func lineToHosts(str string) *trie.Trie {
return tree
}
func newHandler(hosts *trie.Trie, cache *cache.Cache, pool *fakeip.Pool) handler {
func newHandler(hosts *trie.Trie, pool *fakeip.Pool) handler {
if hosts != nil {
return withHost(hosts, withFakeIP(cache, pool))
return withHost(hosts, withFakeIP(pool))
}
return withFakeIP(cache, pool)
return withFakeIP(pool)
}

27
common/dns/fakedns/server.go

@ -3,24 +3,21 @@ package fakedns
import (
"errors"
"net"
"strings"
"time"
D "github.com/miekg/dns"
"github.com/xjasonlyu/tun2socks/common/cache"
"github.com/xjasonlyu/tun2socks/common/fakeip"
)
const (
lruCacheSize = 1000
dnsFakeTTL uint32 = 1
dnsDefaultTTL uint32 = 600
)
var cacheDuration = time.Duration(dnsDefaultTTL) * time.Second
type Server struct {
*D.Server
c *cache.Cache
p *fakeip.Pool
h handler
}
@ -50,20 +47,13 @@ func (s *Server) StartServer(addr string) error {
}
s.Server = &D.Server{Addr: addr, PacketConn: p, Handler: s}
go s.ActivateAndServe()
go func() {
_ = s.ActivateAndServe()
}()
return nil
}
func (s *Server) IPToHost(ip net.IP) (string, bool) {
msg := getMsgFromCache(s.c, ip.String())
if msg == nil {
return "", false
}
fqdn := msg.Question[0].Name
return strings.TrimRight(fqdn, "."), true
return s.p.LookBack(ip)
}
func NewServer(fakeIPRange, hostsLine string) (*Server, error) {
@ -71,17 +61,16 @@ func NewServer(fakeIPRange, hostsLine string) (*Server, error) {
if err != nil {
return nil, err
}
pool, err := fakeip.New(ipnet)
pool, err := fakeip.New(ipnet, lruCacheSize)
if err != nil {
return nil, err
}
hosts := lineToHosts(hostsLine)
cacheItem := cache.New(cacheDuration)
handler := newHandler(hosts, cacheItem, pool)
handler := newHandler(hosts, pool)
return &Server{
c: cacheItem,
p: pool,
h: handler,
}, nil
}

24
common/dns/fakedns/utils.go

@ -1,33 +1,9 @@
package fakedns
import (
"strings"
"time"
D "github.com/miekg/dns"
"github.com/xjasonlyu/tun2socks/common/cache"
)
func putMsgToCache(c *cache.Cache, key string, msg *D.Msg) {
var ttl time.Duration
if strings.HasPrefix(key, "fakeip:") {
ttl = time.Duration(dnsDefaultTTL) * time.Second
} else {
ttl = 3 * time.Duration(dnsDefaultTTL) * time.Second
}
c.Put(key, msg.Copy(), ttl)
}
func getMsgFromCache(c *cache.Cache, key string) (msg *D.Msg) {
item := c.Get(key)
if item == nil {
return
}
msg = item.(*D.Msg).Copy()
putMsgToCache(c, key, msg)
return
}
func setMsgTTL(msg *D.Msg, ttl uint32) {
for _, answer := range msg.Answer {
answer.Header().Ttl = ttl

80
common/fakeip/pool.go

@ -4,22 +4,72 @@ import (
"errors"
"net"
"sync"
"github.com/xjasonlyu/tun2socks/common/cache"
)
// Pool is a implementation about fake ip generator without storage
type Pool struct {
max uint32
min uint32
offset uint32
mux *sync.Mutex
max uint32
min uint32
gateway uint32
offset uint32
mux *sync.Mutex
cache *cache.LruCache
}
// Lookup return a fake ip with host
func (p *Pool) Lookup(host string) net.IP {
p.mux.Lock()
defer p.mux.Unlock()
if ip, exist := p.cache.Get(host); exist {
return ip.(net.IP)
}
ip := p.get(host)
p.cache.Set(host, ip)
return ip
}
// Get return a new fake ip
func (p *Pool) Get() net.IP {
// LookBack return host with the fake ip
func (p *Pool) LookBack(ip net.IP) (string, bool) {
p.mux.Lock()
defer p.mux.Unlock()
ip := uintToIP(p.min + p.offset)
p.offset = (p.offset + 1) % (p.max - p.min)
if ip = ip.To4(); ip == nil {
return "", false
}
n := ipToUint(ip.To4())
offset := n - p.min + 1
if host, exist := p.cache.Get(offset); exist {
return host.(string), true
}
return "", false
}
// Gateway return gateway ip
func (p *Pool) Gateway() net.IP {
return uintToIP(p.gateway)
}
func (p *Pool) get(host string) net.IP {
current := p.offset
for {
p.offset = (p.offset + 1) % (p.max - p.min)
// Avoid infinite loops
if p.offset == current {
break
}
if _, exist := p.cache.Get(p.offset); !exist {
break
}
}
ip := uintToIP(p.min + p.offset - 1)
p.cache.Set(p.offset, host)
return ip
}
@ -36,8 +86,8 @@ func uintToIP(v uint32) net.IP {
}
// New return Pool instance
func New(ipnet *net.IPNet) (*Pool, error) {
min := ipToUint(ipnet.IP) + 1
func New(ipnet *net.IPNet, size int) (*Pool, error) {
min := ipToUint(ipnet.IP) + 2
ones, bits := ipnet.Mask.Size()
total := 1<<uint(bits-ones) - 2
@ -46,10 +96,12 @@ func New(ipnet *net.IPNet) (*Pool, error) {
return nil, errors.New("ipnet don't have valid ip")
}
max := min + uint32(total)
max := min + uint32(total) - 1
return &Pool{
min: min,
max: max,
mux: &sync.Mutex{},
min: min,
max: max,
gateway: min - 1,
mux: &sync.Mutex{},
cache: cache.NewLRUCache(cache.WithSize(size * 2)),
}, nil
}

51
common/fakeip/pool_test.go

@ -3,42 +3,49 @@ package fakeip
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPool_Basic(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/30")
pool, _ := New(ipnet)
_, ipnet, _ := net.ParseCIDR("192.168.0.1/29")
pool, _ := New(ipnet, 10)
first := pool.Get()
last := pool.Get()
first := pool.Lookup("foo.com")
last := pool.Lookup("bar.com")
bar, exist := pool.LookBack(last)
if !first.Equal(net.IP{192, 168, 0, 1}) {
t.Error("should get right first ip, instead of", first.String())
}
if !last.Equal(net.IP{192, 168, 0, 2}) {
t.Error("should get right last ip, instead of", first.String())
}
assert.True(t, first.Equal(net.IP{192, 168, 0, 2}))
assert.True(t, last.Equal(net.IP{192, 168, 0, 3}))
assert.True(t, exist)
assert.Equal(t, bar, "bar.com")
}
func TestPool_Cycle(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/30")
pool, _ := New(ipnet)
pool, _ := New(ipnet, 10)
first := pool.Lookup("foo.com")
same := pool.Lookup("baz.com")
assert.True(t, first.Equal(same))
}
func TestPool_MaxCacheSize(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/24")
pool, _ := New(ipnet, 2)
first := pool.Get()
pool.Get()
same := pool.Get()
first := pool.Lookup("foo.com")
pool.Lookup("bar.com")
pool.Lookup("baz.com")
next := pool.Lookup("foo.com")
if !first.Equal(same) {
t.Error("should return same ip", first.String())
}
assert.False(t, first.Equal(next))
}
func TestPool_Error(t *testing.T) {
_, ipnet, _ := net.ParseCIDR("192.168.0.1/31")
_, err := New(ipnet)
_, err := New(ipnet, 10)
if err == nil {
t.Error("should return err")
}
assert.Error(t, err)
}

Loading…
Cancel
Save