Skip to content

Commit

Permalink
audit port manual hook db hits
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Siwiec <rizzza@users.noreply.github.com>
  • Loading branch information
rizzza committed Oct 24, 2023
1 parent 7e364de commit 7749151
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 73 deletions.
92 changes: 27 additions & 65 deletions internal/manualhooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1014,48 +1014,37 @@ func PortHooks() []ent.Hook {
return retValue, err
}

addSubjPools, err := m.Client().Pool.Query().Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx)
// Ensure we have additional relevant subjects in the event msg
addSubjPools, err := m.Client().Pool.Query().WithPorts(func(q *generated.PortQuery) {
q.WithLoadBalancer()
}).Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx)
if err == nil {
for _, pool := range addSubjPools {
if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID)
}
for _, port := range pool.Edges.Ports {
if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID)
}
}
}
addSubjLoadBalancers, err := m.Client().LoadBalancer.Query().Where(loadbalancer.HasPortsWith(port.IDEQ(objID))).All(ctx)
if err == nil {
for _, lb := range addSubjLoadBalancers {
if !slices.Contains(msg.AdditionalSubjectIDs, lb.ID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ID)
}
if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID)
}
if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, lb.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.OwnerID)
if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.OwnerID)
}
}

if !slices.Contains(msg.AdditionalSubjectIDs, lb.ProviderID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ProviderID)
if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID)
}
}
}

lbs := getLoadBalancerIDs(ctx, objID, msg.AdditionalSubjectIDs)
for _, lb := range lbs {
lb, err := m.Client().LoadBalancer.Get(ctx, lb)
if err != nil {
return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb)
}

if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID)
if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID)
}
}
}

Expand Down Expand Up @@ -1087,12 +1076,16 @@ func PortHooks() []ent.Hook {
return nil, fmt.Errorf("object doesn't have an id %s", objID)
}

dbObj, err := m.Client().Port.Get(ctx, objID)
dbObj, err := m.Client().Port.Query().WithLoadBalancer().Where(port.IDEQ(objID)).Only(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load object to get values for event, err %w", err)
}

// Ensure we have additional relevant subjects in the event msg
additionalSubjects = append(additionalSubjects, dbObj.LoadBalancerID)
additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.LocationID)
additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.OwnerID)
additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.ProviderID)

// we have all the info we need, now complete the mutation before we process the event
retValue, err := next.Mutate(ctx, m)
Expand All @@ -1106,37 +1099,6 @@ func PortHooks() []ent.Hook {
}
}

addSubjLoadBalancer, err := m.Client().LoadBalancer.Get(ctx, dbObj.LoadBalancerID)
if err == nil {
if !slices.Contains(additionalSubjects, addSubjLoadBalancer.ID) {
additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.ID)
}

if !slices.Contains(additionalSubjects, addSubjLoadBalancer.LocationID) {
additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.LocationID)
}

if !slices.Contains(additionalSubjects, addSubjLoadBalancer.OwnerID) {
additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.OwnerID)
}

if !slices.Contains(additionalSubjects, addSubjLoadBalancer.ProviderID) {
additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.ProviderID)
}
}

lbs := getLoadBalancerIDs(ctx, objID, additionalSubjects)
for _, lb := range lbs {
lb, err := m.Client().LoadBalancer.Get(ctx, lb)
if err != nil {
return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb)
}

if !slices.Contains(additionalSubjects, lb.LocationID) {
additionalSubjects = append(additionalSubjects, lb.LocationID)
}
}

msg := events.ChangeMessage{
EventType: eventType(m.Op()),
SubjectID: objID,
Expand Down
18 changes: 10 additions & 8 deletions internal/manualhooks/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ func Test_MultipleLoadbalancersSharedPoolAddOrigin(t *testing.T) {

// create 2 loadbalancers with a shared pool of origins
prov := (&testutils.ProviderBuilder{}).MustNew(ctx)
lb1 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx)
lb2 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx)
pool := (&testutils.PoolBuilder{OwnerID: prov.OwnerID}).MustNew(ctx)
lb1 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx)
lb2 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx)
pool := (&testutils.PoolBuilder{OwnerID: "tnttent-testing"}).MustNew(ctx)
_ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb1.ID}).MustNew(ctx)
_ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb2.ID}).MustNew(ctx)
_ = (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx)
Expand All @@ -384,7 +384,8 @@ func Test_MultipleLoadbalancersSharedPoolAddOrigin(t *testing.T) {

// Assert
expectedAdditionalSubjectIDs := []gidx.PrefixedID{
prov.OwnerID,
prov.ID,
lb1.OwnerID,
lb1.ID,
lb2.ID,
lb1.LocationID,
Expand All @@ -408,9 +409,9 @@ func Test_MultipleLoadbalancersSharedPoolDeleteOrigin(t *testing.T) {

// create 2 loadbalancers with a shared pool of origins
prov := (&testutils.ProviderBuilder{}).MustNew(ctx)
lb1 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx)
lb2 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx)
pool := (&testutils.PoolBuilder{OwnerID: prov.OwnerID}).MustNew(ctx)
lb1 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx)
lb2 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx)
pool := (&testutils.PoolBuilder{OwnerID: "tnttent-testing"}).MustNew(ctx)
_ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb1.ID}).MustNew(ctx)
_ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb2.ID}).MustNew(ctx)
_ = (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx)
Expand All @@ -428,7 +429,8 @@ func Test_MultipleLoadbalancersSharedPoolDeleteOrigin(t *testing.T) {

// Assert
expectedAdditionalSubjectIDs := []gidx.PrefixedID{
prov.OwnerID,
prov.ID,
lb1.OwnerID,
lb1.ID,
lb2.ID,
lb1.LocationID,
Expand Down

0 comments on commit 7749151

Please sign in to comment.