Skip to content

Commit

Permalink
feat(subtract): added Subtract function to wire
Browse files Browse the repository at this point in the history
Signed-off-by: Giau. Tran Minh <hello@giautm.dev>
  • Loading branch information
giautm committed Apr 22, 2023
1 parent 0675cdc commit 5bbd21a
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 21 deletions.
118 changes: 115 additions & 3 deletions internal/wire/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
case "NewSet":
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "Subtract":
pset, errs := oc.processSubtract(info, pkgPath, call, nil, varName)
return pset, notePositionAll(exprPos, errs)
case "Bind":
b, err := processBind(oc.fset, info, call)
if err != nil {
Expand Down Expand Up @@ -590,6 +593,115 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
return nil, []error{notePosition(exprPos, errors.New("unknown pattern"))}
}

func (oc *objectCache) filterType(set *ProviderSet, t types.Type) []error {
hasType := func(outs []types.Type) bool {
for _, o := range outs {
if types.Identical(o, t) {
return true
}
pt, ok := o.(*types.Pointer)
if ok && types.Identical(pt.Elem(), t) {
return true
}
}
return false
}

providers := make([]*Provider, 0, len(set.Providers))
for _, p := range set.Providers {
if !hasType(p.Out) {
providers = append(providers, p)
}
}
set.Providers = providers

bindings := make([]*IfaceBinding, 0, len(set.Bindings))
for _, i := range set.Bindings {
if !types.Identical(i.Iface, t) {
bindings = append(bindings, i)
}
}
set.Bindings = bindings

values := make([]*Value, 0, len(set.Values))
for _, v := range set.Values {
if !types.Identical(v.Out, t) {
values = append(values, v)
}
}
set.Values = values

fields := make([]*Field, 0, len(set.Fields))
for _, f := range set.Fields {
if !hasType(f.Out) {
fields = append(fields, f)
}
}
set.Fields = fields

imports := make([]*ProviderSet, 0, len(set.Imports))
for _, p := range set.Imports {
clone := *p
if errs := oc.filterType(&clone, t); len(errs) > 0 {
return errs
}
imports = append(imports, &clone)
}
set.Imports = imports

var errs []error
set.providerMap, set.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, set)
if len(errs) > 0 {
return errs
}
return nil
}

