diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index af6b58a17ce..524df47b56b 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -1,7 +1,6 @@ package dnsforward import ( - "net" "strings" "time" @@ -16,10 +15,10 @@ import ( func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { elapsed := time.Since(ctx.startTime) s := ctx.srv - d := ctx.proxyCtx + pctx := ctx.proxyCtx shouldLog := true - msg := d.Req + msg := pctx.Req // don't log ANY request if refuseAny is enabled if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny { @@ -32,15 +31,15 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { if shouldLog && s.queryLog != nil { p := querylog.AddParams{ Question: msg, - Answer: d.Res, + Answer: pctx.Res, OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: IPFromAddr(d.Addr), + ClientIP: IPFromAddr(pctx.Addr), ClientID: ctx.clientID, } - switch d.Proto { + switch pctx.Proto { case proxy.ProtoHTTPS: p.ClientProto = querylog.ClientProtoDOH case proxy.ProtoQUIC: @@ -54,46 +53,44 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { // request. } - if d.Upstream != nil { - p.Upstream = d.Upstream.Address() + if pctx.Upstream != nil { + p.Upstream = pctx.Upstream.Address() } s.queryLog.Add(p) } - s.updateStats(d, elapsed, *ctx.result) + s.updateStats(ctx, elapsed, *ctx.result) s.RUnlock() return resultCodeSuccess } -func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) { +func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) { if s.stats == nil { return } + pctx := ctx.proxyCtx e := stats.Entry{} - e.Domain = strings.ToLower(d.Req.Question[0].Name) + e.Domain = strings.ToLower(pctx.Req.Question[0].Name) e.Domain = e.Domain[:len(e.Domain)-1] // remove last "." - switch addr := d.Addr.(type) { - case *net.UDPAddr: - e.Client = addr.IP - case *net.TCPAddr: - e.Client = addr.IP + + if clientID := ctx.clientID; clientID != "" { + e.Client = clientID + } else if pctx.Addr != nil { + e.Client = pctx.Addr.String() } + e.Time = uint32(elapsed / 1000) e.Result = stats.RNotFiltered switch res.Reason { - case dnsfilter.FilteredSafeBrowsing: e.Result = stats.RSafeBrowsing - case dnsfilter.FilteredParental: e.Result = stats.RParental - case dnsfilter.FilteredSafeSearch: e.Result = stats.RSafeSearch - case dnsfilter.FilteredBlockList: fallthrough case dnsfilter.FilteredInvalid: diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 1addbebd0df..7ed1d320bea 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -76,10 +76,14 @@ const ( rLast ) -// Entry - data to add +// Entry is a statistics data entry. type Entry struct { + // Clients is the client's primary ID. + // + // TODO(a.garipov): Make this a {net.IP, string} enum? + Client string + Domain string - Client net.IP Result Result Time uint32 // processing time (msec) } diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index b4be4db0699..c4fbe1918e7 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -39,13 +39,13 @@ func TestStats(t *testing.T) { e := Entry{} e.Domain = "domain" - e.Client = net.IP{127, 0, 0, 1} + e.Client = "127.0.0.1" e.Result = RFiltered e.Time = 123456 s.Update(e) e.Domain = "domain" - e.Client = net.IP{127, 0, 0, 1} + e.Client = "127.0.0.1" e.Result = RNotFiltered e.Time = 123456 s.Update(e) @@ -113,9 +113,10 @@ func TestLargeNumbers(t *testing.T) { } for i := 0; i != n; i++ { e.Domain = fmt.Sprintf("domain%d", i) - e.Client = net.IP{127, 0, 0, 1} - e.Client[2] = byte((i & 0xff00) >> 8) - e.Client[3] = byte(i & 0xff) + ip := net.IP{127, 0, 0, 1} + ip[2] = byte((i & 0xff00) >> 8) + ip[3] = byte(i & 0xff) + e.Client = ip.String() e.Result = RNotFiltered e.Time = 123456 s.Update(e) diff --git a/internal/stats/unit.go b/internal/stats/unit.go index 962fe85b3f9..6f31cd5e0a1 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -223,6 +223,7 @@ func (s *statsCtx) periodicFlush() { s.unitLock.Lock() ptr := s.unit s.unitLock.Unlock() + if ptr == nil { break } @@ -230,6 +231,7 @@ func (s *statsCtx) periodicFlush() { id := s.conf.UnitID() if ptr.id == id { time.Sleep(time.Second) + continue } @@ -243,6 +245,7 @@ func (s *statsCtx) periodicFlush() { if tx == nil { continue } + ok1 := s.flushUnitToDB(tx, u.id, udb) ok2 := s.deleteUnit(tx, id-s.conf.limit) if ok1 || ok2 { @@ -251,6 +254,7 @@ func (s *statsCtx) periodicFlush() { _ = tx.Rollback() } } + log.Tracef("periodicFlush() exited") } @@ -265,7 +269,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool { return true } -func convertMapToArray(m map[string]uint64, max int) []countPair { +func convertMapToSlice(m map[string]uint64, max int) []countPair { a := []countPair{} for k, v := range m { pair := countPair{} @@ -283,7 +287,7 @@ func convertMapToArray(m map[string]uint64, max int) []countPair { return a[:max] } -func convertArrayToMap(a []countPair) map[string]uint64 { +func convertSliceToMap(a []countPair) map[string]uint64 { m := map[string]uint64{} for _, it := range a { m[it.Name] = it.Count @@ -301,9 +305,9 @@ func serialize(u *unit) *unitDB { udb.TimeAvg = uint32(u.timeSum / u.nTotal) } - udb.Domains = convertMapToArray(u.domains, maxDomains) - udb.BlockedDomains = convertMapToArray(u.blockedDomains, maxDomains) - udb.Clients = convertMapToArray(u.clients, maxClients) + udb.Domains = convertMapToSlice(u.domains, maxDomains) + udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains) + udb.Clients = convertMapToSlice(u.clients, maxClients) return &udb } @@ -319,9 +323,9 @@ func deserialize(u *unit, udb *unitDB) { u.nResult[i] = udb.NResult[i] } - u.domains = convertArrayToMap(udb.Domains) - u.blockedDomains = convertArrayToMap(udb.BlockedDomains) - u.clients = convertArrayToMap(udb.Clients) + u.domains = convertSliceToMap(udb.Domains) + u.blockedDomains = convertSliceToMap(udb.BlockedDomains) + u.clients = convertSliceToMap(udb.Clients) u.timeSum = uint64(udb.TimeAvg) * u.nTotal } @@ -372,7 +376,7 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB { return &udb } -func convertTopArray(a []countPair) []map[string]uint64 { +func convertTopSlice(a []countPair) []map[string]uint64 { m := []map[string]uint64{} for _, it := range a { ent := map[string]uint64{} @@ -461,13 +465,20 @@ func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) { func (s *statsCtx) Update(e Entry) { if e.Result == 0 || e.Result >= rLast || - len(e.Domain) == 0 || - !(len(e.Client) == 4 || len(e.Client) == 16) { + e.Domain == "" || + e.Client == "" { return } - client := s.getClientIP(e.Client) + + clientID := e.Client + if ip := net.ParseIP(clientID); ip != nil { + ip = s.getClientIP(ip) + clientID = ip.String() + } s.unitLock.Lock() + defer s.unitLock.Unlock() + u := s.unit u.nResult[e.Result]++ @@ -478,10 +489,9 @@ func (s *statsCtx) Update(e Entry) { u.blockedDomains[e.Domain]++ } - u.clients[client.String()]++ + u.clients[clientID]++ u.timeSum += uint64(e.Time) u.nTotal++ - s.unitLock.Unlock() } func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { @@ -594,8 +604,8 @@ func (s *statsCtx) getData() (statsResponse, bool) { m[it.Name] += it.Count } } - a2 := convertMapToArray(m, max) - return convertTopArray(a2) + a2 := convertMapToSlice(m, max) + return convertTopSlice(a2) } dnsQueries := statsCollector(func(u *unitDB) (num uint64) { return u.NTotal }) @@ -661,7 +671,7 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP { m[it.Name] += it.Count } } - a := convertMapToArray(m, int(maxCount)) + a := convertMapToSlice(m, int(maxCount)) d := []net.IP{} for _, it := range a { d = append(d, net.ParseIP(it.Name))