diff --git a/internal/manualhooks/hooks.go b/internal/manualhooks/hooks.go index 8274304b0..289115627 100644 --- a/internal/manualhooks/hooks.go +++ b/internal/manualhooks/hooks.go @@ -1014,48 +1014,37 @@ func PortHooks() []ent.Hook { return retValue, err } - addSubjPools, err := m.Client().Pool.Query().Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx) + // Ensure we have additional relevant subjects in the event msg + addSubjPools, err := m.Client().Pool.Query().WithPorts(func(q *generated.PortQuery) { + q.WithLoadBalancer() + }).Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx) if err == nil { for _, pool := range addSubjPools { - if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID) - } + for _, port := range pool.Edges.Ports { + if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID) + } - if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID) - } - } - } - addSubjLoadBalancers, err := m.Client().LoadBalancer.Query().Where(loadbalancer.HasPortsWith(port.IDEQ(objID))).All(ctx) - if err == nil { - for _, lb := range addSubjLoadBalancers { - if !slices.Contains(msg.AdditionalSubjectIDs, lb.ID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ID) - } + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.LocationID) + } - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) - } + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.ProviderID) + } - if !slices.Contains(msg.AdditionalSubjectIDs, lb.OwnerID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.OwnerID) + if !slices.Contains(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.OwnerID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.Edges.LoadBalancer.OwnerID) + } } - if !slices.Contains(msg.AdditionalSubjectIDs, lb.ProviderID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ProviderID) + if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID) } - } - } - lbs := getLoadBalancerIDs(ctx, objID, msg.AdditionalSubjectIDs) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) { - msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID) + if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) { + msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID) + } } } @@ -1087,12 +1076,16 @@ func PortHooks() []ent.Hook { return nil, fmt.Errorf("object doesn't have an id %s", objID) } - dbObj, err := m.Client().Port.Get(ctx, objID) + dbObj, err := m.Client().Port.Query().WithLoadBalancer().Where(port.IDEQ(objID)).Only(ctx) if err != nil { return nil, fmt.Errorf("failed to load object to get values for event, err %w", err) } + // Ensure we have additional relevant subjects in the event msg additionalSubjects = append(additionalSubjects, dbObj.LoadBalancerID) + additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.LocationID) + additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.OwnerID) + additionalSubjects = append(additionalSubjects, dbObj.Edges.LoadBalancer.ProviderID) // we have all the info we need, now complete the mutation before we process the event retValue, err := next.Mutate(ctx, m) @@ -1106,37 +1099,6 @@ func PortHooks() []ent.Hook { } } - addSubjLoadBalancer, err := m.Client().LoadBalancer.Get(ctx, dbObj.LoadBalancerID) - if err == nil { - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.ID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.ID) - } - - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.LocationID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.LocationID) - } - - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.OwnerID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.OwnerID) - } - - if !slices.Contains(additionalSubjects, addSubjLoadBalancer.ProviderID) { - additionalSubjects = append(additionalSubjects, addSubjLoadBalancer.ProviderID) - } - } - - lbs := getLoadBalancerIDs(ctx, objID, additionalSubjects) - for _, lb := range lbs { - lb, err := m.Client().LoadBalancer.Get(ctx, lb) - if err != nil { - return nil, fmt.Errorf("failed to get loadbalancer to lookup location %s", lb) - } - - if !slices.Contains(additionalSubjects, lb.LocationID) { - additionalSubjects = append(additionalSubjects, lb.LocationID) - } - } - msg := events.ChangeMessage{ EventType: eventType(m.Op()), SubjectID: objID, diff --git a/internal/manualhooks/hooks_test.go b/internal/manualhooks/hooks_test.go index be8ca1e3b..4f3ae162f 100644 --- a/internal/manualhooks/hooks_test.go +++ b/internal/manualhooks/hooks_test.go @@ -365,9 +365,9 @@ func Test_MultipleLoadbalancersSharedPoolAddOrigin(t *testing.T) { // create 2 loadbalancers with a shared pool of origins prov := (&testutils.ProviderBuilder{}).MustNew(ctx) - lb1 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx) - lb2 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx) - pool := (&testutils.PoolBuilder{OwnerID: prov.OwnerID}).MustNew(ctx) + lb1 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + lb2 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + pool := (&testutils.PoolBuilder{OwnerID: "tnttent-testing"}).MustNew(ctx) _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb1.ID}).MustNew(ctx) _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb2.ID}).MustNew(ctx) _ = (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) @@ -384,7 +384,8 @@ func Test_MultipleLoadbalancersSharedPoolAddOrigin(t *testing.T) { // Assert expectedAdditionalSubjectIDs := []gidx.PrefixedID{ - prov.OwnerID, + prov.ID, + lb1.OwnerID, lb1.ID, lb2.ID, lb1.LocationID, @@ -408,9 +409,9 @@ func Test_MultipleLoadbalancersSharedPoolDeleteOrigin(t *testing.T) { // create 2 loadbalancers with a shared pool of origins prov := (&testutils.ProviderBuilder{}).MustNew(ctx) - lb1 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx) - lb2 := (&testutils.LoadBalancerBuilder{OwnerID: prov.OwnerID}).MustNew(ctx) - pool := (&testutils.PoolBuilder{OwnerID: prov.OwnerID}).MustNew(ctx) + lb1 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + lb2 := (&testutils.LoadBalancerBuilder{OwnerID: "tnttent-testing", Provider: prov}).MustNew(ctx) + pool := (&testutils.PoolBuilder{OwnerID: "tnttent-testing"}).MustNew(ctx) _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb1.ID}).MustNew(ctx) _ = (&testutils.PortBuilder{PoolIDs: []gidx.PrefixedID{pool.ID}, LoadBalancerID: lb2.ID}).MustNew(ctx) _ = (&testutils.OriginBuilder{PoolID: pool.ID}).MustNew(ctx) @@ -428,7 +429,8 @@ func Test_MultipleLoadbalancersSharedPoolDeleteOrigin(t *testing.T) { // Assert expectedAdditionalSubjectIDs := []gidx.PrefixedID{ - prov.OwnerID, + prov.ID, + lb1.OwnerID, lb1.ID, lb2.ID, lb1.LocationID,