Skip to content

Commit

Permalink
fix golangci-lint issue
Browse files Browse the repository at this point in the history
Signed-off-by: datelier <57349093+datelier@users.noreply.github.com>
  • Loading branch information
datelier committed May 25, 2020
1 parent 33740e1 commit b4a0498
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions internal/core/converter/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ var loadFunc = func(exportDir string, tags []string, options *SessionOptions) (*

func New(opts ...Option) (TF, error) {
t := new(tensorflow)

for _, opt := range append(defaultOpts, opts...) {
opt(t)
}
Expand All @@ -77,8 +78,10 @@ func New(opts ...Option) (TF, error) {
if err != nil {
return nil, err
}

t.graph = model.Graph
t.session = model.Session

return t, nil
}

Expand All @@ -92,11 +95,13 @@ func (t *tensorflow) run(inputs ...string) ([]*tf.Tensor, error) {
}

feeds := make(map[tf.Output]*tf.Tensor, len(inputs))

for i, val := range inputs {
inputTensor, err := tf.NewTensor(val)
if err != nil {
return nil, err
}

feeds[t.graph.Operation(t.feeds[i].operationName).Output(t.feeds[i].outputIndex)] = inputTensor
}

Expand All @@ -113,6 +118,7 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) {
if err != nil {
return nil, err
}

if len(tensors) == 0 || tensors[0] == nil || tensors[0].Value() == nil {
return nil, errors.ErrNilTensorTF(tensors)
}
Expand All @@ -124,27 +130,29 @@ func (t *tensorflow) GetVector(inputs ...string) ([]float64, error) {
if value == nil {
return nil, errors.ErrNilTensorValueTF(value)
}

return value[0], nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}

return nil, errors.ErrFailedToCastTF(tensors[0].Value())
case ThreeDim:
value, ok := tensors[0].Value().([][][]float64)
if ok {
if len(value) == 0 || value[0] == nil {
return nil, errors.ErrNilTensorValueTF(value)
}

return value[0][0], nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}

return nil, errors.ErrFailedToCastTF(tensors[0].Value())
default:
value, ok := tensors[0].Value().([]float64)
if ok {
return value, nil
} else {
return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}

return nil, errors.ErrFailedToCastTF(tensors[0].Value())
}
}

Expand All @@ -153,9 +161,11 @@ func (t *tensorflow) GetValue(inputs ...string) (interface{}, error) {
if err != nil {
return nil, err
}

if len(tensors) == 0 || tensors[0] == nil {
return nil, errors.ErrNilTensorTF(tensors)
}

return tensors[0].Value(), nil
}

Expand All @@ -164,9 +174,11 @@ func (t *tensorflow) GetValues(inputs ...string) (values []interface{}, err erro
if err != nil {
return nil, err
}

values = make([]interface{}, 0, len(tensors))
for _, tensor := range tensors {
values = append(values, tensor.Value())
}

return values, nil
}

0 comments on commit b4a0498

Please sign in to comment.