diff --git a/interfaces/generate.go b/interfaces/generate.go index bb17643..8c00ab4 100644 --- a/interfaces/generate.go +++ b/interfaces/generate.go @@ -8,6 +8,7 @@ import ( "os" "path" "reflect" + "regexp" "strings" "text/template" @@ -87,6 +88,19 @@ func Generate(clients []any, dir string, opts ...Option) error { return nil } +func normalizedGenericTypeName(str string) string { + // Generic output types have the full import path in the string value, so we need to normalize it + pattern := regexp.MustCompile(`\[(.*?)\]`) + groups := pattern.FindStringSubmatch((str)) + if len(groups) < 2 { + return str + } + + typeName := groups[1] + normalizedGenericTypeName := strings.Split(typeName, "/") + return pattern.ReplaceAllString(str, "["+normalizedGenericTypeName[len(normalizedGenericTypeName)-1]+"]") +} + // Adapted from https://stackoverflow.com/a/54129236 func signature(name string, f any) string { t := reflect.TypeOf(f) @@ -117,7 +131,7 @@ func signature(name string, f any) string { if i > 0 { buf.WriteString(", ") } - buf.WriteString(t.Out(i).String()) + buf.WriteString(normalizedGenericTypeName(t.Out(i).String())) } if numOut > 1 { buf.WriteString(")") diff --git a/interfaces/generate_test.go b/interfaces/generate_test.go index b346fd4..b483a7a 100644 --- a/interfaces/generate_test.go +++ b/interfaces/generate_test.go @@ -11,6 +11,11 @@ import ( "github.com/google/go-cmp/cmp" ) +type Response struct { +} +type Pager[T any] struct { +} + type Client struct{} func (*Client) ListThings() error { @@ -25,6 +30,10 @@ func (*Client) CreateTables(_ context.Context, _ []string) error { return nil } +func (*Client) NewPager(_ context.Context) *Pager[Response] { + return nil +} + var wantOutput = `// Code generated by codegen; DO NOT EDIT. package services @@ -36,6 +45,7 @@ import ( //go:generate mockgen -package=mocks -destination=../mocks/interfaces.go -source=interfaces.go InterfacesClient type InterfacesClient interface { ListTables(context.Context) error + NewPager(context.Context) *interfaces.Pager[interfaces.Response] } ` @@ -43,10 +53,9 @@ func TestGenerate(t *testing.T) { dir := t.TempDir() err := Generate([]any{&Client{}}, dir, WithIncludeFunc(func(m reflect.Method) bool { - return MethodHasAnyPrefix(m, []string{"List"}) && MethodHasAnySuffix(m, []string{"Tables"}) + return MethodHasAnyPrefix(m, []string{"List"}) && MethodHasAnySuffix(m, []string{"Tables"}) || MethodHasAnyPrefix(m, []string{"NewPager"}) }), - WithExtraImports(func(m reflect.Method) []string { return []string{"net/http"} }, - )) + WithExtraImports(func(m reflect.Method) []string { return []string{"net/http"} })) if err != nil { t.Fatalf("unexpected error: %v", err) }