Skip to content

Commit

Permalink
Merge branch 'SWARM-65'
Browse files Browse the repository at this point in the history
  • Loading branch information
EvgenyGri committed Aug 11, 2023
2 parents 64cac77 + 93b06b3 commit a70a1ce
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 66 deletions.
92 changes: 43 additions & 49 deletions cmd/to-nft/internal/nft/cases/aggregate-local-rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type (

// LocalRules -
LocalRules struct {
LocalSGs
SGs
Rules []model.SGRule
Networks Sg2Networks
}
Expand All @@ -52,21 +52,19 @@ type (
)

// Load ...
func (rules *LocalRules) Load(ctx context.Context, client SGClient, locals LocalSGs) (err error) {
func (rules *LocalRules) Load(ctx context.Context, client SGClient, locals SGs) (err error) {
const api = "LocalRules/Load"

defer func() {
err = errors.WithMessage(err, api)
}()

rules.LocalSGs = make(LocalSGs)
rules.Networks = nil
rules.Rules = nil

localSgNames := locals.Names()
if len(localSgNames) == 0 {
return nil
}
rules.Networks = nil
rules.Rules = nil
reqs := []sgAPI.FindRulesReq{
{SgFrom: localSgNames}, {SgTo: localSgNames},
}
Expand Down Expand Up @@ -95,57 +93,45 @@ func (rules *LocalRules) Load(ctx context.Context, client SGClient, locals Local
Proto: v.Transport,
}
}).ToSlice(&rules.Rules)
for _, r := range rules.Rules {
for _, n := range [...]string{r.SgFrom.Name, r.SgTo.Name} {
if sg := locals[n]; sg != nil && rules.LocalSGs[n] == nil {
rules.LocalSGs[n] = sg
}
}
if err = rules.SGs.LoadFromRules(ctx, client, rules.Rules); err == nil {
err = rules.Networks.Load(ctx, client, rules.SGs.Names())
}
err = rules.Networks.Load(ctx, client, rules.LocalSGs.Names())
return err
}

// IterateNetworks ...
func (rules LocalRules) IterateNetworks(f func(sgName string, nets []net.IPNet, isV6 bool) error) error {
type tk = struct {
sgName string
v6 bool
}
seen := make(map[tk]bool)
send := func(sgName string, isV6 bool, nets []net.IPNet) error {
k := tk{sgName, isV6}
if !seen[k] {
seen[k] = true
return f(sgName, nets, isV6)
}
return nil
var err error
type item = struct {
sg string
nw *SgNetworks
}
for _, r := range rules.Rules {
nw1 := rules.Networks[r.SgFrom.Name]
nw2 := rules.Networks[r.SgTo.Name]
if nw1 != nil && nw2 != nil {
if len(nw1.V4) > 0 && len(nw2.V4) > 0 {
err := send(r.SgFrom.Name, false, nw1.V4)
if err == nil {
err = send(r.SgTo.Name, false, nw2.V4)
}
if err != nil {
return err
}
}
if len(nw1.V6) > 0 && len(nw2.V6) > 0 {
err := send(r.SgFrom.Name, false, nw1.V6)
if err == nil {
err = send(r.SgTo.Name, false, nw2.V6)
linq.From(rules.Rules).
SelectMany(func(i any) linq.Query {
r := i.(model.SGRule)
return linq.From([...]item{
{sg: r.SgFrom.Name, nw: rules.Networks[r.SgFrom.Name]},
{sg: r.SgTo.Name, nw: rules.Networks[r.SgTo.Name]},
})
}).
Where(func(i any) bool {
return i.(item).nw != nil
}).
DistinctBy(func(i any) any {
return i.(item).sg
}).
ForEach(func(i any) {
if err == nil {
v := i.(item)
if len(v.nw.V4) > 0 {
err = f(v.sg, v.nw.V4, false)
}
if err != nil {
return err
if err == nil && len(v.nw.V6) > 0 {
err = f(v.sg, v.nw.V6, true)
}
}
}
}
return nil
})
return err
}

// TemplatesOutRules -
Expand All @@ -165,10 +151,14 @@ func (rules LocalRules) TemplatesOutRules() []RulesOutTemplate { //nolint:dupl
return groupped{Sg: r.SgTo.Name, Proto: r.Transport}
},
).
Where(func(i any) bool {
v := i.(linq.Group)
return rules.SGs[v.Key.(string)] != nil
}).
Select(func(i any) any {
v := i.(linq.Group)
item := RulesOutTemplate{
SgOut: rules.LocalSGs[v.Key.(string)].SG,
SgOut: rules.SGs[v.Key.(string)].SecurityGroup,
}
for _, g := range v.Group {
item.In = append(item.In, g.(groupped))
Expand All @@ -195,10 +185,14 @@ func (rules LocalRules) TemplatesInRules() []RulesInTemplate { //nolint:dupl
return groupped{Sg: r.SgFrom.Name, Proto: r.Transport}
},
).
Where(func(i any) bool {
v := i.(linq.Group)
return rules.SGs[v.Key.(string)] != nil
}).
Select(func(i any) any {
v := i.(linq.Group)
item := RulesInTemplate{
SgIn: rules.LocalSGs[v.Key.(string)].SG,
SgIn: rules.SGs[v.Key.(string)].SecurityGroup,
}
for _, g := range v.Group {
item.Out = append(item.Out, g.(groupped))
Expand Down
68 changes: 53 additions & 15 deletions cmd/to-nft/internal/nft/cases/aggregate-local-sgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/H-BF/corlib/pkg/parallel"
"github.com/H-BF/corlib/pkg/slice"
"github.com/ahmetb/go-linq/v3"
"github.com/c-robinson/iplib"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
Expand All @@ -28,27 +29,28 @@ type (
//SGClient is a type alias
SGClient = sgAPI.SecGroupServiceClient

// LocalSG ...
LocalSG struct {
SG model.SecurityGroup
// SG ...
SG struct {
model.SecurityGroup
IPsV4, IPsV6 iplib.ByIP
}

// LocalSGs local SG(s) related to IP(s)
LocalSGs map[SgName]*LocalSG
// SGs local SG(s) related to IP(s)
SGs map[SgName]*SG
)

// Load it loads Local SGs by IPs
func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP) error {
const api = "LocalSG(s)/Load"

*loc = make(LocalSGs)
if len(srcIPs) == 0 {
// LoadFromIPs it loads Local SGs by IPs
func (loc *SGs) LoadFromIPs(ctx context.Context, client SGClient, localIPs []net.IP) error {
const api = "SG(s)/LoadFromIPs"
if *loc == nil {
*loc = make(SGs)
}
if len(localIPs) == 0 {
return nil
}
var mx sync.Mutex
job := func(i int) error {
srcIP := srcIPs[i]
srcIP := localIPs[i]
req := &sgAPI.GetSecGroupForAddressReq{
Address: srcIP.String(),
}
Expand All @@ -67,7 +69,7 @@ func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP)
defer mx.Unlock()
it := (*loc)[sg.Name]
if it == nil {
it = &LocalSG{SG: sg}
it = &SG{SecurityGroup: sg}
(*loc)[sg.Name] = it
}
switch len(srcIP) {
Expand All @@ -78,7 +80,7 @@ func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP)
}
return nil
}
if err := parallel.ExecAbstract(len(srcIPs), 8, job); err != nil { //nolint:gomnd
if err := parallel.ExecAbstract(len(localIPs), 8, job); err != nil { //nolint:gomnd
return errors.WithMessage(err, api)
}
for _, it := range *loc {
Expand All @@ -94,8 +96,44 @@ func (loc *LocalSGs) Load(ctx context.Context, client SGClient, srcIPs []net.IP)
return nil
}

// LoadFromRules it loads Local SGs from SG rules
func (loc *SGs) LoadFromRules(ctx context.Context, client SGClient, rules []model.SGRule) error {
const api = "SG(s)/LoadFromRules"

if *loc == nil {
*loc = make(SGs)
}
usedSG := make([]string, 0, len(rules)*2)
linq.From(rules).
SelectMany(func(i any) linq.Query {
r := i.(model.SGRule)
return linq.From([...]string{r.SgFrom.Name, r.SgTo.Name})
}).Distinct().ToSlice(&usedSG)

if len(usedSG) == 0 {
return nil
}
sgsResp, err := client.ListSecurityGroups(ctx, &sgAPI.ListSecurityGroupsReq{SgNames: usedSG})
if err != nil {
return errors.WithMessage(err, api)
}
linq.From(sgsResp.GetGroups()).
ForEach(func(i any) {
if err != nil {
return
}
g := i.(*sgAPI.SecGroup)
if sg, e := conv.Proto2ModelSG(g); e != nil {
err = e
} else {
(*loc)[sg.Name] = &SG{SecurityGroup: sg}
}
})
return errors.WithMessage(err, api)
}

// Names get local SG(s) names
func (loc LocalSGs) Names() []SgName {
func (loc SGs) Names() []SgName {
ret := make([]SgName, 0, len(loc))
for n := range loc {
ret = append(ret, n)
Expand Down
4 changes: 2 additions & 2 deletions cmd/to-nft/internal/nft/nft-processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ func (impl *nfTablesProcessorImpl) ApplyConf(ctx context.Context, conf NetConf)
var (
err error
localRules cases.LocalRules
localSGs cases.LocalSGs
localSGs cases.SGs
)

localIPsV4, loaclIPsV6 := conf.LocalIPs()
allLoaclIPs := append(localIPsV4, loaclIPsV6...)
if err = localSGs.Load(ctx, impl.sgClient, allLoaclIPs); err != nil {
if err = localSGs.LoadFromIPs(ctx, impl.sgClient, allLoaclIPs); err != nil {
return multierr.Combine(ErrNfTablesProcessor,
err, pkgErr.ErrDetails{Api: api, Details: allLoaclIPs})
}
Expand Down

0 comments on commit a70a1ce

Please sign in to comment.