diff --git a/api/turing/webhook/constant.go b/api/turing/webhook/constant.go index 68c0f6ce7..ff5702a3f 100644 --- a/api/turing/webhook/constant.go +++ b/api/turing/webhook/constant.go @@ -1,26 +1,41 @@ package webhook -import "github.com/caraml-dev/mlp/api/pkg/webhooks" +import ( + "errors" + + "github.com/caraml-dev/mlp/api/pkg/webhooks" +) + +var ( + OnRouterCreated = webhooks.EventType("OnRouterCreated") + OnRouterUpdated = webhooks.EventType("OnRouterUpdated") + OnRouterDeleted = webhooks.EventType("OnRouterDeleted") + OnRouterDeployed = webhooks.EventType("OnRouterDeployed") + OnRouterUndeployed = webhooks.EventType("OnRouterUndeployed") + + OnEnsemblerCreated = webhooks.EventType("OnEnsemblerCreated") + OnEnsemblerUpdated = webhooks.EventType("OnEnsemblerUpdated") + OnEnsemblerDeleted = webhooks.EventType("OnEnsemblerDeleted") +) var ( - OnRouterCreated = webhooks.EventType("on-router-created") - OnRouterUpdated = webhooks.EventType("on-router-updated") - OnRouterDeleted = webhooks.EventType("on-router-deleted") - OnRouterDeployed = webhooks.EventType("on-router-deployed") - OnRouterUndeployed = webhooks.EventType("on-router-undeployed") + // event list for router event + eventListRouter = map[webhooks.EventType]bool{ + OnRouterCreated: true, + OnRouterUpdated: true, + OnRouterDeleted: true, + OnRouterDeployed: true, + OnRouterUndeployed: true, + } - OnEnsemblerCreated = webhooks.EventType("on-ensembler-created") - OnEnsemblerUpdated = webhooks.EventType("on-ensembler-updated") - OnEnsemblerDeleted = webhooks.EventType("on-ensembler-deleted") + // event list for ensembler event + eventListEnsembler = map[webhooks.EventType]bool{ + OnEnsemblerCreated: true, + OnEnsemblerUpdated: true, + OnEnsemblerDeleted: true, + } ) -var eventList = []webhooks.EventType{ - OnRouterCreated, - OnRouterUpdated, - OnRouterDeleted, - OnRouterDeployed, - OnRouterUndeployed, - OnEnsemblerCreated, - OnEnsemblerUpdated, - OnEnsemblerDeleted, -} +var ( + ErrInvalidEventType = errors.New("invalid event type") +) diff --git a/api/turing/webhook/webhook.go b/api/turing/webhook/webhook.go index 7421bb3fe..173b68011 100644 --- a/api/turing/webhook/webhook.go +++ b/api/turing/webhook/webhook.go @@ -14,7 +14,17 @@ type Client interface { } func NewWebhook(cfg *webhooks.Config) (Client, error) { - webhookManager, err := webhooks.InitializeWebhooks(cfg, eventList) + var eventTypeList []webhooks.EventType + + for eventType := range eventListRouter { + eventTypeList = append(eventTypeList, eventType) + } + + for eventType := range eventListEnsembler { + eventTypeList = append(eventTypeList, eventType) + } + + webhookManager, err := webhooks.InitializeWebhooks(cfg, eventTypeList) if err != nil { return nil, err } @@ -51,6 +61,10 @@ func (w *webhook) isEventConfigured(eventType webhooks.EventType) bool { } func (w *webhook) TriggerRouterEvent(ctx context.Context, eventType webhooks.EventType, router *models.Router) error { + if isValid := eventListRouter[eventType]; !isValid { + return ErrInvalidEventType + } + body := &routerRequest{ EventType: eventType, Router: router, @@ -64,6 +78,10 @@ func (w *webhook) TriggerEnsemblerEvent( eventType webhooks.EventType, ensembler models.EnsemblerLike, ) error { + if isValid := eventListEnsembler[eventType]; !isValid { + return ErrInvalidEventType + } + body := &ensemblerRequest{ EventType: eventType, Ensembler: ensembler, diff --git a/api/turing/webhook/webhook_test.go b/api/turing/webhook/webhook_test.go index 4f4c5b706..c2c002492 100644 --- a/api/turing/webhook/webhook_test.go +++ b/api/turing/webhook/webhook_test.go @@ -190,6 +190,19 @@ func Test_webhook_TriggerRouterEvent(t *testing.T) { ).Once().Return(nil) }, }, + { + name: "negative - invalid event type", + fields: fields{ + manager: mockWebhookManager, + }, + args: args{ + ctx: context.TODO(), + eventType: OnEnsemblerCreated, + router: &models.Router{}, + }, + mockFunc: func(args args) {}, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -245,6 +258,19 @@ func Test_webhook_TriggerEnsemblerEvent(t *testing.T) { ).Once().Return(nil) }, }, + { + name: "negative - invalid event type", + fields: fields{ + manager: mockWebhookManager, + }, + args: args{ + ctx: context.TODO(), + eventType: OnRouterCreated, + ensembler: &models.GenericEnsembler{}, + }, + mockFunc: func(args args) {}, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {