Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

Commit

Permalink
fix: Support array parsing with length using binary expression and pa…
Browse files Browse the repository at this point in the history
…renthesis (#603)

Fixes #575
  • Loading branch information
sryoya authored Dec 30, 2021
1 parent d0edad8 commit cf7e215
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 22 deletions.
3 changes: 3 additions & 0 deletions mockgen/internal/tests/const_array_length/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ type I interface {
Foo() [C]int
Bar() [2]int
Baz() [math.MaxInt8]int
Qux() [1 + 2]int
Quux() [(1 + 2)]int
Corge() [math.MaxInt8 - 120]int
}
42 changes: 42 additions & 0 deletions mockgen/internal/tests/const_array_length/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 44 additions & 21 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,31 +418,14 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
case *ast.ArrayType:
ln := -1
if v.Len != nil {
var value string
switch val := v.Len.(type) {
case (*ast.BasicLit):
value = val.Value
case (*ast.Ident):
// when the length is a const defined locally
value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value
case (*ast.SelectorExpr):
// when the length is a const defined in an external package
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
if err != nil {
return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err)
}
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
if err != nil {
return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err)
}
value = ev.Value.String()
value, err := p.parseArrayLength(v.Len)
if err != nil {
return nil, err
}

x, err := strconv.Atoi(value)
ln, err = strconv.Atoi(value)
if err != nil {
return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
}
ln = x
}
t, err := p.parseType(pkg, v.Elt)
if err != nil {
Expand Down Expand Up @@ -525,6 +508,46 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
return nil, fmt.Errorf("don't know how to parse type %T", typ)
}

func (p *fileParser) parseArrayLength(expr ast.Expr) (string, error) {
switch val := expr.(type) {
case (*ast.BasicLit):
return val.Value, nil
case (*ast.Ident):
// when the length is a const defined locally
return val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value, nil
case (*ast.SelectorExpr):
// when the length is a const defined in an external package
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
if err != nil {
return "", p.errorf(expr.Pos(), "unknown package in array length: %v", err)
}
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
if err != nil {
return "", p.errorf(expr.Pos(), "unknown constant in array length: %v", err)
}
return ev.Value.String(), nil
case (*ast.ParenExpr):
return p.parseArrayLength(val.X)
case (*ast.BinaryExpr):
x, err := p.parseArrayLength(val.X)
if err != nil {
return "", err
}
y, err := p.parseArrayLength(val.Y)
if err != nil {
return "", err
}
biExpr := fmt.Sprintf("%s%v%s", x, val.Op, y)
tv, err := types.Eval(token.NewFileSet(), nil, token.NoPos, biExpr)
if err != nil {
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", err)
}
return tv.Value.String(), nil
default:
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", val)
}
}

// importsOfFile returns a map of package name to import path
// of the imports in file.
func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
Expand Down
2 changes: 1 addition & 1 deletion mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestParseArrayWithConstLength(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}

expects := []string{"[2]int", "[2]int", "[127]int"}
expects := []string{"[2]int", "[2]int", "[127]int", "[3]int", "[3]int", "[7]int"}
for i, e := range expects {
got := pkg.Interfaces[0].Methods[i].Out[0].Type.String(nil, "")
if got != e {
Expand Down

0 comments on commit cf7e215

Please sign in to comment.