-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcache.go
More file actions
149 lines (130 loc) · 2.75 KB
/
cache.go
File metadata and controls
149 lines (130 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package main
import (
"container/list"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
type cacheKey struct {
Name string
Type uint16
Class uint16
}
type cacheEntry struct {
response []byte
expiresAt time.Time
element *list.Element
}
// DNSCache is an LRU cache for DNS responses with TTL-based expiry.
type DNSCache struct {
mu sync.Mutex
maxSize int
items map[cacheKey]*cacheEntry
order *list.List
}
func NewDNSCache(maxSize int) *DNSCache {
return &DNSCache{
maxSize: maxSize,
items: make(map[cacheKey]*cacheEntry),
order: list.New(),
}
}
func extractKey(raw []byte) (cacheKey, bool) {
msg := new(dns.Msg)
if err := msg.Unpack(raw); err != nil || len(msg.Question) == 0 {
return cacheKey{}, false
}
q := msg.Question[0]
return cacheKey{
Name: strings.ToLower(q.Name),
Type: q.Qtype,
Class: q.Qclass,
}, true
}
func minTTL(raw []byte) time.Duration {
msg := new(dns.Msg)
if err := msg.Unpack(raw); err != nil {
return 0
}
var min uint32
first := true
for _, rrs := range [][]dns.RR{msg.Answer, msg.Ns, msg.Extra} {
for _, rr := range rrs {
if _, ok := rr.(*dns.OPT); ok {
continue
}
ttl := rr.Header().Ttl
if first || ttl < min {
min = ttl
first = false
}
}
}
if first {
return 0 // no RRs
}
return time.Duration(min) * time.Second
}
// Get retrieves a cached response. On hit, rewrites the transaction ID to match the query.
func (c *DNSCache) Get(query []byte) ([]byte, bool) {
key, ok := extractKey(query)
if !ok {
return nil, false
}
c.mu.Lock()
defer c.mu.Unlock()
entry, exists := c.items[key]
if !exists {
return nil, false
}
if time.Now().After(entry.expiresAt) {
c.order.Remove(entry.element)
delete(c.items, key)
return nil, false
}
c.order.MoveToFront(entry.element)
resp := make([]byte, len(entry.response))
copy(resp, entry.response)
// Rewrite transaction ID (first 2 bytes) to match the query
if len(query) >= 2 && len(resp) >= 2 {
resp[0] = query[0]
resp[1] = query[1]
}
return resp, true
}
// Put stores a DNS response in the cache.
func (c *DNSCache) Put(query []byte, response []byte) {
key, ok := extractKey(query)
if !ok {
return
}
ttl := minTTL(response)
if ttl <= 0 {
return
}
c.mu.Lock()
defer c.mu.Unlock()
if entry, exists := c.items[key]; exists {
c.order.Remove(entry.element)
delete(c.items, key)
}
// Evict LRU if at capacity
for c.order.Len() >= c.maxSize {
back := c.order.Back()
if back == nil {
break
}
evictKey := back.Value.(cacheKey)
c.order.Remove(back)
delete(c.items, evictKey)
}
stored := make([]byte, len(response))
copy(stored, response)
elem := c.order.PushFront(key)
c.items[key] = &cacheEntry{
response: stored,
expiresAt: time.Now().Add(ttl),
element: elem,
}
}