diff --git a/gen/gen.go b/gen/gen.go index 39f6a7dc1..01f611ef1 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -60,6 +60,9 @@ type Config struct { // ParseInternal whether swag should parse internal packages ParseInternal bool + // Strict whether swag should error or warn when it detects cases which are most likely user errors + Strict bool + // MarkdownFilesDir used to find markdownfiles, which can be used for tag descriptions MarkdownFilesDir string @@ -85,7 +88,8 @@ func (g *Gen) Build(config *Config) error { log.Println("Generate swagger docs....") p := swag.New(swag.SetMarkdownFileDirectory(config.MarkdownFilesDir), swag.SetExcludedDirsAndFiles(config.Excludes), - swag.SetCodeExamplesDirectory(config.CodeExampleFilesDir)) + swag.SetCodeExamplesDirectory(config.CodeExampleFilesDir), + swag.SetStrict(config.Strict)) p.PropNamingStrategy = config.PropNamingStrategy p.ParseVendor = config.ParseVendor p.ParseDependency = config.ParseDependency diff --git a/gen/gen_test.go b/gen/gen_test.go index dfee63696..555fd8dd3 100644 --- a/gen/gen_test.go +++ b/gen/gen_test.go @@ -423,3 +423,22 @@ func TestGen_cgoImports(t *testing.T) { os.Remove(expectedFile) } } + +func TestGen_duplicateRoute(t *testing.T) { + searchDir := "../testdata/duplicate_route" + + config := &Config{ + SearchDir: searchDir, + MainAPIFile: "./main.go", + OutputDir: "../testdata/duplicate_route/docs", + PropNamingStrategy: "", + ParseDependency: true, + } + err := New().Build(config) + assert.NoError(t, err) + + // with Strict enabled should cause an error instead of warning about the duplicate route + config.Strict = true + err = New().Build(config) + assert.EqualError(t, err, "route GET /testapi/endpoint is declared multiple times") +} diff --git a/packages.go b/packages.go index 4e99f00ee..6507032cf 100644 --- a/packages.go +++ b/packages.go @@ -3,6 +3,8 @@ package swag import ( "go/ast" "go/token" + "path/filepath" + "sort" "strings" ) @@ -23,26 +25,30 @@ func NewPackagesDefinitions() *PackagesDefinitions { } // CollectAstFile collect ast.file. -func (pkgs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) { +func (pkgs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile *ast.File) error { if pkgs.files == nil { pkgs.files = make(map[*ast.File]*AstFileInfo) } - pkgs.files[astFile] = &AstFileInfo{ - File: astFile, - Path: path, - PackagePath: packageDir, + if pkgs.packages == nil { + pkgs.packages = make(map[string]*PackageDefinitions) } + // return without storing the file if we lack a packageDir if len(packageDir) == 0 { - return + return nil } - if pkgs.packages == nil { - pkgs.packages = make(map[string]*PackageDefinitions) + path, err := filepath.Abs(path) + if err != nil { + return err } if pd, ok := pkgs.packages[packageDir]; ok { + // return without storing the file if it already exists + if _, exists := pd.Files[path]; exists { + return nil + } pd.Files[path] = astFile } else { pkgs.packages[packageDir] = &PackageDefinitions{ @@ -51,12 +57,29 @@ func (pkgs *PackagesDefinitions) CollectAstFile(packageDir, path string, astFile TypeDefinitions: make(map[string]*TypeSpecDef), } } + + pkgs.files[astFile] = &AstFileInfo{ + File: astFile, + Path: path, + PackagePath: packageDir, + } + + return nil } -// RangeFiles for range the collection of ast.File. +// RangeFiles for range the collection of ast.File in alphabetic order. func (pkgs *PackagesDefinitions) RangeFiles(handle func(filename string, file *ast.File) error) error { - for file, info := range pkgs.files { - if err := handle(info.Path, file); err != nil { + sortedFiles := make([]*AstFileInfo, 0, len(pkgs.files)) + for _, info := range pkgs.files { + sortedFiles = append(sortedFiles, info) + } + + sort.Slice(sortedFiles, func(i, j int) bool { + return strings.Compare(sortedFiles[i].Path, sortedFiles[j].Path) < 0 + }) + + for _, info := range sortedFiles { + if err := handle(info.Path, info.File); err != nil { return err } } diff --git a/parser.go b/parser.go index 4bae91ebc..d1a2bba4b 100644 --- a/parser.go +++ b/parser.go @@ -81,6 +81,9 @@ type Parser struct { // ParseInternal whether swag should parse internal packages ParseInternal bool + // Strict whether swag should error or warn when it detects cases which are most likely user errors + Strict bool + // structStack stores full names of the structures that were already parsed or are being parsed now structStack []*TypeSpecDef @@ -159,6 +162,13 @@ func SetExcludedDirsAndFiles(excludes string) func(*Parser) { } } +// SetStrict sets whether swag should error or warn when it detects cases which are most likely user errors +func SetStrict(strict bool) func(*Parser) { + return func(p *Parser) { + p.Strict = strict + } +} + // ParseAPI parses general api info for given searchDir and mainAPIFile. func (parser *Parser) ParseAPI(searchDir string, mainAPIFile string, parseDepth int) error { return parser.ParseAPIMultiSearchDir([]string{searchDir}, mainAPIFile, parseDepth) @@ -599,23 +609,18 @@ func (parser *Parser) ParseRouterAPIInfo(fileName string, astFile *ast.File) err if pathItem, ok = parser.swagger.Paths.Paths[routeProperties.Path]; !ok { pathItem = spec.PathItem{} } - switch strings.ToUpper(routeProperties.HTTPMethod) { - case http.MethodGet: - pathItem.Get = &operation.Operation - case http.MethodPost: - pathItem.Post = &operation.Operation - case http.MethodDelete: - pathItem.Delete = &operation.Operation - case http.MethodPut: - pathItem.Put = &operation.Operation - case http.MethodPatch: - pathItem.Patch = &operation.Operation - case http.MethodHead: - pathItem.Head = &operation.Operation - case http.MethodOptions: - pathItem.Options = &operation.Operation + + // check if we already have a operation for this path and method + if hasRouteMethodOp(pathItem, routeProperties.HTTPMethod) { + err := fmt.Errorf("route %s %s is declared multiple times", routeProperties.HTTPMethod, routeProperties.Path) + if parser.Strict { + return err + } + Printf("warning: %s\n", err) } + setRouteMethodOp(&pathItem, routeProperties.HTTPMethod, &operation.Operation) + parser.swagger.Paths.Paths[routeProperties.Path] = pathItem } } @@ -625,6 +630,46 @@ func (parser *Parser) ParseRouterAPIInfo(fileName string, astFile *ast.File) err return nil } +func setRouteMethodOp(pathItem *spec.PathItem, method string, op *spec.Operation) { + switch strings.ToUpper(method) { + case http.MethodGet: + pathItem.Get = op + case http.MethodPost: + pathItem.Post = op + case http.MethodDelete: + pathItem.Delete = op + case http.MethodPut: + pathItem.Put = op + case http.MethodPatch: + pathItem.Patch = op + case http.MethodHead: + pathItem.Head = op + case http.MethodOptions: + pathItem.Options = op + } +} + +func hasRouteMethodOp(pathItem spec.PathItem, method string) bool { + switch strings.ToUpper(method) { + case http.MethodGet: + return pathItem.Get != nil + case http.MethodPost: + return pathItem.Post != nil + case http.MethodDelete: + return pathItem.Delete != nil + case http.MethodPut: + return pathItem.Put != nil + case http.MethodPatch: + return pathItem.Patch != nil + case http.MethodHead: + return pathItem.Head != nil + case http.MethodOptions: + return pathItem.Options != nil + } + + return false +} + func convertFromSpecificToPrimitive(typeName string) (string, error) { name := typeName if strings.ContainsRune(name, '.') { diff --git a/parser_test.go b/parser_test.go index 79bda7590..c541cd2e2 100644 --- a/parser_test.go +++ b/parser_test.go @@ -5,10 +5,12 @@ import ( goparser "go/parser" "go/token" "io/ioutil" + "net/http" "os" "path/filepath" "testing" + "github.com/go-openapi/spec" "github.com/stretchr/testify/assert" ) @@ -33,6 +35,14 @@ func TestSetCodeExamplesDirectory(t *testing.T) { assert.Equal(t, expected, p.codeExampleFilesDir) } +func TestSetStrict(t *testing.T) { + p := New() + assert.Equal(t, false, p.Strict) + + p = New(SetStrict(true)) + assert.Equal(t, true, p.Strict) +} + func TestParser_ParseGeneralApiInfo(t *testing.T) { expected := `{ "schemes": [ @@ -2196,6 +2206,29 @@ func Test3(){ // } // } +func TestParser_ParseRouterApiDuplicateRoute(t *testing.T) { + src := ` +package test + +// @Router /api/{id} [get] +func Test1(){ +} +// @Router /api/{id} [get] +func Test2(){ +} +` + f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + assert.NoError(t, err) + + p := New(SetStrict(true)) + err = p.ParseRouterAPIInfo("", f) + assert.EqualError(t, err, "route GET /api/{id} is declared multiple times") + + p = New() + err = p.ParseRouterAPIInfo("", f) + assert.NoError(t, err) +} + func TestApiParseTag(t *testing.T) { searchDir := "testdata/tags" p := New(SetMarkdownFileDirectory(searchDir)) @@ -2410,6 +2443,52 @@ func Fun() { assert.Equal(t, "#/definitions/Teacher", ref.String()) } +func TestPackagesDefinitions_CollectAstFileInit(t *testing.T) { + src := ` +package main + +// @Router /test [get] +func Fun() { + +} +` + f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + assert.NoError(t, err) + + pkgs := NewPackagesDefinitions() + + // unset the .files and .packages and check that they're re-initialized by CollectAstFile + pkgs.packages = nil + pkgs.files = nil + + pkgs.CollectAstFile("api", "api/api.go", f) + assert.NotNil(t, pkgs.packages) + assert.NotNil(t, pkgs.files) +} + +func TestCollectAstFileMultipleTimes(t *testing.T) { + src := ` +package main + +// @Router /test [get] +func Fun() { + +} +` + f, err := goparser.ParseFile(token.NewFileSet(), "", src, goparser.ParseComments) + assert.NoError(t, err) + + p := New() + p.packages.CollectAstFile("api", "api/api.go", f) + assert.NotNil(t, p.packages.files[f]) + + astFileInfo := p.packages.files[f] + + // if we collect the same again nothing should happen + p.packages.CollectAstFile("api", "api/api.go", f) + assert.Equal(t, astFileInfo, p.packages.files[f]) +} + func TestParseJSONFieldString(t *testing.T) { expected := `{ "swagger": "2.0", @@ -2603,6 +2682,64 @@ func TestDefineTypeOfExample(t *testing.T) { assert.Nil(t, example) } +func TestSetRouteMethodOp(t *testing.T) { + op := spec.NewOperation("dummy") + + // choosing to test each method explicitly instead of table driven to avoid reliance on helpers + + pathItem := spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodGet, op) + assert.Equal(t, op, pathItem.Get) + + pathItem = spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodPost, op) + assert.Equal(t, op, pathItem.Post) + + pathItem = spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodDelete, op) + assert.Equal(t, op, pathItem.Delete) + + pathItem = spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodPut, op) + assert.Equal(t, op, pathItem.Put) + + pathItem = spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodPatch, op) + assert.Equal(t, op, pathItem.Patch) + + pathItem = spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodHead, op) + assert.Equal(t, op, pathItem.Head) + + pathItem = spec.PathItem{} + setRouteMethodOp(&pathItem, http.MethodOptions, op) + assert.Equal(t, op, pathItem.Options) +} + +func TestHasRouteMethodOp(t *testing.T) { + pathItem := spec.PathItem{} + + // assert that an invalid http method produces false + assert.False(t, hasRouteMethodOp(pathItem, "OOPSIE")) + + // test each (supported) http method + httpMethods := []string{ + http.MethodGet, http.MethodPost, http.MethodDelete, http.MethodPut, + http.MethodPatch, http.MethodHead, http.MethodOptions, + } + for _, httpMethod := range httpMethods { + pathItem = spec.PathItem{} + + // should be false before setting + assert.False(t, hasRouteMethodOp(pathItem, httpMethod)) + + // and true after we set it + // we rely on setRouteMethodOp, which is tested more thoroughly above + setRouteMethodOp(&pathItem, httpMethod, spec.NewOperation("dummy")) + assert.True(t, hasRouteMethodOp(pathItem, httpMethod)) + } +} + type mockFS struct { os.FileInfo FileName string diff --git a/testdata/duplicate_route/api/api.go b/testdata/duplicate_route/api/api.go new file mode 100644 index 000000000..cf7daed85 --- /dev/null +++ b/testdata/duplicate_route/api/api.go @@ -0,0 +1,17 @@ +package api + +import ( + "net/http" + + _ "github.com/swaggo/swag/testdata/simple/web" +) + +// @Router /testapi/endpoint [get] +func FunctionOne(w http.ResponseWriter, r *http.Request) { + //write your code +} + +// @Router /testapi/endpoint [get] +func FunctionTwo(w http.ResponseWriter, r *http.Request) { + //write your code +} diff --git a/testdata/duplicate_route/main.go b/testdata/duplicate_route/main.go new file mode 100644 index 000000000..1330c302b --- /dev/null +++ b/testdata/duplicate_route/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "net/http" + + "github.com/swaggo/swag/testdata/duplicate_route/api" +) + +func main() { + http.HandleFunc("/testapi/endpoint", api.FunctionOne) + http.HandleFunc("/testapi/endpoint", api.FunctionTwo) + http.ListenAndServe(":8080", nil) +}