Skip to content

Commit

Permalink
Auto parse external models when flag parseDependency is set (#1027)
Browse files Browse the repository at this point in the history
* auto parse external models when flag parseDependency is set

* remove log in test

* fix test

* rephrase the description for flag parseDependency
  • Loading branch information
sdghchj authored Oct 14, 2021
1 parent 732ca2c commit 9981d9f
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 47 deletions.
5 changes: 3 additions & 2 deletions cmd/swag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ var initFlags = []cli.Flag{
Usage: "Parse go files in 'vendor' folder, disabled by default",
},
&cli.BoolFlag{
Name: parseDependencyFlag,
Usage: "Parse go files in outside dependency folder, disabled by default",
Name: parseDependencyFlag,
Aliases: []string{"pd"},
Usage: "Parse go files inside dependency folder, disabled by default",
},
&cli.StringFlag{
Name: markdownFilesFlag,
Expand Down
149 changes: 105 additions & 44 deletions packages.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package swag

import (
"go/ast"
goparser "go/parser"
"go/token"
"golang.org/x/tools/go/loader"
"os"
"path/filepath"
"sort"
"strings"
Expand Down Expand Up @@ -95,51 +98,58 @@ func (pkgs *PackagesDefinitions) RangeFiles(handle func(filename string, file *a
func (pkgs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, error) {
parsedSchemas := make(map[*TypeSpecDef]*Schema)
for astFile, info := range pkgs.files {
for _, astDeclaration := range astFile.Decls {
generalDeclaration, ok := astDeclaration.(*ast.GenDecl)
if ok && generalDeclaration.Tok == token.TYPE {
for _, astSpec := range generalDeclaration.Specs {
typeSpec, ok := astSpec.(*ast.TypeSpec)
if ok {
typeSpecDef := &TypeSpecDef{
PkgPath: info.PackagePath,
File: astFile,
TypeSpec: typeSpec,
}
pkgs.parseTypesFromFile(astFile, info.PackagePath, parsedSchemas)
}
return parsedSchemas, nil
}

idt, ok := typeSpec.Type.(*ast.Ident)
if ok && IsGolangPrimitiveType(idt.Name) {
parsedSchemas[typeSpecDef] = &Schema{
PkgPath: typeSpecDef.PkgPath,
Name: astFile.Name.Name,
Schema: PrimitiveSchema(TransToValidSchemeType(idt.Name)),
}
}
func (pkgs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) {
for _, astDeclaration := range astFile.Decls {
if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE {
for _, astSpec := range generalDeclaration.Specs {
if typeSpec, ok := astSpec.(*ast.TypeSpec); ok {
typeSpecDef := &TypeSpecDef{
PkgPath: packagePath,
File: astFile,
TypeSpec: typeSpec,
}

if pkgs.uniqueDefinitions == nil {
pkgs.uniqueDefinitions = make(map[string]*TypeSpecDef)
if idt, ok := typeSpec.Type.(*ast.Ident); ok && IsGolangPrimitiveType(idt.Name) && parsedSchemas != nil {
parsedSchemas[typeSpecDef] = &Schema{
PkgPath: typeSpecDef.PkgPath,
Name: astFile.Name.Name,
Schema: PrimitiveSchema(TransToValidSchemeType(idt.Name)),
}
}

fullName := typeSpecDef.FullName()
anotherTypeDef, ok := pkgs.uniqueDefinitions[fullName]
if ok {
if typeSpecDef.PkgPath == anotherTypeDef.PkgPath {
continue
} else {
delete(pkgs.uniqueDefinitions, fullName)
}
if pkgs.uniqueDefinitions == nil {
pkgs.uniqueDefinitions = make(map[string]*TypeSpecDef)
}

fullName := typeSpecDef.FullName()
anotherTypeDef, ok := pkgs.uniqueDefinitions[fullName]
if ok {
if typeSpecDef.PkgPath == anotherTypeDef.PkgPath {
continue
} else {
pkgs.uniqueDefinitions[fullName] = typeSpecDef
delete(pkgs.uniqueDefinitions, fullName)
}
} else {
pkgs.uniqueDefinitions[fullName] = typeSpecDef
}

if pkgs.packages[typeSpecDef.PkgPath] == nil {
pkgs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{
Name: astFile.Name.Name,
TypeDefinitions: map[string]*TypeSpecDef{typeSpecDef.Name(): typeSpecDef},
}
} else if _, ok = pkgs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok {
pkgs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()] = typeSpecDef
}
}
}
}
}

return parsedSchemas, nil
}

func (pkgs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *TypeSpecDef {
Expand All @@ -157,11 +167,43 @@ func (pkgs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *
return nil
}

func (pkgs *PackagesDefinitions) loadExternalPackage(importPath string) error {
cwd, err := os.Getwd()
if err != nil {
return err
}

conf := loader.Config{
ParserMode: goparser.ParseComments,
Cwd: cwd,
}

conf.Import(importPath)

lprog, err := conf.Load()
if err != nil {
return err
}

for _, info := range lprog.AllPackages {
pkgPath := info.Pkg.Path()
if strings.HasPrefix(pkgPath, "vendor/") {
pkgPath = pkgPath[7:]
}
for _, astFile := range info.Files {
pkgs.parseTypesFromFile(astFile, pkgPath, nil)
}
}

return nil
}

// findPackagePathFromImports finds out the package path of a package via ranging imports of a ast.File
// @pkg the name of the target package
// @file current ast.File in which to search imports
// @fuzzy search for the package path that the last part matches the @pkg if true
// @return the package path of a package of @pkg.
func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File) string {
func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File, fuzzy bool) string {
if file == nil {
return ""
}
Expand All @@ -172,6 +214,14 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as

hasAnonymousPkg := false

matchLastPathPart := func(pkgPath string) bool {
paths := strings.Split(pkgPath, "/")
if paths[len(paths)-1] == pkg {
return true
}
return false
}

// prior to match named package
for _, imp := range file.Imports {
if imp.Name != nil {
Expand All @@ -186,11 +236,12 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as
}
if pkgs.packages != nil {
path := strings.Trim(imp.Path.Value, `"`)
pd, ok := pkgs.packages[path]
if ok {
if pd.Name == pkg {
if fuzzy {
if matchLastPathPart(path) {
return path
}
} else if pd, ok := pkgs.packages[path]; ok && pd.Name == pkg {
return path
}
}
}
Expand All @@ -203,11 +254,12 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as
}
if imp.Name.Name == "_" {
path := strings.Trim(imp.Path.Value, `"`)
pd, ok := pkgs.packages[path]
if ok {
if pd.Name == pkg {
if fuzzy {
if matchLastPathPart(path) {
return path
}
} else if pd, ok := pkgs.packages[path]; ok && pd.Name == pkg {
return path
}
}
}
Expand All @@ -220,7 +272,7 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as
// @typeName the name of the target type, if it starts with a package name, find its own package path from imports on top of @file
// @file the ast.file in which @typeName is used
// @pkgPath the package path of @file.
func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef {
func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef {
if IsGolangPrimitiveType(typeName) {
return nil
}
Expand Down Expand Up @@ -248,10 +300,19 @@ func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File) *
return typeDef
}
}

pkgPath := pkgs.findPackagePathFromImports(parts[0], file)
if len(pkgPath) == 0 && parts[0] == file.Name.Name {
pkgPath = pkgs.files[file].PackagePath
pkgPath := pkgs.findPackagePathFromImports(parts[0], file, false)
if len(pkgPath) == 0 {
//check if the current package
if parts[0] == file.Name.Name {
pkgPath = pkgs.files[file].PackagePath
} else if parseDependency {
//take it as an external package, needs to be loaded
if pkgPath = pkgs.findPackagePathFromImports(parts[0], file, true); len(pkgPath) > 0 {
if err := pkgs.loadExternalPackage(pkgPath); err != nil {
return nil
}
}
}
}

return pkgs.findTypeSpec(pkgPath, parts[1])
Expand Down
2 changes: 1 addition & 1 deletion parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) (
return PrimitiveSchema(schemaType), nil
}

typeSpecDef := parser.packages.FindTypeSpec(typeName, file)
typeSpecDef := parser.packages.FindTypeSpec(typeName, file, parser.ParseDependency)
if typeSpecDef == nil {
return nil, fmt.Errorf("cannot find type definition: %s", typeName)
}
Expand Down
14 changes: 14 additions & 0 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,20 @@ func TestParseConflictSchemaName(t *testing.T) {
assert.Equal(t, string(expected), string(b))
}

func TestParseExternalModels(t *testing.T) {
searchDir := "testdata/external_models/main"
mainAPIFile := "main.go"
p := New()
p.ParseDependency = true
err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, _ := json.MarshalIndent(p.swagger, "", " ")
//ioutil.WriteFile("./testdata/external_models/main/expected.json",b,0777)
expected, err := ioutil.ReadFile(filepath.Join(searchDir, "expected.json"))
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
}

func TestParser_ParseStructArrayObject(t *testing.T) {
t.Parallel()

Expand Down
7 changes: 7 additions & 0 deletions testdata/external_models/external/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package external

import "github.com/urfave/cli/v2"

type MyError struct {
cli.Author
}
18 changes: 18 additions & 0 deletions testdata/external_models/main/api/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package api

import (
"net/http"
)

// GetExternalModels example
// @Summary parse external models
// @Description get string by ID
// @ID get_external_models
// @Accept json
// @Produce json
// @Success 200 {string} string "ok"
// @Failure 400 {object} http.Header "from internal pkg"
// @Router /testapi/external_models [get]
func GetExternalModels(w http.ResponseWriter, r *http.Request) {

}
50 changes: 50 additions & 0 deletions testdata/external_models/main/expected.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"swagger": "2.0",
"info": {
"description": "Parse external models.",
"title": "Swagger Example API",
"contact": {},
"version": "1.0"
},
"basePath": "/v1",
"paths": {
"/testapi/external_models": {
"get": {
"description": "get string by ID",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"summary": "parse external models",
"operationId": "get_external_models",
"responses": {
"200": {
"description": "ok",
"schema": {
"type": "string"
}
},
"400": {
"description": "from internal pkg",
"schema": {
"$ref": "#/definitions/http.Header"
}
}
}
}
}
},
"definitions": {
"http.Header": {
"type": "object",
"additionalProperties": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
}
8 changes: 8 additions & 0 deletions testdata/external_models/main/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package main

// @title Swagger Example API
// @version 1.0
// @description Parse external models.
// @BasePath /v1
func main() {
}

0 comments on commit 9981d9f

Please sign in to comment.