Skip to content

Commit

Permalink
Add tensorflow test code
Browse files Browse the repository at this point in the history
Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>
  • Loading branch information
datelier authored and actions-user committed Jun 2, 2020
1 parent 7bf7dbf commit 3d57c14
Show file tree
Hide file tree
Showing 3 changed files with 674 additions and 281 deletions.
17 changes: 15 additions & 2 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ type TF interface {
GetVector(inputs ...string) ([]float64, error)
GetValue(inputs ...string) (interface{}, error)
GetValues(inputs ...string) (values []interface{}, err error)
Closer
}

type session interface {
Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error)
Closer
}

type Closer interface {
Close() error
}

Expand All @@ -42,7 +51,7 @@ type tensorflow struct {
sessionConfig []byte
options *SessionOptions
graph *tf.Graph
session *tf.Session
session session
ndim uint8
}

Expand All @@ -56,6 +65,10 @@ const (
ThreeDim
)

var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) {
return tf.LoadSavedModel(exportDir, tags, options)
}

func New(opts ...Option) (TF, error) {
t := new(tensorflow)
for _, opt := range append(defaultOpts, opts...) {
Expand All @@ -69,7 +82,7 @@ func New(opts ...Option) (TF, error) {
}
}

model, err := tf.LoadSavedModel(t.exportDir, t.tags, t.options)
model, err := loadFunc(t.exportDir, t.tags, t.options)
if err != nil {
return nil, err
}
Expand Down
35 changes: 35 additions & 0 deletions internal/core/converter/tensorflow/tensorflow_mock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//
// Copyright (C) 2019-2020 Vdaas.org Vald team ( kpango, rinx, kmrmt )
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

// Package tensorflow provides implementation of Go API for extract data to vector
package tensorflow

import (
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)

type mockSession struct {
RunFunc func(map[tf.Output]*tf.Tensor, []tf.Output, []*Operation) ([]*tf.Tensor, error)
CloseFunc func() error
}

func (m *mockSession) Run(feeds map[tf.Output]*tf.Tensor, fetches []tf.Output, operations []*Operation) ([]*tf.Tensor, error) {
return m.RunFunc(feeds, fetches, operations)
}

func (m *mockSession) Close() error {
return m.CloseFunc()
}
Loading

0 comments on commit 3d57c14

Please sign in to comment.