Skip to content

Commit

Permalink
fix golangci-lint issue
Browse files Browse the repository at this point in the history
Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>
  • Loading branch information
datelier authored and actions-user committed Jun 2, 2020
1 parent 015a30a commit 17304e0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
12 changes: 12 additions & 0 deletions internal/core/converter/tensorflow/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// Package tensorflow provides implementation of Go API for extract data to vector
package tensorflow

// Option is tensorflow configure.
type Option func(*tensorflow)

var (
Expand All @@ -27,6 +28,7 @@ var (
}
)

// WithSessionOptions returns Option that sets options.
func WithSessionOptions(opts *SessionOptions) Option {
return func(t *tensorflow) {
if opts != nil {
Expand All @@ -35,6 +37,7 @@ func WithSessionOptions(opts *SessionOptions) Option {
}
}

// WithSessionTarget returns Option that sets target.
func WithSessionTarget(tgt string) Option {
return func(t *tensorflow) {
if tgt != "" {
Expand All @@ -49,6 +52,7 @@ func WithSessionTarget(tgt string) Option {
}
}

// WithSessionConfig returns Option that sets config.
func WithSessionConfig(cfg []byte) Option {
return func(t *tensorflow) {
if cfg != nil {
Expand All @@ -63,6 +67,7 @@ func WithSessionConfig(cfg []byte) Option {
}
}

// WithOperations returns Option that sets operations.
func WithOperations(opes ...*Operation) Option {
return func(t *tensorflow) {
if opes != nil {
Expand All @@ -75,6 +80,7 @@ func WithOperations(opes ...*Operation) Option {
}
}

// WithExportPath returns Option that sets exportDir.
func WithExportPath(path string) Option {
return func(t *tensorflow) {
if path != "" {
Expand All @@ -83,6 +89,7 @@ func WithExportPath(path string) Option {
}
}

// WithTags returns Option that sets tags.
func WithTags(tags ...string) Option {
return func(t *tensorflow) {
if tags != nil {
Expand All @@ -95,12 +102,14 @@ func WithTags(tags ...string) Option {
}
}

// WithFeed returns Option that sets feeds.
func WithFeed(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.feeds = append(t.feeds, OutputSpec{operationName, outputIndex})
}
}

// WithFeeds returns Option that sets feeds.
func WithFeeds(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
Expand All @@ -111,12 +120,14 @@ func WithFeeds(operationNames []string, outputIndexes []int) Option {
}
}

// WithFetch returns Option that sets fetches.
func WithFetch(operationName string, outputIndex int) Option {
return func(t *tensorflow) {
t.fetches = append(t.fetches, OutputSpec{operationName, outputIndex})
}
}

// WithFetches returns Option that sets fetches.
func WithFetches(operationNames []string, outputIndexes []int) Option {
return func(t *tensorflow) {
if operationNames != nil && outputIndexes != nil && len(operationNames) == len(outputIndexes) {
Expand All @@ -127,6 +138,7 @@ func WithFetches(operationNames []string, outputIndexes []int) Option {
}
}

// WithNdim returns Option that sets ndim.
func WithNdim(ndim uint8) Option {
return func(t *tensorflow) {
t.ndim = ndim
Expand Down
15 changes: 11 additions & 4 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ import (
"github.com/vdaas/vald/internal/errors"
)

// SessionOptions is a type alias for tensorflow.SessionOptions.
type SessionOptions = tf.SessionOptions

// Operation is a type alias for tensorflow.Operation.
type Operation = tf.Operation

// TF represents a tensorflow interface.
type TF interface {
GetVector(inputs ...string) ([]float64, error)
GetValue(inputs ...string) (interface{}, error)
Expand All @@ -37,6 +41,7 @@ type session interface {
Closer
}

// Closer close a tensorflow.Session.
type Closer interface {
Close() error
}
Expand All @@ -53,20 +58,22 @@ type tensorflow struct {
ndim uint8
}

// OutputSpec is the specification of an feed/fetch.
type OutputSpec struct {
operationName string
outputIndex int
}

const (
TwoDim uint8 = iota + 2
ThreeDim
twoDim uint8 = iota + 2
threeDim
)

var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*tf.SavedModel, error) {
return tf.LoadSavedModel(exportDir, tags, options)
}

// New load a tensorlfow model and returns a new tensorflow struct.
func New(opts ...Option) (TF, error) {
t := new(tensorflow)

Expand Down Expand Up @@ -124,7 +131,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) {
}

switch t.ndim {
case TwoDim:
case twoDim:
value, ok := tensors[0].Value().([][]float64)
if ok {
if value == nil {
Expand All @@ -135,7 +142,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) {
}

return nil, errors.ErrFailedToCastTF(tensors[0].Value())
case ThreeDim:
case threeDim:
value, ok := tensors[0].Value().([][][]float64)
if ok {
if len(value) == 0 || value[0] == nil {
Expand Down
6 changes: 0 additions & 6 deletions internal/core/converter/tensorflow/tensorflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ func TestNew(t *testing.T) {
if err := test.checkFunc(test.want, got, err); err != nil {
tt.Errorf("error = %v", err)
}

})
}
}
Expand Down Expand Up @@ -196,7 +195,6 @@ func Test_tensorflow_Close(t *testing.T) {
if err := test.checkFunc(test.want, err); err != nil {
tt.Errorf("error = %v", err)
}

})
}
}
Expand Down Expand Up @@ -320,7 +318,6 @@ func Test_tensorflow_run(t *testing.T) {
if err := test.checkFunc(test.want, got, err); err != nil {
tt.Errorf("error = %v", err)
}

})
}
}
Expand Down Expand Up @@ -554,7 +551,6 @@ func Test_tensorflow_GetVector(t *testing.T) {
if err := test.checkFunc(test.want, got, err); err != nil {
tt.Errorf("error = %v", err)
}

})
}
}
Expand Down Expand Up @@ -667,7 +663,6 @@ func Test_tensorflow_GetValue(t *testing.T) {
if err := test.checkFunc(test.want, got, err); err != nil {
tt.Errorf("error = %v", err)
}

})
}
}
Expand Down Expand Up @@ -757,7 +752,6 @@ func Test_tensorflow_GetValues(t *testing.T) {
if err := test.checkFunc(test.want, gotValues, err); err != nil {
tt.Errorf("error = %v", err)
}

})
}
}

0 comments on commit 17304e0

Please sign in to comment.