diff --git a/cmd/to-nft/internal/nft/cases/aggregate-local-rules.go b/cmd/to-nft/internal/nft/cases/aggregate-local-rules.go index ae8a5d65..69d6f83a 100644 --- a/cmd/to-nft/internal/nft/cases/aggregate-local-rules.go +++ b/cmd/to-nft/internal/nft/cases/aggregate-local-rules.go @@ -37,7 +37,7 @@ type ( // LocalRules - LocalRules struct { - LocalSGs + SGs Rules []model.SGRule Networks Sg2Networks } @@ -52,21 +52,19 @@ type ( ) // Load ... -func (rules *LocalRules) Load(ctx context.Context, client SGClient, locals LocalSGs) (err error) { +func (rules *LocalRules) Load(ctx context.Context, client SGClient, locals SGs) (err error) { const api = "LocalRules/Load" defer func() { err = errors.WithMessage(err, api) }() - rules.LocalSGs = make(LocalSGs) - rules.Networks = nil - rules.Rules = nil - localSgNames := locals.Names() if len(localSgNames) == 0 { return nil } + rules.Networks = nil + rules.Rules = nil reqs := []sgAPI.FindRulesReq{ {SgFrom: localSgNames}, {SgTo: localSgNames}, } @@ -95,57 +93,45 @@ func (rules *LocalRules) Load(ctx context.Context, client SGClient, locals Local Proto: v.Transport, } }).ToSlice(&rules.Rules) - for _, r := range rules.Rules { - for _, n := range [...]string{r.SgFrom.Name, r.SgTo.Name} { - if sg := locals[n]; sg != nil && rules.LocalSGs[n] == nil { - rules.LocalSGs[n] = sg - } - } + if err = rules.SGs.LoadFromRules(ctx, client, rules.Rules); err == nil { + err = rules.Networks.Load(ctx, client, rules.SGs.Names()) } - err = rules.Networks.Load(ctx, client, rules.LocalSGs.Names()) return err } // IterateNetworks ... func (rules LocalRules) IterateNetworks(f func(sgName string, nets []net.IPNet, isV6 bool) error) error { - type tk = struct { - sgName string - v6 bool - } - seen := make(map[tk]bool) - send := func(sgName string, isV6 bool, nets []net.IPNet) error { - k := tk{sgName, isV6} - if !seen[k] { - seen[k] = true - return f(sgName, nets, isV6) - } - return nil + var err error + type item = struct { + sg string + nw *SgNetworks } - for _, r := range rules.Rules { - nw1 := rules.Networks[r.SgFrom.Name] - nw2 := rules.Networks[r.SgTo.Name] - if nw1 != nil && nw2 != nil { - if len(nw1.V4) > 0 && len(nw2.V4) > 0 { - err := send(r.SgFrom.Name, false, nw1.V4) - if err == nil { - err = send(r.SgTo.Name, false, nw2.V4) - } - if err != nil { - return err - } - } - if len(nw1.V6) > 0 && len(nw2.V6) > 0 { - err := send(r.SgFrom.Name, false, nw1.V6) - if err == nil { - err = send(r.SgTo.Name, false, nw2.V6) + linq.From(rules.Rules). + SelectMany(func(i any) linq.Query { + r := i.(model.SGRule) + return linq.From([...]item{ + {sg: r.SgFrom.Name, nw: rules.Networks[r.SgFrom.Name]}, + {sg: r.SgTo.Name, nw: rules.Networks[r.SgTo.Name]}, + }) + }). + Where(func(i any) bool { + return i.(item).nw != nil + }). + DistinctBy(func(i any) any { + return i.(item).sg + }). + ForEach(func(i any) { + if err == nil { + v := i.(item) + if len(v.nw.V4) > 0 { + err = f(v.sg, v.nw.V4, false) } - if err != nil { - return err + if err == nil && len(v.nw.V6) > 0 { + err = f(v.sg, v.nw.V6, true) } } - } - } - return nil + }) + return err } // TemplatesOutRules - @@ -165,10 +151,14 @@ func (rules LocalRules) TemplatesOutRules() []RulesOutTemplate { //nolint:dupl return groupped{Sg: r.SgTo.Name, Proto: r.Transport} }, ). + Where(func(i any) bool { + v := i.(linq.Group) + return rules.SGs[v.Key.(string)] != nil + }). Select(func(i any) any { v := i.(linq.Group) item := RulesOutTemplate{ - SgOut: rules.LocalSGs[v.Key.(string)].SG, + SgOut: rules.SGs[v.Key.(string)].SecurityGroup, } for _, g := range v.Group { item.In = append(item.In, g.(groupped)) @@ -195,10 +185,14 @@ func (rules LocalRules) TemplatesInRules() []RulesInTemplate { //nolint:dupl return groupped{Sg: r.SgFrom.Name, Proto: r.Transport} }, ). + Where(func(i any) bool { + v := i.(linq.Group) + return rules.SGs[v.Key.(string)] != nil + }). Select(func(i any) any { v := i.(linq.Group) item := RulesInTemplate{ - SgIn: rules.LocalSGs[v.Key.(string)].SG, + SgIn: rules.SGs[v.Key.(string)].SecurityGroup, } for _, g := range v.Group { item.Out = append(item.Out, g.(groupped)) diff --git a/cmd/to-nft/internal/nft/cases/aggregate-local-sgs.go b/cmd/to-nft/internal/nft/cases/aggregate-local-sgs.go index 5325dd4a..01aa124c 100644 --- a/cmd/to-nft/internal/nft/cases/aggregate-local-sgs.go +++ b/cmd/to-nft/internal/nft/cases/aggregate-local-sgs.go @@ -12,6 +12,7 @@ import ( "github.com/H-BF/corlib/pkg/parallel" "github.com/H-BF/corlib/pkg/slice" + "github.com/ahmetb/go-linq/v3" "github.com/c-robinson/iplib" "github.com/pkg/errors" "google.golang.org/grpc/codes" @@ -28,27 +29,28 @@ type ( //SGClient is a type alias SGClient = sgAPI.SecGroupServiceClient - // LocalSG ... - LocalSG struct { - SG model.SecurityGroup + // SG ... + SG struct { + model.SecurityGroup IPsV4, IPsV6 iplib.ByIP } - // LocalSGs local SG(s) related to IP(s) - LocalSGs map[SgName]*LocalSG + // SGs local SG(s) related to IP(s) + SGs map[SgName]*SG ) -// Load it loads Local SGs by IPs -func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP) error { - const api = "LocalSG(s)/Load" - - *loc = make(LocalSGs) - if len(srcIPs) == 0 { +// LoadFromIPs it loads Local SGs by IPs +func (loc *SGs) LoadFromIPs(ctx context.Context, client SGClient, localIPs []net.IP) error { + const api = "SG(s)/LoadFromIPs" + if *loc == nil { + *loc = make(SGs) + } + if len(localIPs) == 0 { return nil } var mx sync.Mutex job := func(i int) error { - srcIP := srcIPs[i] + srcIP := localIPs[i] req := &sgAPI.GetSecGroupForAddressReq{ Address: srcIP.String(), } @@ -67,7 +69,7 @@ func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP) defer mx.Unlock() it := (*loc)[sg.Name] if it == nil { - it = &LocalSG{SG: sg} + it = &SG{SecurityGroup: sg} (*loc)[sg.Name] = it } switch len(srcIP) { @@ -78,7 +80,7 @@ func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP) } return nil } - if err := parallel.ExecAbstract(len(srcIPs), 8, job); err != nil { //nolint:gomnd + if err := parallel.ExecAbstract(len(localIPs), 8, job); err != nil { //nolint:gomnd return errors.WithMessage(err, api) } for _, it := range *loc { @@ -94,8 +96,44 @@ func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP) return nil } +// LoadFromRules it loads Local SGs from SG rules +func (loc *SGs) LoadFromRules(ctx context.Context, client SGClient, rules []model.SGRule) error { + const api = "SG(s)/LoadFromRules" + + if *loc == nil { + *loc = make(SGs) + } + usedSG := make([]string, 0, len(rules)*2) + linq.From(rules). + SelectMany(func(i any) linq.Query { + r := i.(model.SGRule) + return linq.From([...]string{r.SgFrom.Name, r.SgTo.Name}) + }).Distinct().ToSlice(&usedSG) + + if len(usedSG) == 0 { + return nil + } + sgsResp, err := client.ListSecurityGroups(ctx, &sgAPI.ListSecurityGroupsReq{SgNames: usedSG}) + if err != nil { + return errors.WithMessage(err, api) + } + linq.From(sgsResp.GetGroups()). + ForEach(func(i any) { + if err != nil { + return + } + g := i.(*sgAPI.SecGroup) + if sg, e := conv.Proto2ModelSG(g); e != nil { + err = e + } else { + (*loc)[sg.Name] = &SG{SecurityGroup: sg} + } + }) + return errors.WithMessage(err, api) +} + // Names get local SG(s) names -func (loc LocalSGs) Names() []SgName { +func (loc SGs) Names() []SgName { ret := make([]SgName, 0, len(loc)) for n := range loc { ret = append(ret, n) diff --git a/cmd/to-nft/internal/nft/nft-processor.go b/cmd/to-nft/internal/nft/nft-processor.go index 98934936..2b57e9bb 100644 --- a/cmd/to-nft/internal/nft/nft-processor.go +++ b/cmd/to-nft/internal/nft/nft-processor.go @@ -45,12 +45,12 @@ func (impl *nfTablesProcessorImpl) ApplyConf(ctx context.Context, conf NetConf) var ( err error localRules cases.LocalRules - localSGs cases.LocalSGs + localSGs cases.SGs ) localIPsV4, loaclIPsV6 := conf.LocalIPs() allLoaclIPs := append(localIPsV4, loaclIPsV6...) - if err = localSGs.Load(ctx, impl.sgClient, allLoaclIPs); err != nil { + if err = localSGs.LoadFromIPs(ctx, impl.sgClient, allLoaclIPs); err != nil { return multierr.Combine(ErrNfTablesProcessor, err, pkgErr.ErrDetails{Api: api, Details: allLoaclIPs}) }