Skip to content
This repository has been archived by the owner on Jun 27, 2023. It is now read-only.

fix: Support array parsing with length using binary expression and parenthesis #603

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mockgen/internal/tests/const_array_length/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ type I interface {
Foo() [C]int
Bar() [2]int
Baz() [math.MaxInt8]int
Qux() [1 + 2]int
Quux() [(1 + 2)]int
Corge() [math.MaxInt8 - 120]int
}
42 changes: 42 additions & 0 deletions mockgen/internal/tests/const_array_length/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 44 additions & 21 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,31 +418,14 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
case *ast.ArrayType:
ln := -1
if v.Len != nil {
var value string
switch val := v.Len.(type) {
case (*ast.BasicLit):
value = val.Value
case (*ast.Ident):
// when the length is a const defined locally
value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value
case (*ast.SelectorExpr):
// when the length is a const defined in an external package
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
if err != nil {
return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err)
}
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
if err != nil {
return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err)
}
value = ev.Value.String()
value, err := p.parseArrayLength(v.Len)
if err != nil {
return nil, err
}

x, err := strconv.Atoi(value)
ln, err = strconv.Atoi(value)
if err != nil {
return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
}
ln = x
}
t, err := p.parseType(pkg, v.Elt)
if err != nil {
Expand Down Expand Up @@ -525,6 +508,46 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
return nil, fmt.Errorf("don't know how to parse type %T", typ)
}

func (p *fileParser) parseArrayLength(expr ast.Expr) (string, error) {
switch val := expr.(type) {
case (*ast.BasicLit):
return val.Value, nil
case (*ast.Ident):
// when the length is a const defined locally
return val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value, nil
case (*ast.SelectorExpr):
// when the length is a const defined in an external package
usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
if err != nil {
return "", p.errorf(expr.Pos(), "unknown package in array length: %v", err)
}
ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
if err != nil {
return "", p.errorf(expr.Pos(), "unknown constant in array length: %v", err)
}
return ev.Value.String(), nil
case (*ast.ParenExpr):
return p.parseArrayLength(val.X)
case (*ast.BinaryExpr):
x, err := p.parseArrayLength(val.X)
if err != nil {
return "", err
}
y, err := p.parseArrayLength(val.Y)
if err != nil {
return "", err
}
biExpr := fmt.Sprintf("%s%v%s", x, val.Op, y)
tv, err := types.Eval(token.NewFileSet(), nil, token.NoPos, biExpr)
if err != nil {
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", err)
}
return tv.Value.String(), nil
default:
return "", p.errorf(expr.Pos(), "invalid expression in array length: %v", val)
}
}

// importsOfFile returns a map of package name to import path
// of the imports in file.
func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
Expand Down
2 changes: 1 addition & 1 deletion mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestParseArrayWithConstLength(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}

expects := []string{"[2]int", "[2]int", "[127]int"}
expects := []string{"[2]int", "[2]int", "[127]int", "[3]int", "[3]int", "[7]int"}
for i, e := range expects {
got := pkg.Interfaces[0].Methods[i].Out[0].Type.String(nil, "")
if got != e {
Expand Down