Skip to content

Commit

Permalink
Fix up mutation hooks to ensure additionalSubjects are present in the…
Browse files Browse the repository at this point in the history
… change message (#231)

* query all for hooks and add additionalSubjects to msg

Signed-off-by: Matt Siwiec <rizzza@users.noreply.github.com>

* origin delete additional subjects

Signed-off-by: Matt Siwiec <rizzza@users.noreply.github.com>

---------

Signed-off-by: Matt Siwiec <rizzza@users.noreply.github.com>
  • Loading branch information
rizzza authored Oct 2, 2023
1 parent d51c67c commit c02b206
Showing 1 changed file with 99 additions and 61 deletions.
160 changes: 99 additions & 61 deletions internal/manualhooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,12 @@ func LoadBalancerHooks() []ent.Hook {
return retValue, err
}

addSubjPort, err := m.Client().Port.Query().Where(port.HasLoadBalancerWith(loadbalancer.IDEQ(objID))).Only(ctx)
addSubjPortIDs, err := m.Client().Port.Query().Where(port.HasLoadBalancerWith(loadbalancer.IDEQ(objID))).IDs(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPort.ID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPort.ID)
for _, portID := range addSubjPortIDs {
if !slices.Contains(msg.AdditionalSubjectIDs, portID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, portID)
}
}
}

Expand Down Expand Up @@ -494,21 +496,25 @@ func OriginHooks() []ent.Hook {
return retValue, err
}

addSubjPool, err := m.Client().Pool.Query().Where(pool.HasOriginsWith(origin.IDEQ(objID))).Only(ctx)
addSubjPools, err := m.Client().Pool.Query().Where(pool.HasOriginsWith(origin.IDEQ(objID))).All(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPool.ID) && objID != addSubjPool.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPool.ID)
}
for _, pool := range addSubjPools {
if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPool.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPool.OwnerID)
if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID)
}
}
}

addSubjPort, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).Only(ctx)
addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).All(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPort.LoadBalancerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPort.LoadBalancerID)
for _, port := range addSubjPorts {
if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID)
}
}
}

Expand Down Expand Up @@ -559,6 +565,28 @@ func OriginHooks() []ent.Hook {

additionalSubjects = append(additionalSubjects, dbObj.PoolID)

addSubjPools, err := m.Client().Pool.Query().Where(pool.HasOriginsWith(origin.IDEQ(objID))).All(ctx)
if err == nil {
for _, pool := range addSubjPools {
if !slices.Contains(additionalSubjects, pool.ID) && objID != pool.ID {
additionalSubjects = append(additionalSubjects, pool.ID)
}

if !slices.Contains(additionalSubjects, pool.OwnerID) {
additionalSubjects = append(additionalSubjects, pool.OwnerID)
}
}
}

addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.HasOriginsWith(origin.IDEQ(objID)))).All(ctx)
if err == nil {
for _, port := range addSubjPorts {
if !slices.Contains(additionalSubjects, port.LoadBalancerID) {
additionalSubjects = append(additionalSubjects, port.LoadBalancerID)
}
}
}

relationships = append(relationships, events.AuthRelationshipRelation{
Relation: "pool",
SubjectID: dbObj.PoolID,
Expand Down Expand Up @@ -759,25 +787,29 @@ func PoolHooks() []ent.Hook {
return retValue, err
}

addSubjPort, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.IDEQ(objID))).Only(ctx)
addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.IDEQ(objID))).All(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPort.ID) && objID != addSubjPort.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPort.ID)
}
for _, port := range addSubjPorts {
if !slices.Contains(msg.AdditionalSubjectIDs, port.ID) && objID != port.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.ID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPort.LoadBalancerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPort.LoadBalancerID)
if !slices.Contains(msg.AdditionalSubjectIDs, port.LoadBalancerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, port.LoadBalancerID)
}
}
}

addSubjOrigin, err := m.Client().Origin.Query().Where(origin.HasPoolWith(pool.IDEQ(objID))).Only(ctx)
addSubjOrigins, err := m.Client().Origin.Query().Where(origin.HasPoolWith(pool.IDEQ(objID))).All(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjOrigin.ID) && objID != addSubjOrigin.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjOrigin.ID)
}
for _, origin := range addSubjOrigins {
if !slices.Contains(msg.AdditionalSubjectIDs, origin.ID) && objID != origin.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, origin.ID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjOrigin.PoolID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjOrigin.PoolID)
if !slices.Contains(msg.AdditionalSubjectIDs, origin.PoolID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, origin.PoolID)
}
}
}

