From 112f1bd13acf992b0ba9562c29365b22d5374ec2 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Fri, 29 Dec 2023 18:29:11 +0300 Subject: [PATCH] home: imp code --- internal/home/client.go | 70 +++++++++++++++++++++++++--------------- internal/home/clients.go | 6 ++-- 2 files changed, 47 insertions(+), 29 deletions(-) diff --git a/internal/home/client.go b/internal/home/client.go index 03b960ea88e..731245ef5da 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -65,7 +65,8 @@ type persistentClient struct { Tags []string Upstreams []string - IPs []netip.Addr + IPs []netip.Addr + // TODO(s.chzhen): Use netutil.Prefix. Subnets []netip.Prefix MACs []net.HardwareAddr ClientIDs []string @@ -88,45 +89,57 @@ type persistentClient struct { // parseIDs parses a list of strings into typed fields. func (c *persistentClient) parseIDs(ids []string) (err error) { for _, id := range ids { - if id == "" { - return errors.Error("clientid is empty") + err = c.checkID(id) + if err != nil { + return err } + } - var ip netip.Addr - if ip, err = netip.ParseAddr(id); err == nil { - c.IPs = append(c.IPs, ip) + return nil +} - continue - } +// checkID parses id into typed field if there is no error. +func (c *persistentClient) checkID(id string) (err error) { + if id == "" { + return errors.Error("clientid is empty") + } - var subnet netip.Prefix - if subnet, err = netip.ParsePrefix(id); err == nil { - c.Subnets = append(c.Subnets, subnet) + var ip netip.Addr + if ip, err = netip.ParseAddr(id); err == nil { + c.IPs = append(c.IPs, ip) - continue - } + return nil + } - var mac net.HardwareAddr - if mac, err = net.ParseMAC(id); err == nil { - c.MACs = append(c.MACs, mac) + var subnet netip.Prefix + if subnet, err = netip.ParsePrefix(id); err == nil { + c.Subnets = append(c.Subnets, subnet) - continue - } + return nil + } - err = dnsforward.ValidateClientID(id) - if err != nil { - // Don't wrap the error, because it's informative enough as is. - return err - } + var mac net.HardwareAddr + if mac, err = net.ParseMAC(id); err == nil { + c.MACs = append(c.MACs, mac) - c.ClientIDs = append(c.ClientIDs, strings.ToLower(id)) + return nil } + err = dnsforward.ValidateClientID(id) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + c.ClientIDs = append(c.ClientIDs, strings.ToLower(id)) + return nil } // ids returns a list of client ids. func (c *persistentClient) ids() (ids []string) { + ids = make([]string, 0, c.idsLen()) + for _, ip := range c.IPs { ids = append(ids, ip.String()) } @@ -142,9 +155,14 @@ func (c *persistentClient) ids() (ids []string) { return append(ids, c.ClientIDs...) } -// shallowClone returns a deep copy of the client, except upstreamConfig, +// idsLen returns a length of client ids. +func (c *persistentClient) idsLen() (n int) { + return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) +} + +// clone returns a deep copy of the client, except upstreamConfig, // safeSearchConf, SafeSearch fields, because it's difficult to copy them. -func (c *persistentClient) shallowClone() (sh *persistentClient) { +func (c *persistentClient) clone() (sh *persistentClient) { clone := *c clone.BlockedServices = c.BlockedServices.Clone() diff --git a/internal/home/clients.go b/internal/home/clients.go index ce92a606407..7df46cdb038 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -314,7 +314,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { BlockedServices: cli.BlockedServices.Clone(), - IDs: stringutil.CloneSlice(cli.ids()), + IDs: cli.ids(), Tags: stringutil.CloneSlice(cli.Tags), Upstreams: stringutil.CloneSlice(cli.Upstreams), @@ -453,7 +453,7 @@ func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) return nil, false } - return c.shallowClone(), true + return c.clone(), true } // shouldCountClient is a wrapper around [clientsContainer.find] to make it a @@ -608,7 +608,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) { return errors.Error("client is nil") case c.Name == "": return errors.Error("invalid name") - case len(c.IPs)+len(c.Subnets)+len(c.MACs)+len(c.ClientIDs) == 0: + case c.idsLen() == 0: return errors.Error("id required") default: // Go on.