From 6d215f660d20c264ffac2d6e6c945df790d99d72 Mon Sep 17 00:00:00 2001 From: datelier <57349093+datelier@users.noreply.github.com> Date: Thu, 23 Jul 2020 14:05:39 +0900 Subject: [PATCH] tensorflow savedmodel warmup (#539) * tensorflow savedmodel warmup Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> * fix warmup Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> * fix DeepSource issue: Empty string test can be improved Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> * fix test checkFunc Signed-off-by: datelier <57349093+datelier@users.noreply.github.com> Co-authored-by: Yusuke Kato --- internal/core/converter/tensorflow/option.go | 33 ++- .../core/converter/tensorflow/option_test.go | 206 +++++++++++++++++- .../core/converter/tensorflow/tensorflow.go | 42 ++-- .../converter/tensorflow/tensorflow_test.go | 191 ++++++++++++++-- 4 files changed, 424 insertions(+), 48 deletions(-) diff --git a/internal/core/converter/tensorflow/option.go b/internal/core/converter/tensorflow/option.go index a8156252e5..70ed8e32ec 100644 --- a/internal/core/converter/tensorflow/option.go +++ b/internal/core/converter/tensorflow/option.go @@ -17,14 +17,19 @@ // Package tensorflow provides implementation of Go API for extract data to vector package tensorflow +import ( + tf "github.com/tensorflow/tensorflow/tensorflow/go" +) + // Option is tensorflow configure. type Option func(*tensorflow) var ( defaultOpts = []Option{ - WithOperations(), // set to default - WithSessionOptions(nil), // set to default - WithNdim(0), // set to default + withLoadFunc(tf.LoadSavedModel), // set to default + WithOperations(), // set to default + WithSessionOptions(nil), // set to default + WithNdim(0), // set to default } ) @@ -102,6 +107,15 @@ func WithTags(tags ...string) Option { } } +func withLoadFunc( + loadFunc func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error)) Option { + return func(t *tensorflow) { + if loadFunc != nil { + t.loadFunc = loadFunc + } + } +} + // WithFeed returns Option that sets feeds. func WithFeed(operationName string, outputIndex int) Option { return func(t *tensorflow) { @@ -138,6 +152,19 @@ func WithFetches(operationNames []string, outputIndexes []int) Option { } } +// WithWarmupInputs returns Option that sets warmupInputs. +func WithWarmupInputs(warmupInputs ...string) Option { + return func(t *tensorflow) { + if warmupInputs != nil { + if t.warmupInputs != nil { + t.warmupInputs = append(t.warmupInputs, warmupInputs...) + } else { + t.warmupInputs = warmupInputs + } + } + } +} + // WithNdim returns Option that sets ndim. func WithNdim(ndim uint8) Option { return func(t *tensorflow) { diff --git a/internal/core/converter/tensorflow/option_test.go b/internal/core/converter/tensorflow/option_test.go index 0c522faf4e..c22600329d 100644 --- a/internal/core/converter/tensorflow/option_test.go +++ b/internal/core/converter/tensorflow/option_test.go @@ -21,6 +21,9 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + tf "github.com/tensorflow/tensorflow/tensorflow/go" "github.com/vdaas/vald/internal/errors" "go.uber.org/goleak" ) @@ -71,7 +74,7 @@ func TestWithSessionOptions(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -140,7 +143,7 @@ func TestWithSessionTarget(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -209,7 +212,7 @@ func TestWithSessionConfig(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -294,7 +297,7 @@ func TestWithOperations(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -363,7 +366,7 @@ func TestWithExportPath(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -459,7 +462,7 @@ func TestWithTags(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -482,6 +485,89 @@ func TestWithTags(t *testing.T) { } } +func TestWithLoadFunc(t *testing.T) { + type T = tensorflow + type args struct { + loadFunc func(string, []string, *SessionOptions) (*tf.SavedModel, error) + } + type want struct { + obj *T + } + type test struct { + name string + args args + want want + checkFunc func(want, *T) error + beforeFunc func(args) + afterFunc func(args) + } + + defaultCheckFunc := func(w want, obj *T) error { + opts := []cmp.Option{ + cmp.AllowUnexported(tensorflow{}), + cmp.AllowUnexported(OutputSpec{}), + cmpopts.IgnoreFields(tensorflow{}, "loadFunc"), + cmp.Comparer(func(want, obj T) bool { + p1 := reflect.ValueOf(want).FieldByName("loadFunc").Pointer() + p2 := reflect.ValueOf(obj).FieldByName("loadFunc").Pointer() + return p1 == p2 + }), + } + if diff := cmp.Diff(w.obj, obj, opts...); diff != "" { + return errors.Errorf("err: %s", diff) + } + return nil + } + + loadFunc := func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { + return nil, nil + } + tests := []test{ + { + name: "set success when loadFunc is not nil", + args: args{ + loadFunc: loadFunc, + }, + want: want{ + obj: &T{ + loadFunc: loadFunc, + }, + }, + }, + { + name: "do nothing when loadFunc is nil", + args: args{ + loadFunc: nil, + }, + want: want{ + obj: &T{}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt) + if test.beforeFunc != nil { + test.beforeFunc(test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(test.args) + } + + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := withLoadFunc(test.args.loadFunc) + obj := new(T) + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + func TestWithFeed(t *testing.T) { type T = tensorflow type args struct { @@ -529,7 +615,7 @@ func TestWithFeed(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -635,7 +721,7 @@ func TestWithFeeds(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -703,7 +789,7 @@ func TestWithFetch(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -809,7 +895,7 @@ func TestWithFetches(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -830,6 +916,104 @@ func TestWithFetches(t *testing.T) { } } +func TestWithWarmupInputs(t *testing.T) { + type T = tensorflow + type args struct { + warmupInputs []string + } + type fields struct { + warmupInputs []string + } + type want struct { + obj *T + } + type test struct { + name string + args args + want want + fields fields + checkFunc func(want, *T) error + beforeFunc func(args) + afterFunc func(args) + } + + defaultCheckFunc := func(w want, obj *T) error { + if !reflect.DeepEqual(obj, w.obj) { + return errors.Errorf("got = %v, want %v", obj, w.obj) + } + return nil + } + + tests := []test{ + { + name: "set nothing when warmupInputs is nil", + want: want{ + obj: new(T), + }, + }, + { + name: "set success when warmupInputs is not nil and warmupInputs field is not nil", + args: args{ + warmupInputs: []string{ + "test", + }, + }, + fields: fields{ + warmupInputs: []string{ + "test", + }, + }, + want: want{ + obj: &T{ + warmupInputs: []string{ + "test", + "test", + }, + }, + }, + }, + { + name: "set success when warmupInputs is not nil and warmupInputs field is nil", + args: args{ + warmupInputs: []string{ + "test", + }, + }, + want: want{ + obj: &T{ + warmupInputs: []string{ + "test", + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt) + if test.beforeFunc != nil { + test.beforeFunc(test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(test.args) + } + + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + got := WithWarmupInputs(test.args.warmupInputs...) + obj := &T{ + warmupInputs: test.fields.warmupInputs, + } + got(obj) + if err := test.checkFunc(test.want, obj); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + func TestWithNdim(t *testing.T) { type T = tensorflow type args struct { @@ -870,7 +1054,7 @@ func TestWithNdim(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 6f90c12b7a..a8ff5aa666 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -30,7 +30,7 @@ type SessionOptions = tf.SessionOptions // Operation is a type alias for tensorflow.Operation. type Operation = tf.Operation -// Closer is a type alias io.Closer +// Closer is a type alias io.Closer. type Closer = io.Closer // TF represents a tensorflow interface. @@ -47,15 +47,17 @@ type session interface { } type tensorflow struct { - exportDir string - tags []string - feeds []OutputSpec - fetches []OutputSpec - operations []*Operation - options *SessionOptions - graph *tf.Graph - session session - ndim uint8 + exportDir string + tags []string + loadFunc func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) + feeds []OutputSpec + fetches []OutputSpec + operations []*Operation + options *SessionOptions + graph *tf.Graph + session session + warmupInputs []string + ndim uint8 } // OutputSpec is the specification of an feed/fetch. @@ -69,8 +71,6 @@ const ( threeDim ) -var loadFunc = tf.LoadSavedModel - // New load a tensorlfow model and returns a new tensorflow struct. func New(opts ...Option) (TF, error) { t := new(tensorflow) @@ -79,7 +79,7 @@ func New(opts ...Option) (TF, error) { opt(t) } - model, err := loadFunc(t.exportDir, t.tags, t.options) + model, err := t.loadFunc(t.exportDir, t.tags, t.options) if err != nil { return nil, err } @@ -87,9 +87,25 @@ func New(opts ...Option) (TF, error) { t.graph = model.Graph t.session = model.Session + err = t.warmup() + if err != nil { + return nil, err + } + return t, nil } +func (t *tensorflow) warmup() error { + if t.warmupInputs != nil { + _, err := t.run(t.warmupInputs...) + if err != nil { + return err + } + } + + return nil +} + func (t *tensorflow) Close() error { return t.session.Close() } diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 8954b1b687..49cb28bded 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -21,6 +21,8 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" tf "github.com/tensorflow/tensorflow/tensorflow/go" "github.com/vdaas/vald/internal/errors" "go.uber.org/goleak" @@ -42,74 +44,127 @@ func TestNew(t *testing.T) { beforeFunc func(args) afterFunc func(args) } + savedModel := &tf.SavedModel{} + loadFunc := func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { + return savedModel, nil + } defaultCheckFunc := func(w want, got TF, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got error = %v, want %v", err, w.err) } - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got = %v, want %v", got, w.want) + + opts := []cmp.Option{ + cmp.AllowUnexported(tensorflow{}), + cmp.AllowUnexported(OutputSpec{}), + cmpopts.IgnoreFields(tensorflow{}, "loadFunc"), + cmp.Comparer(func(want, got TF) bool { + p1 := reflect.ValueOf(want).Elem().FieldByName("loadFunc").Pointer() + p2 := reflect.ValueOf(got).Elem().FieldByName("loadFunc").Pointer() + return p1 == p2 + }), + } + if diff := cmp.Diff(w.want, got, opts...); diff != "" { + return errors.Errorf("err: %s", diff) } return nil } tests := []test{ { name: "returns (t, nil) when opts is nil", + args: args{ + opts: []Option{ + withLoadFunc(loadFunc), + }, + }, want: want{ want: &tensorflow{ - session: (&tf.SavedModel{}).Session, + loadFunc: loadFunc, + graph: savedModel.Graph, + session: savedModel.Session, }, }, beforeFunc: func(args args) { defaultOpts = []Option{} - loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { - return &tf.SavedModel{}, nil - } }, }, { name: "returns (t, nil) when args is not nil", args: args{ opts: []Option{ + withLoadFunc(loadFunc), + WithFeed("test", 0), + WithFetch("test", 0), WithSessionTarget("test"), WithSessionConfig([]byte{}), + WithWarmupInputs(), WithNdim(1), }, }, want: want{ want: &tensorflow{ + loadFunc: loadFunc, + feeds: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + fetches: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, options: &tf.SessionOptions{ Target: "test", Config: []byte{}, }, - graph: nil, - session: (&tf.SavedModel{}).Session, - ndim: 1, + graph: savedModel.Graph, + session: savedModel.Session, + warmupInputs: nil, + ndim: 1, }, }, beforeFunc: func(args args) { defaultOpts = []Option{} - loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { - return &tf.SavedModel{}, nil - } }, }, { name: "returns (nil, error) when loadFunc function returns error", + args: args{ + opts: []Option{ + withLoadFunc(func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { + return nil, errors.New("load error") + }), + }, + }, want: want{ err: errors.New("load error"), }, beforeFunc: func(args args) { defaultOpts = []Option{} - loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { - return nil, errors.New("load error") - } + }, + }, + { + name: "returns (nil, error) when warmup error", + args: args{ + opts: []Option{ + withLoadFunc(loadFunc), + WithWarmupInputs("test"), + }, + }, + want: want{ + err: errors.ErrInputLength(1, 0), + }, + beforeFunc: func(args args) { + defaultOpts = []Option{} }, }, } for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -128,6 +183,100 @@ func TestNew(t *testing.T) { } } +func Test_tensorflow_warmup(t *testing.T) { + type fields struct { + feeds []OutputSpec + graph *tf.Graph + session session + warmupInputs []string + } + type want struct { + err error + } + type test struct { + name string + fields fields + want want + checkFunc func(want, error) error + beforeFunc func() + afterFunc func() + } + defaultCheckFunc := func(w want, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got error = %v, want %v", err, w.err) + } + return nil + } + tests := []test{ + { + name: "return nil when warmupInputs is nil", + want: want{ + err: nil, + }, + }, + { + name: "return nil when warmupInputs is not nil", + fields: fields{ + feeds: []OutputSpec{ + { + operationName: "test", + outputIndex: 0, + }, + }, + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + warmupInputs: []string{ + "test", + }, + }, + want: want{ + err: nil, + }, + }, + { + name: "return error", + fields: fields{ + warmupInputs: []string{ + "test", + }, + }, + want: want{ + err: errors.ErrInputLength(1, 0), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt) + if test.beforeFunc != nil { + test.beforeFunc() + } + if test.afterFunc != nil { + defer test.afterFunc() + } + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + t := &tensorflow{ + feeds: test.fields.feeds, + graph: test.fields.graph, + session: test.fields.session, + warmupInputs: test.fields.warmupInputs, + } + + err := t.warmup() + if err := test.checkFunc(test.want, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + func Test_tensorflow_Close(t *testing.T) { type fields struct { session session @@ -177,7 +326,7 @@ func Test_tensorflow_Close(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc() } @@ -298,7 +447,7 @@ func Test_tensorflow_run(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -532,7 +681,7 @@ func Test_tensorflow_GetVector(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -645,7 +794,7 @@ func Test_tensorflow_GetValue(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) } @@ -734,7 +883,7 @@ func Test_tensorflow_GetValues(t *testing.T) { for _, test := range tests { t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(t) + defer goleak.VerifyNone(tt) if test.beforeFunc != nil { test.beforeFunc(test.args) }