diff --git a/internal/registry.go b/internal/registry.go index a09ec275a..84a9eec5e 100644 --- a/internal/registry.go +++ b/internal/registry.go @@ -92,7 +92,7 @@ func (r *registry) RegisterWorkflowWithOptions( defer r.Unlock() if !options.DisableAlreadyRegisteredCheck { - if _, ok := r.workflowFuncMap[registerName]; ok { + if _, ok := r.getWorkflowNoLock(registerName); ok { panic(fmt.Sprintf("workflow name \"%v\" is already registered", registerName)) } } @@ -141,7 +141,7 @@ func (r *registry) registerActivityFunction(af interface{}, options RegisterActi defer r.Unlock() if !options.DisableAlreadyRegisteredCheck { - if _, ok := r.activityFuncMap[registerName]; ok { + if _, ok := r.getActivityNoLock(registerName); ok { return fmt.Errorf("activity type \"%v\" is already registered", registerName) } } @@ -231,6 +231,14 @@ func (r *registry) getWorkflowFn(fnName string) (interface{}, bool) { return fn, ok } +func (r *registry) getWorkflowNoLock(registerName string) (interface{}, bool) { + a, ok := r.workflowFuncMap[registerName] + if !ok && r.next != nil { + return r.next.getWorkflowNoLock(registerName) + } + return a, ok +} + func (r *registry) getRegisteredWorkflowTypes() []string { r.Lock() // do not defer for Unlock to call next.getRegisteredWorkflowTypes without lock var result []string @@ -277,10 +285,10 @@ func (r *registry) GetActivity(fnName string) (activity, bool) { return a, ok } -func (r *registry) getActivityNoLock(fnName string) (activity, bool) { - a, ok := r.activityFuncMap[fnName] +func (r *registry) getActivityNoLock(registerName string) (activity, bool) { + a, ok := r.activityFuncMap[registerName] if !ok && r.next != nil { - return r.next.getActivityNoLock(fnName) + return r.next.getActivityNoLock(registerName) } return a, ok } diff --git a/internal/registry_test.go b/internal/registry_test.go index dccf4e504..4e9c19d01 100644 --- a/internal/registry_test.go +++ b/internal/registry_test.go @@ -31,6 +31,7 @@ func TestWorkflowRegistration(t *testing.T) { tests := []struct { msg string register func(r *registry) + registerPanic bool workflowType string altWorkflowType string resolveByFunction interface{} @@ -66,11 +67,41 @@ func TestWorkflowRegistration(t *testing.T) { altWorkflowType: "go.uber.org/cadence/internal.(*testWorkflowStruct).Method-fm", resolveByFunction: w.Method, }, + { + msg: "register duplicated workflow in one registry (should panic)", + register: func(r *registry) { + r.RegisterWorkflow(testWorkflowFunction) + r.RegisterWorkflow(testWorkflowFunction) + }, + registerPanic: true, + }, + { + msg: "register duplicated workflow with already registered check disabled", + register: func(r *registry) { + r.RegisterWorkflow(testWorkflowFunction) + r.RegisterWorkflowWithOptions(testWorkflowFunction, RegisterWorkflowOptions{DisableAlreadyRegisteredCheck: true}) + }, + workflowType: "go.uber.org/cadence/internal.testWorkflowFunction", + resolveByFunction: testWorkflowFunction, + }, + { + msg: "register duplicated workflow in chained registry (should panic)", + register: func(r *registry) { + r.next.RegisterWorkflow(testWorkflowFunction) + r.RegisterWorkflow(testWorkflowFunction) + }, + registerPanic: true, + }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { r := newRegistry() + if tt.registerPanic { + require.Panics(t, func() { tt.register(r) }, "register should panic") + return + } + tt.register(r) // Verify registered workflow type @@ -104,6 +135,7 @@ func TestActivityRegistration(t *testing.T) { tests := []struct { msg string register func(r *registry) + registerPanic bool activityType string altActivityType string resolveByFunction interface{} @@ -156,10 +188,41 @@ func TestActivityRegistration(t *testing.T) { resolveByFunction: (&testActivityStruct{}).Method, resolveByAlias: "prefix.Method", }, + { + msg: "register duplicated activity function in one registry (should panic)", + register: func(r *registry) { + duplicatedActivityAlias := "activity.alias" + r.RegisterActivityWithOptions(testActivityFunction, RegisterActivityOptions{Name: duplicatedActivityAlias}) + r.RegisterActivityWithOptions(testActivityFunction, RegisterActivityOptions{Name: duplicatedActivityAlias}) + }, + registerPanic: true, + }, + { + msg: "register duplicated activity struct with already registered check disabled", + register: func(r *registry) { + r.RegisterActivity(&testActivityStruct{}) + r.RegisterActivityWithOptions(&testActivityStruct{}, RegisterActivityOptions{DisableAlreadyRegisteredCheck: true}) + }, + activityType: "go.uber.org/cadence/internal.(*testActivityStruct).Method", + resolveByFunction: (&testActivityStruct{}).Method, + }, + { + msg: "register duplicated activity function in chained registry (should panic)", + register: func(r *registry) { + r.next.RegisterActivity(testActivityFunction) + r.RegisterActivity(testActivityFunction) + }, + registerPanic: true, + }, } for _, tt := range tests { t.Run(tt.msg, func(t *testing.T) { r := newRegistry() + if tt.registerPanic { + require.Panics(t, func() { tt.register(r) }, "register should panic") + return + } + tt.register(r) // Verify registered activity type