func (oc *objectCache) processSubtract(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (interface{}, []error) {
// Assumes that call.Fun is wire.Subtract.
if len(call.Args) < 2 {
return nil, []error{notePosition(oc.fset.Position(call.Pos()),
errors.New("call to Subtract must specify types to be subtracted"))}
}
firstArg, errs := oc.processExpr(info, pkgPath, call.Args[0], "")
if len(errs) > 0 {
return nil, errs
}
set, ok := firstArg.(*ProviderSet)
if !ok {
return nil, []error{notePosition(oc.fset.Position(call.Pos()),
fmt.Errorf("first argument to Subtract must be a Set")),
}
}
pset := &ProviderSet{
Pos: call.Pos(),
InjectorArgs: args,
PkgPath: pkgPath,
VarName: varName,
// Copy the other fields.
Providers: set.Providers,
Bindings: set.Bindings,
Values: set.Values,
Fields: set.Fields,
Imports: set.Imports,
}
ec := new(errorCollector)
for _, arg := range call.Args[1:] {
ptr, ok := info.TypeOf(arg).(*types.Pointer)
if !ok {
ec.add(notePosition(oc.fset.Position(arg.Pos()),
errors.New("argument to Subtract must be a pointer"),
))
continue
}
ec.add(oc.filterType(pset, ptr.Elem())...)
}
if len(ec.errors) > 0 {
return nil, ec.errors
}
return pset, nil
}

func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (*ProviderSet, []error) {
// Assumes that call.Fun is wire.NewSet or wire.Build.

Expand Down Expand Up @@ -1173,9 +1285,9 @@ func (pt ProvidedType) IsNil() bool {
//
// - For a function provider, this is the first return value type.
// - For a struct provider, this is either the struct type or the pointer type
// whose element type is the struct type.
// - For a value, this is the type of the expression.
// - For an argument, this is the type of the argument.
// whose element type is the struct type.
// - For a value, this is the type of the expression.
// - For an argument, this is the type of the argument.
func (pt ProvidedType) Type() types.Type {
return pt.t
}
Expand Down
66 changes: 66 additions & 0 deletions internal/wire/testdata/Subtract/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"github.com/google/wire"
)

type context struct{}

func main() {}

type FooOptions struct{}
type Foo string
type Bar struct{}
type BarName string

func (b *Bar) Bar() {}

func provideFooOptions() *FooOptions {
return &FooOptions{}
}

func provideFoo(*FooOptions) Foo {
return Foo("foo")
}

func provideBar(Foo, BarName) *Bar {
return &Bar{}
}

type BarService interface {
Bar()
}

type FooBar struct {
BarService
Foo
}

var Set = wire.NewSet(
provideFooOptions,
provideFoo,
provideBar,
)

var SuperSet = wire.NewSet(Set,
wire.Struct(new(FooBar), "*"),
wire.Bind(new(BarService), new(*Bar)),
)

type FakeBarService struct{}

func (f *FakeBarService) Bar() {}
49 changes: 49 additions & 0 deletions internal/wire/testdata/Subtract/foo/wire.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright 2018 The Wire Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build wireinject
// +build wireinject

package main

import (
// "strings"

"github.com/google/wire"
)

func inject(name BarName, opts *FooOptions) *Bar {
panic(wire.Build(wire.Subtract(Set, new(FooOptions))))
}

func injectBarService(name BarName, opts *FakeBarService) *FooBar {
panic(wire.Build(
wire.Subtract(SuperSet, new(BarService)),
wire.Bind(new(BarService), new(*FakeBarService)),
))
}

func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar {
panic(wire.Build(
wire.Subtract(SuperSet, new(FooOptions), new(BarService)),
wire.Bind(new(BarService), new(*FakeBarService)),
))
}

func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar {
panic(wire.Build(
wire.Subtract(SuperSet, new(Foo), new(BarService)),
wire.Bind(new(BarService), new(*FakeBarService)),
))
}
1 change: 1 addition & 0 deletions internal/wire/testdata/Subtract/pkg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
example.com/foo
Empty file.
42 changes: 42 additions & 0 deletions internal/wire/testdata/Subtract/want/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 22 additions & 18 deletions wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ func NewSet(...interface{}) ProviderSet {
return ProviderSet{}
}

func Subtract(...interface{}) ProviderSet {
return ProviderSet{}
}

// Build is placed in the body of an injector function template to declare the
// providers to use. The Wire code generation tool will fill in an
// implementation of the function. The arguments to Build are interpreted the
Expand Down Expand Up @@ -156,12 +160,12 @@ type StructProvider struct{}
//
// For example:
//
// type S struct {
// MyFoo *Foo
// MyBar *Bar
// }
// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo
// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields
// type S struct {
// MyFoo *Foo
// MyBar *Bar
// }
// var Set = wire.NewSet(wire.Struct(new(S), "MyFoo")) -> inject only S.MyFoo
// var Set = wire.NewSet(wire.Struct(new(S), "*")) -> inject all fields
func Struct(structType interface{}, fieldNames ...string) StructProvider {
return StructProvider{}
}
Expand All @@ -175,22 +179,22 @@ type StructFields struct{}
//
// The following example would provide Foo and Bar using S.MyFoo and S.MyBar respectively:
//
// type S struct {
// MyFoo Foo
// MyBar Bar
// }
// type S struct {
// MyFoo Foo
// MyBar Bar
// }
//
// func NewStruct() S { /* ... */ }
// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar"))
// func NewStruct() S { /* ... */ }
// var Set = wire.NewSet(wire.FieldsOf(new(S), "MyFoo", "MyBar"))
//
// or
// or
//
// func NewStruct() *S { /* ... */ }
// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar"))
// func NewStruct() *S { /* ... */ }
// var Set = wire.NewSet(wire.FieldsOf(new(*S), "MyFoo", "MyBar"))
//
// If the structType argument is a pointer to a pointer to a struct, then FieldsOf
// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the
// example above).
// If the structType argument is a pointer to a pointer to a struct, then FieldsOf
// additionally provides a pointer to each field type (e.g., *Foo and *Bar in the
// example above).
func FieldsOf(structType interface{}, fieldNames ...string) StructFields {
return StructFields{}
}

0 comments on commit 5bbd21a

Please sign in to comment.