diff --git a/internal/core/converter/tensorflow/tensorflow.go b/internal/core/converter/tensorflow/tensorflow.go index 799a1268f6..841b2e97d3 100644 --- a/internal/core/converter/tensorflow/tensorflow.go +++ b/internal/core/converter/tensorflow/tensorflow.go @@ -29,6 +29,15 @@ type TF interface { GetVector(inputs ...string) ([]float64, error) GetValue(inputs ...string) (interface{}, error) GetValues(inputs ...string) (values []interface{}, err error) + Closer +} + +type session interface { + Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) + Closer +} + +type Closer interface { Close() error } @@ -42,7 +51,7 @@ type tensorflow struct { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } @@ -56,6 +65,10 @@ const ( ThreeDim ) +var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) { + return tf.LoadSavedModel(exportDir, tags, options) +} + func New(opts ...Option) (TF, error) { t := new(tensorflow) for _, opt := range append(defaultOpts, opts...) { @@ -69,7 +82,7 @@ func New(opts ...Option) (TF, error) { } } - model, err := tf.LoadSavedModel(t.exportDir, t.tags, t.options) + model, err := loadFunc(t.exportDir, t.tags, t.options) if err != nil { return nil, err } diff --git a/internal/core/converter/tensorflow/tensorflow_mock_test.go b/internal/core/converter/tensorflow/tensorflow_mock_test.go new file mode 100644 index 0000000000..b9b6e4bc85 --- /dev/null +++ b/internal/core/converter/tensorflow/tensorflow_mock_test.go @@ -0,0 +1,35 @@ +// +// Copyright (C) 2019-2020 Vdaas.org Vald team ( kpango, rinx, kmrmt ) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package tensorflow provides implementation of Go API for extract data to vector +package tensorflow + +import ( + tf "github.com/tensorflow/tensorflow/tensorflow/go" +) + +type mockSession struct { + RunFunc func(map[tf.Output]*tf.Tensor, []tf.Output, []*Operation) ([]*tf.Tensor, error) + CloseFunc func() error +} + +func (m *mockSession) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) { + return m.RunFunc(feeds, fetches, operations) +} + +func (m *mockSession) Close() error { + return m.CloseFunc() +} diff --git a/internal/core/converter/tensorflow/tensorflow_test.go b/internal/core/converter/tensorflow/tensorflow_test.go index 58b9f67118..8a5fb7dfad 100644 --- a/internal/core/converter/tensorflow/tensorflow_test.go +++ b/internal/core/converter/tensorflow/tensorflow_test.go @@ -53,31 +53,74 @@ func TestNew(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - opts: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - opts: nil, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (t, nil): default options", + args: args{ + opts: nil, + }, + want: want{ + want: &tensorflow{ + graph: nil, + session: (&tf.SavedModel{}).Session, + }, + err: nil, + }, + checkFunc: defaultCheckFunc, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return &tf.SavedModel{}, nil + } + }, + }, + { + name: "return (t, nil): args options", + args: args{ + opts: []Option{ + WithSessionTarget("test"), + WithSessionConfig([]byte{}), + WithNdim(1), + }, + }, + want: want{ + want: &tensorflow{ + sessionTarget: "test", + sessionConfig: []byte{}, + options: &tf.SessionOptions{ + Target: "test", + Config: []byte{}, + }, + graph: nil, + session: (&tf.SavedModel{}).Session, + ndim: 1, + }, + err: nil, + }, + checkFunc: defaultCheckFunc, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return &tf.SavedModel{}, nil + } + }, + }, + { + name: "return (nil, error)", + args: args{ + nil, + }, + want: want{ + want: nil, + err: errors.New("load error"), + }, + checkFunc: defaultCheckFunc, + beforeFunc: func(args args) { + defaultOpts = []Option{} + loadFunc = func(s string, ss []string, o *SessionOptions) (*tf.SavedModel, error) { + return nil, errors.New("load error") + } + }, + }, } for _, test := range tests { @@ -113,7 +156,7 @@ func Test_tensorflow_Close(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -134,51 +177,54 @@ func Test_tensorflow_Close(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return nil", + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + CloseFunc: func() error { + return nil + }, + }, + ndim: 0, + }, + want: want{ + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "return error", + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + CloseFunc: func() error { + return errors.New("fail") + }, + }, + ndim: 0, + }, + want: want{ + err: errors.New("fail"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -230,7 +276,7 @@ func Test_tensorflow_run(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -256,57 +302,121 @@ func Test_tensorflow_run(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return ([], nil): inputs=nil", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: []*tf.Tensor{}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "return ([], nil): inputs={\"test\"}", + args: args{ + inputs: []string{ + "test", + }, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: []OutputSpec{ + OutputSpec{ + operationName: "test", + outputIndex: 0, + }, + }, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: []*tf.Tensor{}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "length error", + args: args{ + inputs: []string{ + "", + }, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: nil, + ndim: 0, + }, + want: want{ + err: errors.ErrInputLength(1, 0), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "session.Run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: tf.NewGraph(), + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -358,7 +468,7 @@ func Test_tensorflow_GetVector(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -384,57 +494,218 @@ func Test_tensorflow_GetVector(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (vector, nil)", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor([]float64{1, 2, 3}) + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: []float64{1, 2, 3}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return nil", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{}), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return [nil]", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{nil}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "failed to cast error: ndim=TwoDim", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 2, + }, + want: want{ + want: nil, + err: errors.ErrFailedToCastTF("test"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "failed to cast error: ndim=ThreeDim", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 3, + }, + want: want{ + want: nil, + err: errors.ErrFailedToCastTF("test"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "failed to cast error: ndim=default", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrFailedToCastTF("test"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -486,7 +757,7 @@ func Test_tensorflow_GetValue(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -512,57 +783,122 @@ func Test_tensorflow_GetValue(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (value, nil)", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: "test", + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return nil", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{}), + }, + checkFunc: defaultCheckFunc, + }, + { + name: "nil tensor error: run() return [nil]", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return []*tf.Tensor{nil}, nil + }, + }, + ndim: 0, + }, + want: want{ + want: nil, + err: errors.ErrNilTensorTF([]*tf.Tensor{nil}), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests { @@ -614,7 +950,7 @@ func Test_tensorflow_GetValues(t *testing.T) { sessionConfig []byte options *SessionOptions graph *tf.Graph - session *tf.Session + session session ndim uint8 } type want struct { @@ -640,57 +976,66 @@ func Test_tensorflow_GetValues(t *testing.T) { return nil } tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - inputs: nil, - }, - fields: fields { - exportDir: "", - tags: nil, - feeds: nil, - fetches: nil, - operations: nil, - sessionTarget: "", - sessionConfig: nil, - options: nil, - graph: nil, - session: nil, - ndim: 0, - }, - want: want{}, - checkFunc: defaultCheckFunc, - } - }(), - */ + { + name: "return (values, nil)", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + tensor, err := tf.NewTensor("test") + if err != nil { + return nil, errors.New("NewTensor error") + } + return []*tf.Tensor{tensor, tensor}, nil + }, + }, + ndim: 0, + }, + want: want{ + wantValues: []interface{}{"test", "test"}, + err: nil, + }, + checkFunc: defaultCheckFunc, + }, + { + name: "run() error", + args: args{ + inputs: nil, + }, + fields: fields{ + exportDir: "", + tags: nil, + feeds: nil, + fetches: nil, + operations: nil, + sessionTarget: "", + sessionConfig: nil, + options: nil, + graph: nil, + session: &mockSession{ + RunFunc: func(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*tf.Operation) ([]*tf.Tensor, error) { + return nil, errors.New("session.Run() error") + }, + }, + ndim: 0, + }, + want: want{ + wantValues: nil, + err: errors.New("session.Run() error"), + }, + checkFunc: defaultCheckFunc, + }, } for _, test := range tests {