Skip to content

Commit

Permalink
Add typed nil validation to dsl.Security (#3574)
Browse files Browse the repository at this point in the history
* Add a test case for eval.InvalidArgError()

* Add typed nil validation to dsl.Security

* Remove unnecessary blocks
  • Loading branch information
tchssk authored Aug 9, 2024
1 parent f0108a7 commit 57a4260
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 37 deletions.
51 changes: 25 additions & 26 deletions dsl/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,35 +228,34 @@ func JWTSecurity(name string, fn ...func()) *expr.SchemeExpr {
// })
func Security(args ...any) {
var dsl func()
{
if d, ok := args[len(args)-1].(func()); ok {
args = args[:len(args)-1]
dsl = d
}
}

var schemes []*expr.SchemeExpr
{
schemes = make([]*expr.SchemeExpr, len(args))
for i, arg := range args {
switch val := arg.(type) {
case string:
for _, s := range expr.Root.Schemes {
if s.SchemeName == val {
schemes[i] = expr.DupScheme(s)
break
}
}
if schemes[i] == nil {
eval.ReportError("security scheme %q not found", val)
return
if d, ok := args[len(args)-1].(func()); ok {
args = args[:len(args)-1]
dsl = d
}

schemes := make([]*expr.SchemeExpr, len(args))
for i, arg := range args {
switch val := arg.(type) {
case string:
for _, s := range expr.Root.Schemes {
if s.SchemeName == val {
schemes[i] = expr.DupScheme(s)
break
}
case *expr.SchemeExpr:
schemes[i] = expr.DupScheme(val)
default:
eval.InvalidArgError("security scheme or security scheme name", val)
}
if schemes[i] == nil {
eval.ReportError("security scheme %q not found", val)
return
}
case *expr.SchemeExpr:
if val == nil {
eval.InvalidArgError("security scheme", val)
return
}
schemes[i] = expr.DupScheme(val)
default:
eval.InvalidArgError("security scheme or security scheme name", val)
return
}
}

Expand Down
23 changes: 12 additions & 11 deletions eval/eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ func TestInvalidArgError(t *testing.T) {
dsl func()
want string
}{
"Attribute": {func() { Type("name", func() { Attribute("name", String, "description", 1) }) }, "cannot use 1 (type int) as type func()"},
"Body": {func() { Service("s", func() { Method("m", func() { HTTP(func() { Body(1) }) }) }) }, "cannot use 1 (type int) as type attribute name, user type or DSL"},
"ErrorName (bool)": {func() { Type("name", func() { ErrorName(true) }) }, "cannot use true (type bool) as type name or position"},
"ErrorName (int)": {func() { Type("name", func() { ErrorName(1, 2) }) }, "cannot use 2 (type int) as type name"},
"Example": {func() { Example(1, 2) }, "cannot use 1 (type int) as type summary (string)"},
"Headers": {func() { Headers(1) }, "cannot use 1 (type int) as type function"},
"Param": {func() { API("name", func() { HTTP(func() { Params(1) }) }) }, "cannot use 1 (type int) as type function"},
"Response": {func() { Service("s", func() { HTTP(func() { Response(1) }) }) }, "cannot use 1 (type int) as type name of error"},
"ResultType": {func() { ResultType("identifier", 1) }, "cannot use 1 (type int) as type function or string"},
"Security": {func() { Security(1) }, "cannot use 1 (type int) as type security scheme or security scheme name"},
"Type": {func() { Type("name", 1) }, "cannot use 1 (type int) as type type or function"},
"Attribute": {func() { Type("name", func() { Attribute("name", String, "description", 1) }) }, "cannot use 1 (type int) as type func()"},
"Body": {func() { Service("s", func() { Method("m", func() { HTTP(func() { Body(1) }) }) }) }, "cannot use 1 (type int) as type attribute name, user type or DSL"},
"ErrorName (bool)": {func() { Type("name", func() { ErrorName(true) }) }, "cannot use true (type bool) as type name or position"},
"ErrorName (int)": {func() { Type("name", func() { ErrorName(1, 2) }) }, "cannot use 2 (type int) as type name"},
"Example": {func() { Example(1, 2) }, "cannot use 1 (type int) as type summary (string)"},
"Headers": {func() { Headers(1) }, "cannot use 1 (type int) as type function"},
"Param": {func() { API("name", func() { HTTP(func() { Params(1) }) }) }, "cannot use 1 (type int) as type function"},
"Response": {func() { Service("s", func() { HTTP(func() { Response(1) }) }) }, "cannot use 1 (type int) as type name of error"},
"ResultType": {func() { ResultType("identifier", 1) }, "cannot use 1 (type int) as type function or string"},
"Security": {func() { Security(1) }, "cannot use 1 (type int) as type security scheme or security scheme name"},
"Security (typed nil)": {func() { Security((*expr.SchemeExpr)(nil)) }, "cannot use (*expr.SchemeExpr)(nil) (type *expr.SchemeExpr) as type security scheme"},
"Type": {func() { Type("name", 1) }, "cannot use 1 (type int) as type type or function"},
}
for name, tc := range dsls {
t.Run(name, func(t *testing.T) {
Expand Down

0 comments on commit 57a4260

Please sign in to comment.