a recursive dns resolver
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

add rate limiting and caching

+781 -315
+22
docs/alky.toml
··· 12 12 # This is used only if logging.output is "file". 13 13 file_path = "/var/log/alky.log" 14 14 15 + [ratelimit] 16 + # rate: The steady-state rate of requests allowed (r in GCRA) 17 + # Defines how many requests are allowed per second in normal conditions 18 + # Type: Integer 19 + rate = 250 20 + 21 + # burst: Maximum number of requests allowed to exceed the steady-state rate temporarily 22 + # Allows for short bursts of traffic above the defined rate 23 + # Type: Integer 24 + burst = 500 25 + 26 + # window: The interval (in seconds) at which the rate limit is checked and potentially reset 27 + # Implements a sliding window rate limit mechanism 28 + # Type: Integer 29 + window = 3 30 + 31 + # expiration_time: Duration (in seconds) for keeping a client's rate limit data in memory 32 + # After this period of inactivity, a client's rate limit data is removed to free up memory 33 + # Type: Integer 34 + expiration_time = 300 35 + 36 + 15 37 [advanced] 16 38 # Timeout (in milliseconds) for outgoing queries before being cancelled. 17 39 query_timeout = 100
+1 -1
go.mod
··· 3 3 go 1.22.5 4 4 5 5 require ( 6 - code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84 6 + code.kiri.systems/kiri/magna v0.0.0-20240922043826-2c2a1c508469 7 7 github.com/BurntSushi/toml v1.4.0 8 8 )
+2
go.sum
··· 1 1 code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84 h1:igzBX4k3REg0WZExjGLWW7/wu/X+U6QlbMc8aeO2030= 2 2 code.kiri.systems/kiri/magna v0.0.0-20240721214902-8d0a079dbd84/go.mod h1:gSzCiTKyKlUEjGgl/qTb8rxF0QUVuWOEORAsTXA0qyI= 3 + code.kiri.systems/kiri/magna v0.0.0-20240922043826-2c2a1c508469 h1:LUvvGcJ7DuW3eo7yblNH2igCJzYsbWJQ08iZEXBWplc= 4 + code.kiri.systems/kiri/magna v0.0.0-20240922043826-2c2a1c508469/go.mod h1:gSzCiTKyKlUEjGgl/qTb8rxF0QUVuWOEORAsTXA0qyI= 3 5 github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= 4 6 github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= 5 7 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+31 -11
main.go
··· 1 1 package main 2 2 3 3 import ( 4 + "flag" 4 5 "log" 5 6 "log/slog" 6 7 "os" 7 - "flag" 8 + "time" 8 9 9 10 "code.kiri.systems/kiri/alky/pkg/config" 10 11 "code.kiri.systems/kiri/alky/pkg/dns" ··· 14 15 var configFlag string 15 16 16 17 func init() { 17 - flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky") 18 + flag.StringVar(&configFlag, "config", "/etc/alky/alky.toml", "config file path for alky") 18 19 19 - flag.Parse() 20 + flag.Parse() 20 21 } 21 22 22 23 func main() { ··· 45 46 logger = slog.New(slog.NewJSONHandler(os.Stdout, nil)) 46 47 } 47 48 48 - s := dns.Server{ 49 - Address: cfg.Server.Address, 50 - Port: cfg.Server.Port, 51 - Timeout: cfg.Advanced.QueryTimeout, 49 + memCache := dns.NewMemoryCache() 50 + var cache dns.Cache = memCache 51 + 52 + handler := &dns.QueryHandler{ 52 53 RootServers: rootServers, 54 + Timeout: time.Duration(cfg.Advanced.QueryTimeout) * time.Second, 55 + Cache: &cache, 56 + } 57 + 58 + logConfig := &dns.LogConfig{Logger: logger} 59 + 60 + rateLimitHandler := dns.RateLimitMiddleware(&dns.RateLimitConfig{ 61 + Rate: float64(cfg.Ratelimit.Rate), 62 + Burst: cfg.Ratelimit.Burst, 63 + WindowLength: time.Duration(cfg.Ratelimit.Window) * time.Second, 64 + ExpirationTime: time.Duration(cfg.Ratelimit.ExpirationTime) * time.Second, 65 + })(handler) 66 + loggingHandler := dns.LoggingMiddleware(logConfig)(rateLimitHandler) 67 + 68 + s := dns.Server{ 69 + Address: cfg.Server.Address, 70 + Port: cfg.Server.Port, 71 + Handler: loggingHandler, 72 + UDPSize: 512, 73 + ReadTimeout: 2 * time.Second, 74 + WriteTimeout: 2 * time.Second, 53 75 54 76 Logger: logger, 55 77 } 56 78 57 - go s.TCPListenAndServe() 58 - go s.UDPListenAndServe() 59 - 60 - for { 79 + if err := s.ListenAndServe(); err != nil { 80 + slog.Error("Failed to start server", "error", err) 61 81 } 62 82 }
+11 -3
pkg/config/config.go
··· 17 17 FilePath string `toml:"file_path"` 18 18 } 19 19 20 + type RatelimitConfig struct { 21 + Rate int `toml:"rate"` 22 + Burst int `toml:"burst"` 23 + Window int `toml:"window"` 24 + ExpirationTime int `toml:"expiration_time"` 25 + } 26 + 20 27 type AdvancedConfig struct { 21 28 QueryTimeout int `toml:"query_timeout"` 22 29 } 23 30 24 31 type Config struct { 25 - Server ServerConfig `toml:"server"` 26 - Logging LoggingConfig `toml:"logging"` 27 - Advanced AdvancedConfig `toml:"advanced"` 32 + Server ServerConfig `toml:"server"` 33 + Logging LoggingConfig `toml:"logging"` 34 + Ratelimit RatelimitConfig `toml:"ratelimit"` 35 + Advanced AdvancedConfig `toml:"advanced"` 28 36 } 29 37 30 38 func LoadConfig(path string) (Config, error) {
+52
pkg/dns/cache.go
··· 1 + package dns 2 + 3 + import ( 4 + "sync" 5 + "time" 6 + 7 + "code.kiri.systems/kiri/magna" 8 + ) 9 + 10 + type CachedResourceRecord struct { 11 + Record magna.ResourceRecord 12 + ExpireAt time.Time 13 + } 14 + 15 + type CacheEntry struct { 16 + Answer []CachedResourceRecord 17 + } 18 + 19 + type Cache interface { 20 + Get(key string) (*CacheEntry, bool) 21 + Set(key string, entry *CacheEntry) 22 + } 23 + 24 + type MemoryCache struct { 25 + entries map[string]*CacheEntry 26 + mu sync.RWMutex 27 + } 28 + 29 + func NewMemoryCache() *MemoryCache { 30 + return &MemoryCache{ 31 + entries: make(map[string]*CacheEntry), 32 + } 33 + } 34 + 35 + func (c *MemoryCache) Get(key string) (*CacheEntry, bool) { 36 + c.mu.RLock() 37 + c.mu.RUnlock() 38 + 39 + entry, exists := c.entries[key] 40 + if !exists { 41 + return nil, false 42 + } 43 + 44 + return entry, true 45 + } 46 + 47 + func (c *MemoryCache) Set(key string, entry *CacheEntry) { 48 + c.mu.Lock() 49 + defer c.mu.Unlock() 50 + 51 + c.entries[key] = entry 52 + }
-300
pkg/dns/dns.go
··· 1 1 package dns 2 - 3 - import ( 4 - "context" 5 - "encoding/binary" 6 - "fmt" 7 - "io" 8 - "log/slog" 9 - "math/rand/v2" 10 - "net" 11 - "time" 12 - 13 - "code.kiri.systems/kiri/magna" 14 - ) 15 - 16 - type Server struct { 17 - Address string 18 - Port int 19 - Timeout int 20 - RootServers []string 21 - 22 - Logger *slog.Logger 23 - } 24 - 25 - type queryResponse struct { 26 - MSG magna.Message 27 - Server string 28 - Error error 29 - } 30 - 31 - func (s *Server) UDPListenAndServe() error { 32 - addr := net.UDPAddr{ 33 - Port: s.Port, 34 - IP: net.ParseIP(s.Address), 35 - } 36 - server, err := net.ListenUDP("udp", &addr) 37 - if err != nil { 38 - return err 39 - } 40 - defer server.Close() 41 - 42 - for { 43 - b := make([]byte, 512) 44 - _, remote_addr, err := server.ReadFromUDP(b) 45 - if err != nil { 46 - s.Logger.Warn(err.Error()) 47 - continue 48 - } 49 - 50 - start := time.Now() 51 - msg := s.processQuery(b) 52 - s.Logger.Info("query", "class", msg.Question[0].QClass.String(), "type", msg.Question[0].QType.String(), "name", msg.Question[0].QName, "rcode", msg.Header.RCode.String(), "remote_addr", remote_addr.IP, "time_taken", time.Since(start).Nanoseconds()) 53 - if err != nil { 54 - s.Logger.Warn(err.Error()) 55 - continue 56 - } 57 - 58 - ans := msg.Encode() 59 - // xxx: set the TC bit if the message is over 512 bytes 60 - if len(ans) > 512 { 61 - ans[3] |= 1 << 6 62 - } 63 - 64 - if _, err := server.WriteToUDP(ans, remote_addr); err != nil { 65 - s.Logger.Warn("sending response", "err", err.Error()) 66 - } 67 - } 68 - } 69 - 70 - func (s *Server) TCPListenAndServe() error { 71 - addr := net.TCPAddr{ 72 - Port: s.Port, 73 - IP: net.ParseIP(s.Address), 74 - } 75 - 76 - server, err := net.ListenTCP("tcp", &addr) 77 - if err != nil { 78 - return err 79 - } 80 - defer server.Close() 81 - 82 - for { 83 - conn, err := server.Accept() 84 - if err != nil { 85 - s.Logger.Warn("conn error:", err) 86 - continue 87 - } 88 - 89 - sizeBuffer := make([]byte, 2) 90 - if _, err := io.ReadFull(conn, sizeBuffer); err != nil { 91 - s.Logger.Warn("tcp-error", err) 92 - continue 93 - } 94 - 95 - size := binary.BigEndian.Uint16(sizeBuffer) 96 - 97 - data := make([]byte, size) 98 - if _, err := io.ReadFull(conn, data); err != nil { 99 - s.Logger.Warn("tcp-error", err) 100 - continue 101 - } 102 - 103 - start := time.Now() 104 - msg := s.processQuery(data) 105 - s.Logger.Info("query", "class", msg.Question[0].QClass.String(), "type", msg.Question[0].QType.String(), "name", msg.Question[0].QName, "rcode", msg.Header.RCode.String(), "remote_addr", conn.RemoteAddr(), "time_taken", time.Since(start).Nanoseconds()) 106 - 107 - ans := msg.Encode() 108 - conn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans)))) 109 - if _, err := conn.Write(ans); err != nil { 110 - s.Logger.Error("tcp-error", err) 111 - } 112 - } 113 - } 114 - 115 - func (s *Server) processQuery(messageBuffer []byte) (msg magna.Message) { 116 - var query magna.Message 117 - if err := query.Decode(messageBuffer); err != nil { 118 - slog.Warn("decode", err) 119 - return 120 - } 121 - 122 - msg = magna.Message{ 123 - Header: magna.Header{ 124 - ID: query.Header.ID, 125 - QR: true, 126 - OPCode: 0, 127 - AA: false, 128 - TC: false, 129 - RD: query.Header.RD, 130 - RA: true, 131 - Z: 0, 132 - RCode: magna.NOERROR, 133 - QDCount: 1, 134 - ANCount: 0, 135 - NSCount: 0, 136 - ARCount: 0, 137 - }, 138 - Question: []magna.Question{}, 139 - Answer: []magna.ResourceRecord{}, 140 - Additional: []magna.ResourceRecord{}, 141 - Authority: []magna.ResourceRecord{}, 142 - } 143 - 144 - if len(query.Question) < 0 { 145 - msg.Header.RCode = magna.FORMERR 146 - return 147 - } 148 - question := query.Question[0] 149 - msg.Question = append(msg.Question, question) 150 - 151 - if question.QClass != magna.IN { 152 - msg.Header.RCode = magna.NOTIMP 153 - return 154 - } else { 155 - answer, err := s.resolveQuestion(question, s.RootServers) 156 - if err != nil { 157 - slog.Warn("resolve-question", err) 158 - msg.Header.RCode = magna.SERVFAIL 159 - return 160 - } 161 - 162 - msg.Header.ANCount = uint16(len(answer)) 163 - msg.Answer = answer 164 - 165 - if msg.Header.ANCount == 0 { 166 - msg.Header.RCode = magna.NXDOMAIN 167 - return 168 - } 169 - } 170 - 171 - return 172 - } 173 - 174 - func (s *Server) resolveQuestion(question magna.Question, servers []string) ([]magna.ResourceRecord, error) { 175 - ctx, cancel := context.WithCancel(context.Background()) 176 - defer cancel() 177 - 178 - ch := make(chan queryResponse, len(servers)) 179 - 180 - for _, s := range servers { 181 - go queryServer(ctx, question, s, ch) 182 - } 183 - 184 - for i := 0; i < len(servers); i++ { 185 - select { 186 - case res := <-ch: 187 - if res.Error != nil { 188 - slog.Warn("error", "question", question, "server", res.Server, "error", res.Error) 189 - break 190 - } 191 - 192 - msg := res.MSG 193 - if msg.Header.ANCount > 0 { 194 - if msg.Answer[0].RType == magna.CNAMEType { 195 - cname_answers, err := s.resolveQuestion(magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, s.RootServers) 196 - if err != nil { 197 - slog.Warn("error with cname request", err) 198 - continue 199 - } 200 - msg.Answer = append(msg.Answer, cname_answers...) 201 - } 202 - 203 - return msg.Answer, nil 204 - } 205 - 206 - if msg.Header.ARCount > 0 { 207 - var nextZone []string 208 - for _, ans := range msg.Additional { 209 - if ans.RType == magna.AType { 210 - nextZone = append(nextZone, ans.RData.String()) 211 - } 212 - } 213 - 214 - return s.resolveQuestion(question, nextZone) 215 - } 216 - 217 - if msg.Header.NSCount > 0 { 218 - var ns []string 219 - for _, a := range msg.Authority { 220 - if a.RType == magna.NSType { 221 - ans, err := s.resolveQuestion(magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, s.RootServers) 222 - if err != nil { 223 - slog.Warn("error with ns request", err) 224 - break 225 - } 226 - for _, x := range ans { 227 - ns = append(ns, x.RData.String()) 228 - } 229 - } 230 - } 231 - 232 - return s.resolveQuestion(question, ns) 233 - } 234 - 235 - return []magna.ResourceRecord{}, nil 236 - case <-time.After(time.Duration(s.Timeout) * time.Millisecond): 237 - cancel() 238 - } 239 - } 240 - 241 - return []magna.ResourceRecord{}, nil 242 - } 243 - 244 - func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse) { 245 - done := make(chan struct{}, 1) 246 - 247 - go func() { 248 - conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server)) 249 - if err != nil { 250 - ch <- queryResponse{Error: err} 251 - return 252 - } 253 - defer conn.Close() 254 - 255 - query := magna.Message{ 256 - Header: magna.Header{ 257 - ID: uint16(rand.Int() % 65535), 258 - QR: false, 259 - OPCode: 0, 260 - AA: false, 261 - TC: false, 262 - RD: false, 263 - RA: false, 264 - Z: 0, 265 - RCode: magna.NOERROR, 266 - QDCount: 1, 267 - ARCount: 0, 268 - NSCount: 0, 269 - ANCount: 0, 270 - }, 271 - Question: []magna.Question{question}, 272 - } 273 - if _, err := conn.Write(query.Encode()); err != nil { 274 - ch <- queryResponse{Server: server, Error: err} 275 - return 276 - } 277 - 278 - p := make([]byte, 512) 279 - nn, err := conn.Read(p) 280 - 281 - // TODO: retry request with TCP 282 - if err != nil || nn > 512 { 283 - if err == nil { 284 - err = fmt.Errorf("truncated response") 285 - } 286 - ch <- queryResponse{Server: server, Error: err} 287 - return 288 - } 289 - 290 - var response magna.Message 291 - err = response.Decode(p) 292 - ch <- queryResponse{MSG: response, Server: server, Error: err} 293 - }() 294 - 295 - select { 296 - case <-ctx.Done(): 297 - ch <- queryResponse{Server: server, Error: ctx.Err()} 298 - case <-done: 299 - // goroutine finished with no cancellation 300 - } 301 - }
+46
pkg/dns/logging.go
··· 1 + package dns 2 + 3 + import ( 4 + "log/slog" 5 + "os" 6 + "time" 7 + ) 8 + 9 + type LogConfig struct { 10 + Logger *slog.Logger 11 + Level slog.Level 12 + } 13 + 14 + func NewDefaultLogConfig() *LogConfig { 15 + return &LogConfig{ 16 + Logger: slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ 17 + Level: slog.LevelInfo, 18 + })), 19 + Level: slog.LevelInfo, 20 + } 21 + } 22 + 23 + func LoggingMiddleware(config *LogConfig) func(Handler) Handler { 24 + if config == nil { 25 + config = NewDefaultLogConfig() 26 + } 27 + 28 + return func(next Handler) Handler { 29 + return HandlerFunc(func(w ResponseWriter, r *Request) { 30 + start := time.Now() 31 + 32 + next.ServeDNS(w, r) 33 + 34 + duration := time.Since(start) 35 + question := r.Message.Question[0] 36 + config.Logger.Info("query", 37 + "class", question.QClass.String(), 38 + "type", question.QType.String(), 39 + "name", question.QName, 40 + "rcode", r.Message.Header.RCode.String(), 41 + "remote_addr", r.RemoteAddr, 42 + "time_taken", duration.Nanoseconds(), 43 + ) 44 + }) 45 + } 46 + }
+137
pkg/dns/ratelimit.go
··· 1 + package dns 2 + 3 + import ( 4 + "net" 5 + "sync" 6 + "time" 7 + 8 + "code.kiri.systems/kiri/magna" 9 + ) 10 + 11 + type RateLimitConfig struct { 12 + Rate float64 13 + Burst int 14 + WindowLength time.Duration 15 + ExpirationTime time.Duration 16 + } 17 + 18 + type rateLimiter struct { 19 + config RateLimitConfig 20 + ipData map[string]*ipRateData 21 + mu sync.RWMutex 22 + } 23 + 24 + type ipRateData struct { 25 + time time.Time 26 + } 27 + 28 + func NewDefaultRateLimitConfig() *RateLimitConfig { 29 + return &RateLimitConfig{ 30 + Rate: 1, 31 + Burst: 1, 32 + WindowLength: time.Hour, 33 + ExpirationTime: time.Hour, 34 + } 35 + } 36 + 37 + func newRateLimiter(config RateLimitConfig) *rateLimiter { 38 + return &rateLimiter{ 39 + config: config, 40 + ipData: make(map[string]*ipRateData), 41 + } 42 + } 43 + 44 + func (rl *rateLimiter) allow(ip string) bool { 45 + rl.mu.Lock() 46 + defer rl.mu.Unlock() 47 + 48 + now := time.Now() 49 + cost := time.Duration(float64(time.Second) / rl.config.Rate) 50 + 51 + data, exists := rl.ipData[ip] 52 + if !exists { 53 + data = &ipRateData{time: now.Add(-rl.config.WindowLength)} 54 + rl.ipData[ip] = data 55 + } 56 + 57 + if data.time.Before(now.Add(-rl.config.WindowLength)) { 58 + data.time = now.Add(-rl.config.WindowLength) 59 + } 60 + 61 + nextTime := data.time.Add(cost) 62 + if now.Before(nextTime) { 63 + return false 64 + } 65 + 66 + if nextTime.Sub(now.Add(-rl.config.WindowLength)) > time.Duration(rl.config.Burst)*cost { 67 + nextTime = now.Add(cost) 68 + } 69 + 70 + data.time = nextTime 71 + return true 72 + } 73 + 74 + func (rl *rateLimiter) cleanup() { 75 + rl.mu.Lock() 76 + defer rl.mu.Unlock() 77 + 78 + now := time.Now() 79 + for ip, data := range rl.ipData { 80 + if data.time.Before(now.Add(-rl.config.WindowLength)) { 81 + delete(rl.ipData, ip) 82 + } 83 + } 84 + } 85 + 86 + func extractIP(addr net.Addr) string { 87 + switch v := addr.(type) { 88 + case *net.UDPAddr: 89 + return v.IP.String() 90 + case *net.TCPAddr: 91 + return v.IP.String() 92 + default: 93 + host, _, err := net.SplitHostPort(addr.String()) 94 + if err != nil { 95 + return addr.String() 96 + } 97 + return host 98 + } 99 + } 100 + 101 + func RateLimitMiddleware(config *RateLimitConfig) func(Handler) Handler { 102 + if config == nil { 103 + config = NewDefaultRateLimitConfig() 104 + } 105 + 106 + rl := newRateLimiter(*config) 107 + 108 + go func() { 109 + ticker := time.NewTicker(config.ExpirationTime) 110 + for range ticker.C { 111 + rl.cleanup() 112 + } 113 + }() 114 + 115 + return func(next Handler) Handler { 116 + return HandlerFunc(func(w ResponseWriter, r *Request) { 117 + if !rl.allow(extractIP(r.RemoteAddr)) { 118 + r.Message.Header.RA = true 119 + msg := r.Message.CreateReply(r.Message) 120 + msg = r.Message.SetRCode(magna.REFUSED) 121 + 122 + // XXX: dont support edns yet and these get copied over on responses 123 + msg.Header.ANCount = 0 124 + msg.Header.NSCount = 0 125 + msg.Header.ARCount = 0 126 + msg.Answer = []magna.ResourceRecord{} 127 + msg.Additional = []magna.ResourceRecord{} 128 + msg.Authority = []magna.ResourceRecord{} 129 + 130 + w.WriteMsg(msg) 131 + return 132 + } 133 + 134 + next.ServeDNS(w, r) 135 + }) 136 + } 137 + }
+240
pkg/dns/resolve.go
··· 1 + package dns 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "net" 7 + "strings" 8 + "time" 9 + 10 + "code.kiri.systems/kiri/magna" 11 + ) 12 + 13 + type QueryHandler struct { 14 + RootServers []string 15 + Timeout time.Duration 16 + Cache *Cache 17 + } 18 + 19 + type queryResponse struct { 20 + MSG magna.Message 21 + Server string 22 + Error error 23 + } 24 + 25 + func (h *QueryHandler) ServeDNS(w ResponseWriter, r *Request) { 26 + msg := h.processQuery(r.Message.Encode()) 27 + w.WriteMsg(msg) 28 + } 29 + 30 + func (h *QueryHandler) processQuery(messageBuffer []byte) *magna.Message { 31 + var query magna.Message 32 + if err := query.Decode(messageBuffer); err != nil { 33 + return nil 34 + } 35 + 36 + msg := new(magna.Message) 37 + msg = msg.CreateReply(&query) 38 + 39 + if len(query.Question) < 1 { 40 + return msg.SetRCode(magna.FORMERR) 41 + } 42 + 43 + question := query.Question[0] 44 + msg = msg.AddQuestion(question) 45 + 46 + if question.QClass != magna.IN { 47 + return msg.SetRCode(magna.NOTIMP) 48 + } 49 + 50 + answer, err := h.resolveWithCache(question) 51 + if err != nil { 52 + return msg.SetRCode(magna.SERVFAIL) 53 + } 54 + 55 + if len(answer) == 0 { 56 + return msg.SetRCode(magna.NXDOMAIN) 57 + } 58 + 59 + msg.Header.ANCount = uint16(len(answer)) 60 + msg.Answer = answer 61 + return msg.SetRCode(magna.NOERROR) 62 + } 63 + 64 + func (h *QueryHandler) resolveWithCache(question magna.Question) ([]magna.ResourceRecord, error) { 65 + cacheKey := fmt.Sprintf("%s:%s:%s", strings.ToLower(question.QName), question.QType.String(), question.QClass.String()) 66 + 67 + if e, found := (*h.Cache).Get(cacheKey); found { 68 + now := time.Now() 69 + var updatedAnswer []magna.ResourceRecord 70 + var cname *magna.ResourceRecord 71 + hasAddressRecord := false 72 + 73 + for _, cachedRR := range e.Answer { 74 + if now.Before(cachedRR.ExpireAt) { 75 + updatedRR := cachedRR.Record 76 + updatedRR.TTL = uint32(cachedRR.ExpireAt.Sub(now).Seconds()) 77 + updatedAnswer = append(updatedAnswer, updatedRR) 78 + 79 + if updatedRR.RType == magna.CNAMEType && cname == nil { 80 + cname = &updatedRR 81 + } else if updatedRR.RType == question.QType { 82 + hasAddressRecord = true 83 + } 84 + } 85 + } 86 + 87 + if len(updatedAnswer) > 0 { 88 + // add AAAA types when magna supports those record types 89 + if cname != nil && !hasAddressRecord && (question.QType == magna.AType) { 90 + cnameTarget := cname.RData.String() 91 + aRecords, err := h.resolveWithCache(magna.Question{QName: cnameTarget, QType: question.QType, QClass: question.QClass}) 92 + if err == nil && len(aRecords) > 0 { 93 + updatedAnswer = append(updatedAnswer, aRecords...) 94 + } 95 + } 96 + return updatedAnswer, nil 97 + } 98 + } 99 + 100 + answer, err := h.resolveQuestion(question, h.RootServers) 101 + if err != nil { 102 + return nil, err 103 + } 104 + 105 + now := time.Now() 106 + cachedAnswer := make([]CachedResourceRecord, len(answer)) 107 + for i, rr := range answer { 108 + cachedAnswer[i] = CachedResourceRecord{ 109 + Record: rr, 110 + ExpireAt: now.Add(time.Duration(rr.TTL) * time.Second), 111 + } 112 + } 113 + 114 + entry := &CacheEntry{ 115 + Answer: cachedAnswer, 116 + } 117 + (*h.Cache).Set(cacheKey, entry) 118 + 119 + if len(answer) > 0 && answer[0].RType == magna.CNAMEType && question.QType == magna.AType { 120 + cnameTarget := answer[len(answer)-1].RData.String() 121 + addressRecords, err := h.resolveWithCache(magna.Question{QName: cnameTarget, QType: question.QType, QClass: question.QClass}) 122 + if err == nil && len(addressRecords) > 0 { 123 + answer = append(answer, addressRecords...) 124 + } 125 + } 126 + 127 + return answer, nil 128 + } 129 + 130 + func (h *QueryHandler) resolveQuestion(question magna.Question, servers []string) ([]magna.ResourceRecord, error) { 131 + ctx, cancel := context.WithCancel(context.Background()) 132 + defer cancel() 133 + 134 + ch := make(chan queryResponse, len(servers)) 135 + 136 + for _, s := range servers { 137 + go queryServer(ctx, question, s, ch, h.Timeout) 138 + } 139 + 140 + for i := 0; i < len(servers); i++ { 141 + select { 142 + case res := <-ch: 143 + if res.Error != nil { 144 + break 145 + } 146 + 147 + msg := res.MSG 148 + if msg.Header.ANCount > 0 { 149 + if msg.Answer[0].RType == magna.CNAMEType { 150 + cname_answers, err := h.resolveQuestion(magna.Question{QName: msg.Answer[0].RData.String(), QType: question.QType, QClass: question.QClass}, h.RootServers) 151 + if err != nil { 152 + continue 153 + } 154 + msg.Answer = append(msg.Answer, cname_answers...) 155 + } 156 + 157 + return msg.Answer, nil 158 + } 159 + 160 + if msg.Header.ARCount > 0 { 161 + var nextZone []string 162 + for _, ans := range msg.Additional { 163 + if ans.RType == magna.AType { 164 + nextZone = append(nextZone, ans.RData.String()) 165 + } 166 + } 167 + 168 + return h.resolveQuestion(question, nextZone) 169 + } 170 + 171 + if msg.Header.NSCount > 0 { 172 + var ns []string 173 + for _, a := range msg.Authority { 174 + if a.RType == magna.NSType { 175 + ans, err := h.resolveQuestion(magna.Question{QName: a.RData.String(), QType: magna.AType, QClass: magna.IN}, h.RootServers) 176 + if err != nil { 177 + break 178 + } 179 + for _, x := range ans { 180 + ns = append(ns, x.RData.String()) 181 + } 182 + } 183 + } 184 + 185 + return h.resolveQuestion(question, ns) 186 + } 187 + 188 + return []magna.ResourceRecord{}, nil 189 + case <-time.After(h.Timeout): 190 + cancel() 191 + } 192 + } 193 + 194 + return []magna.ResourceRecord{}, nil 195 + } 196 + 197 + func queryServer(ctx context.Context, question magna.Question, server string, ch chan<- queryResponse, timeout time.Duration) { 198 + done := make(chan struct{}, 1) 199 + 200 + go func() { 201 + conn, err := net.Dial("udp", fmt.Sprintf("%s:53", server)) 202 + if err != nil { 203 + ch <- queryResponse{Error: err} 204 + return 205 + } 206 + defer conn.Close() 207 + 208 + query := magna.CreateRequest(0, false) 209 + query = query.AddQuestion(question) 210 + if _, err := conn.Write(query.Encode()); err != nil { 211 + ch <- queryResponse{Server: server, Error: err} 212 + return 213 + } 214 + 215 + p := make([]byte, 512) 216 + nn, err := conn.Read(p) 217 + 218 + // TODO: retry request with TCP 219 + if err != nil || nn > 512 { 220 + if err == nil { 221 + err = fmt.Errorf("truncated response") 222 + } 223 + ch <- queryResponse{Server: server, Error: err} 224 + return 225 + } 226 + 227 + var response magna.Message 228 + err = response.Decode(p) 229 + ch <- queryResponse{MSG: response, Server: server, Error: err} 230 + }() 231 + 232 + select { 233 + case <-ctx.Done(): 234 + ch <- queryResponse{Server: server, Error: ctx.Err()} 235 + case <-done: 236 + // goroutine finished with no cancellation 237 + case <-time.After(timeout): 238 + ch <- queryResponse{Server: server, Error: fmt.Errorf("timeout")} 239 + } 240 + }
+239
pkg/dns/server.go
··· 1 + package dns 2 + 3 + import ( 4 + "encoding/binary" 5 + "fmt" 6 + "io" 7 + "log/slog" 8 + "net" 9 + "sync" 10 + "time" 11 + 12 + "code.kiri.systems/kiri/magna" 13 + ) 14 + 15 + type Handler interface { 16 + ServeDNS(ResponseWriter, *Request) 17 + } 18 + 19 + type HandlerFunc func(ResponseWriter, *Request) 20 + 21 + func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Request) { 22 + f(w, r) 23 + } 24 + 25 + type Request struct { 26 + RemoteAddr net.Addr 27 + Message *magna.Message 28 + } 29 + 30 + type ResponseWriter interface { 31 + WriteMsg(*magna.Message) 32 + } 33 + 34 + type udpResponseWriter struct { 35 + udpConn *net.UDPConn 36 + addr *net.UDPAddr 37 + logger *slog.Logger 38 + writeTimeout time.Duration 39 + } 40 + 41 + func (w *udpResponseWriter) WriteMsg(msg *magna.Message) { 42 + ans := msg.Encode() 43 + if len(ans) > 512 { 44 + ans[3] |= 1 << 6 // set the truncated bit 45 + } 46 + 47 + err := w.udpConn.SetWriteDeadline(time.Now().Add(w.writeTimeout)) 48 + if err != nil { 49 + w.logger.Warn("error setting write deadline for UDP", "error", err) 50 + } 51 + 52 + _, err = w.udpConn.WriteToUDP(ans, w.addr) 53 + if err != nil { 54 + w.logger.Error("error writing UDP response", "error", err) 55 + } 56 + } 57 + 58 + type tcpResponseWriter struct { 59 + tcpConn net.Conn 60 + logger *slog.Logger 61 + writeTiemout time.Duration 62 + } 63 + 64 + func (w *tcpResponseWriter) WriteMsg(msg *magna.Message) { 65 + ans := msg.Encode() 66 + 67 + err := w.tcpConn.SetWriteDeadline(time.Now().Add(w.writeTiemout)) 68 + if err != nil { 69 + w.logger.Warn("error setting write deadline for TCP", "error", err) 70 + } 71 + 72 + _, err = w.tcpConn.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(ans)))) 73 + if err != nil { 74 + w.logger.Error("error writing TCP message length", "error", err) 75 + return 76 + } 77 + 78 + _, err = w.tcpConn.Write(ans) 79 + if err != nil { 80 + w.logger.Error("error writing TCP response", "error", err) 81 + } 82 + } 83 + 84 + type Server struct { 85 + Address string 86 + Port int 87 + Handler Handler 88 + UDPSize int 89 + ReadTimeout time.Duration 90 + WriteTimeout time.Duration 91 + Logger *slog.Logger 92 + Cache Cache 93 + } 94 + 95 + func (srv *Server) ListenAndServe() error { 96 + var wg sync.WaitGroup 97 + errChan := make(chan error, 2) 98 + 99 + wg.Add(2) 100 + 101 + go func() { 102 + defer wg.Done() 103 + if err := srv.serveTCP(); err != nil { 104 + errChan <- fmt.Errorf("TCP server error: %w", err) 105 + } 106 + }() 107 + 108 + go func() { 109 + defer wg.Done() 110 + if err := srv.serveUDP(); err != nil { 111 + errChan <- fmt.Errorf("TCP server error: %w", err) 112 + } 113 + }() 114 + 115 + go func() { 116 + wg.Wait() 117 + close(errChan) 118 + }() 119 + 120 + for err := range errChan { 121 + return err 122 + } 123 + 124 + return nil 125 + } 126 + 127 + func (srv *Server) serveUDP() error { 128 + addr := net.UDPAddr{ 129 + Port: srv.Port, 130 + IP: net.ParseIP(srv.Address), 131 + } 132 + conn, err := net.ListenUDP("udp", &addr) 133 + if err != nil { 134 + return err 135 + } 136 + defer conn.Close() 137 + 138 + for { 139 + buf := make([]byte, srv.UDPSize) 140 + 141 + err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) 142 + if err != nil { 143 + return fmt.Errorf("error setting read deadline: %w", err) 144 + } 145 + 146 + n, remoteAddr, err := conn.ReadFromUDP(buf) 147 + if err != nil { 148 + // skip logging timeout errors 149 + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { 150 + continue 151 + } 152 + 153 + srv.Logger.Warn(err.Error()) 154 + continue 155 + } 156 + 157 + go srv.handleUDPQuery(conn, buf[:n], remoteAddr) 158 + } 159 + } 160 + 161 + func (srv *Server) handleUDPQuery(conn *net.UDPConn, query []byte, remoteAddr *net.UDPAddr) { 162 + w := &udpResponseWriter{ 163 + udpConn: conn, 164 + addr: remoteAddr, 165 + logger: srv.Logger, 166 + writeTimeout: srv.WriteTimeout, 167 + } 168 + 169 + srv.handleQuery(query, w, remoteAddr) 170 + } 171 + 172 + func (srv *Server) serveTCP() error { 173 + addr := net.TCPAddr{ 174 + Port: srv.Port, 175 + IP: net.ParseIP(srv.Address), 176 + } 177 + 178 + listener, err := net.ListenTCP("tcp", &addr) 179 + if err != nil { 180 + return err 181 + } 182 + defer listener.Close() 183 + 184 + for { 185 + conn, err := listener.Accept() 186 + if err != nil { 187 + srv.Logger.Warn("tcp accept error:", err) 188 + continue 189 + } 190 + 191 + go srv.handleTCPQuery(conn) 192 + } 193 + } 194 + 195 + func (srv *Server) handleTCPQuery(conn net.Conn) { 196 + defer conn.Close() 197 + 198 + err := conn.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) 199 + if err != nil { 200 + srv.Logger.Error("error setting read deadline", "error", err) 201 + return 202 + } 203 + 204 + sizeBuffer := make([]byte, 2) 205 + if _, err := io.ReadFull(conn, sizeBuffer); err != nil { 206 + srv.Logger.Warn("tcp-error", err) 207 + return 208 + } 209 + 210 + size := binary.BigEndian.Uint16(sizeBuffer) 211 + data := make([]byte, size) 212 + if _, err := io.ReadFull(conn, data); err != nil { 213 + srv.Logger.Warn("tcp-error", err) 214 + return 215 + } 216 + 217 + w := &tcpResponseWriter{ 218 + tcpConn: conn, 219 + logger: srv.Logger, 220 + writeTiemout: srv.WriteTimeout, 221 + } 222 + 223 + srv.handleQuery(data, w, conn.RemoteAddr()) 224 + } 225 + 226 + func (srv *Server) handleQuery(messageBuffer []byte, w ResponseWriter, remoteAddr net.Addr) { 227 + var query magna.Message 228 + if err := query.Decode(messageBuffer); err != nil { 229 + srv.Logger.Warn("decode error", err) 230 + return 231 + } 232 + 233 + r := &Request{ 234 + Message: &query, 235 + RemoteAddr: remoteAddr, 236 + } 237 + 238 + srv.Handler.ServeDNS(w, r) 239 + }