Skip to content

Commit

Permalink
fix: Nested generic fields not fully working, if generic type is from… (
Browse files Browse the repository at this point in the history
#1305)

* fix: Nested generic fields not fully working, if generic type is from another package

- change full name generation and support SelectorExpr
- prepend package only, if no name does not contain package

fixes #1304

* test: New tests added increase code coverage for generics
  • Loading branch information
FabianMartin authored Aug 25, 2022
1 parent 732c087 commit 9d34a76
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 20 deletions.
46 changes: 32 additions & 14 deletions generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
// splitStructName splits a generic struct name in his parts
func splitStructName(fullGenericForm string) (string, []string) {
// split only at the first '[' and remove the last ']'
if fullGenericForm[len(fullGenericForm)-1] != ']' {
return "", nil
}

genericParams := strings.SplitN(strings.TrimSpace(fullGenericForm)[:len(fullGenericForm)-1], "[", 2)
if len(genericParams) == 1 {
return "", nil
Expand Down Expand Up @@ -224,12 +228,11 @@ func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[strin
func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
switch fieldType := field.(type) {
case *ast.IndexListExpr:
spec := &TypeSpecDef{
File: file,
TypeSpec: getGenericTypeSpec(fieldType.X),
PkgPath: file.Name.Name,
fullName, err := getGenericTypeName(file, fieldType.X)
if err != nil {
return "", err
}
fullName := spec.FullName() + "["
fullName += "["

for _, index := range fieldType.Indices {
var fieldName string
Expand All @@ -252,11 +255,6 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {

return strings.TrimRight(fullName, ",") + "]", nil
case *ast.IndexExpr:
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ := getFieldType(file, file.Name)

x, err := getFieldType(file, fieldType.X)
if err != nil {
return "", err
Expand All @@ -267,18 +265,38 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", err
}

packageName := ""
if !strings.Contains(x, ".") {
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
}

return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil
}

return "", fmt.Errorf("unknown field type %#v", field)
}

func getGenericTypeSpec(field ast.Expr) *ast.TypeSpec {
func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
case *ast.Ident:
return indexType.Obj.Decl.(*ast.TypeSpec)
spec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
case *ast.ArrayType:
return indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec)
spec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
}
return nil
return "", fmt.Errorf("unknown type %#v", field)
}
145 changes: 145 additions & 0 deletions generics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,87 @@ func TestParseGenericsNames(t *testing.T) {
assert.Equal(t, string(expected), string(b))
}

func TestParametrizeStruct(t *testing.T) {
pd := PackagesDefinitions{
packages: make(map[string]*PackageDefinitions),
}
// valid
typeSpec := pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, []string]", false)
assert.Equal(t, "$test.Field-string-array_string", typeSpec.Name())

// definition contains one type params, but two type params are provided
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, string]", false)
assert.Nil(t, typeSpec)

// definition contains two type params, but only one is used
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string]", false)
assert.Nil(t, typeSpec)

// name is not a valid type name
typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string", false)
assert.Nil(t, typeSpec)

typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, [string]", false)
assert.Nil(t, typeSpec)

typeSpec = pd.parametrizeStruct(&TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
TypeParams: &ast.FieldList{List: []*ast.Field{{Names: []*ast.Ident{{Name: "T"}}}, {Names: []*ast.Ident{{Name: "T2"}}}}},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
}}, "test.Field[string, ]string]", false)
assert.Nil(t, typeSpec)
}

func TestSplitStructNames(t *testing.T) {
t.Parallel()

field, params := splitStructName("test.Field")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field]")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string")
assert.Empty(t, field)
assert.Nil(t, params)

field, params = splitStructName("test.Field[string]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string"}, params)

field, params = splitStructName("test.Field[string, []string]")
assert.Equal(t, "test.Field", field)
assert.Equal(t, []string{"string", "[]string"}, params)
}

func TestGetGenericFieldType(t *testing.T) {
field, err := getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
Expand Down Expand Up @@ -124,6 +205,34 @@ func TestGetGenericFieldType(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,int]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}},
},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,[]int]", field)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.BadExpr{},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}},
},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexListExpr{
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.BadExpr{}}},
},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
Expand All @@ -148,4 +257,40 @@ func TestGetGenericFieldType(t *testing.T) {
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}},
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}},
)
assert.NoError(t, err)
assert.Equal(t, "field.Name[string]", field)
}

func TestGetGenericTypeName(t *testing.T) {
field, err := getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field", field)

field, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.ArrayType{Elt: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}}},
)
assert.NoError(t, err)
assert.Equal(t, "test.Field", field)

field, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}},
)
assert.NoError(t, err)
assert.Equal(t, "field.Name", field)

_, err = getGenericTypeName(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.BadExpr{},
)
assert.Error(t, err)
}
6 changes: 6 additions & 0 deletions testdata/generics_property/api/api.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package api

import (
"github.com/swaggo/swag/testdata/generics_property/web"
"net/http"
)

type NestedResponse struct {
web.GenericResponse[[]string, *uint8]
}

// @Summary List Posts
// @Description Get All of the Posts
// @Accept json
Expand All @@ -12,6 +17,7 @@ import (
// @Success 200 {object} web.PostResponse "ok"
// @Success 201 {object} web.PostResponses "ok"
// @Success 202 {object} web.StringResponse "ok"
// @Success 203 {object} NestedResponse "ok"
// @Router /posts [get]
func GetPosts(w http.ResponseWriter, r *http.Request) {
}
25 changes: 20 additions & 5 deletions testdata/generics_property/expected.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
"type": "integer",
"name": "rows",
"in": "query"
},
{
"type": "string",
"name": "search",
"in": "query"
}
],
"responses": {
Expand All @@ -64,12 +59,32 @@
"schema": {
"$ref": "#/definitions/web.StringResponse"
}
},
"203": {
"description": "ok",
"schema": {
"$ref": "#/definitions/api.NestedResponse"
}
}
}
}
}
},
"definitions": {
"api.NestedResponse": {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "string"
}
},
"items2": {
"type": "integer"
}
}
},
"types.Field-string": {
"type": "object",
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion testdata/generics_property/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (String) Where(ps ...PostSelector) String {

type PostPager struct {
Pager[String, PostSelector]
Search string `json:"search" form:"search"`
Search types.Field[string] `json:"search" form:"search"`
}

type PostResponse struct {
Expand Down

0 comments on commit 9d34a76

Please sign in to comment.