Expand Down Expand Up @@ -828,30 +860,20 @@ func PoolHooks() []ent.Hook {

additionalSubjects = append(additionalSubjects, dbObj.OwnerID)

addSubjPorts, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.IDEQ(objID))).All(ctx)
if err == nil {
for _, port := range addSubjPorts {
if !slices.Contains(additionalSubjects, port.LoadBalancerID) {
additionalSubjects = append(additionalSubjects, port.LoadBalancerID)
}
}
}

relationships = append(relationships, events.AuthRelationshipRelation{
Relation: "owner",
SubjectID: dbObj.OwnerID,
})

// we have all the info we need, now complete the mutation before we process the event
retValue, err := next.Mutate(ctx, m)
if err != nil {
return retValue, err
}

if len(relationships) != 0 {
if err := permissions.DeleteAuthRelationships(ctx, "load-balancer-pool", objID, relationships...); err != nil {
return nil, fmt.Errorf("relationship request failed with error: %w", err)
}
}

addSubjPort, err := m.Client().Port.Query().Where(port.HasPoolsWith(pool.IDEQ(objID))).Only(ctx)
if err == nil {
if !slices.Contains(additionalSubjects, addSubjPort.LoadBalancerID) {
additionalSubjects = append(additionalSubjects, addSubjPort.LoadBalancerID)
}
}

lb_lookup := getLoadBalancerID(ctx, objID, additionalSubjects)
if lb_lookup != "" {
lb, err := m.Client().LoadBalancer.Get(ctx, lb_lookup)
Expand All @@ -864,6 +886,18 @@ func PoolHooks() []ent.Hook {
}
}

// we have all the info we need, now complete the mutation before we process the event
retValue, err := next.Mutate(ctx, m)
if err != nil {
return retValue, err
}

if len(relationships) != 0 {
if err := permissions.DeleteAuthRelationships(ctx, "load-balancer-pool", objID, relationships...); err != nil {
return nil, fmt.Errorf("relationship request failed with error: %w", err)
}
}

msg := events.ChangeMessage{
EventType: eventType(m.Op()),
SubjectID: objID,
Expand Down Expand Up @@ -1021,32 +1055,36 @@ func PortHooks() []ent.Hook {
return retValue, err
}

addSubjPool, err := m.Client().Pool.Query().Where(pool.HasPortsWith(port.IDEQ(objID))).Only(ctx)
addSubjPools, err := m.Client().Pool.Query().Where(pool.HasPortsWith(port.IDEQ(objID))).All(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPool.ID) && objID != addSubjPool.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPool.ID)
}
for _, pool := range addSubjPools {
if !slices.Contains(msg.AdditionalSubjectIDs, pool.ID) && objID != pool.ID {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.ID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjPool.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjPool.OwnerID)
if !slices.Contains(msg.AdditionalSubjectIDs, pool.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, pool.OwnerID)
}
}
}
addSubjLoadBalancer, err := m.Client().LoadBalancer.Query().Where(loadbalancer.HasPortsWith(port.IDEQ(objID))).Only(ctx)
addSubjLoadBalancers, err := m.Client().LoadBalancer.Query().Where(loadbalancer.HasPortsWith(port.IDEQ(objID))).All(ctx)
if err == nil {
if !slices.Contains(msg.AdditionalSubjectIDs, addSubjLoadBalancer.ID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjLoadBalancer.ID)
}
for _, lb := range addSubjLoadBalancers {
if !slices.Contains(msg.AdditionalSubjectIDs, lb.ID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjLoadBalancer.LocationID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjLoadBalancer.LocationID)
}
if !slices.Contains(msg.AdditionalSubjectIDs, lb.LocationID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.LocationID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjLoadBalancer.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjLoadBalancer.OwnerID)
}
if !slices.Contains(msg.AdditionalSubjectIDs, lb.OwnerID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.OwnerID)
}

if !slices.Contains(msg.AdditionalSubjectIDs, addSubjLoadBalancer.ProviderID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, addSubjLoadBalancer.ProviderID)
if !slices.Contains(msg.AdditionalSubjectIDs, lb.ProviderID) {
msg.AdditionalSubjectIDs = append(msg.AdditionalSubjectIDs, lb.ProviderID)
}
}
}

Expand Down

0 comments on commit c02b206

Please sign in to comment.