Skip to content
This repository was archived by the owner on Nov 25, 2020. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmd/graphpipe-onnx/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
** Copyright © 2018, Oracle and/or its affiliates. All rights reserved.
** Licensed under the Universal Permissive License v 1.0 as shown at http://oss.oracle.com/licenses/upl.
*/
*/

package main

Expand Down Expand Up @@ -269,6 +269,8 @@ func serve(opts options) error {
for _, engine_ctx := range c2c.CEngineCtxs {
C.c2_engine_register_input(engine_ctx, C.CString(k), (*C.int64_t)(&dims[0]), (C.int(len(dims))), C.int(dtype))
}
} else {
logrus.Fatalf("Invalid value for value_input with key: %s. Format should be {\"k\": [dtype, [dim1, dim2, ... ]]}, eg {\"k\": [1, [1, 3, 217, 217]]}", k)
}
}

Expand Down
138 changes: 136 additions & 2 deletions cmd/graphpipe-tf/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
** Copyright © 2018, Oracle and/or its affiliates. All rights reserved.
** Licensed under the Universal Permissive License v 1.0 as shown at http://oss.oracle.com/licenses/upl.
*/
*/

package main

Expand All @@ -13,15 +13,21 @@ import (
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"time"

"github.com/Sirupsen/logrus"
"github.com/spf13/cobra"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
graphpipe "github.com/oracle/graphpipe-go"
tfproto "github.com/oracle/graphpipe-go/cmd/graphpipe-tf/internal/github.com/tensorflow/tensorflow/tensorflow/go/core/framework"
cproto "github.com/oracle/graphpipe-go/cmd/graphpipe-tf/internal/github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf"
Expand Down Expand Up @@ -185,7 +191,7 @@ func initializeMetadata(opts options, c *tfContext) *graphpipe.NativeMetadataRes
}
c.shapes = append(c.shapes, shape)

if op.Type() != "Const" && t != graphpipefb.TypeNull {
if op.Type() != "Const" && op.Type() != "CheckNumerics" && t != graphpipefb.TypeNull {
if len(node.Input) == 0 {
c.defaultInputs = append(c.defaultInputs, name)
} else if _, present := outputsThatAreInputs[node.Name]; !present {
Expand Down Expand Up @@ -275,6 +281,33 @@ func readModel(uri string) ([]byte, error) {
}
return ioutil.ReadAll(response.Body)
}
if strings.HasPrefix(uri, "s3://") {
u, err := url.Parse(uri)
if err != nil {
logrus.Errorf("Failed to parse uri '%s': %v", uri, err)
return nil, err
}
bucket := u.Host
item := u.Path
sess, err := session.NewSession()
if err != nil {
logrus.Errorf("Failed to create s3 session %v", err)
return nil, err
}
downloader := s3manager.NewDownloader(sess)
buf := &aws.WriteAtBuffer{}
_, err = downloader.Download(buf,
&s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(item),
})
if err != nil {
logrus.Errorf("Unable to download item %q, %v", item, err)
return nil, err
}
return buf.Bytes(), nil

}
return ioutil.ReadFile(uri)
}

Expand Down Expand Up @@ -378,6 +411,7 @@ func serve(opts options) error {
DefaultInputs: dIn,
DefaultOutputs: dOut,
Apply: c.apply,
RESTApply: c.restApply,
GetHandler: c.getHandler,
}

Expand Down Expand Up @@ -470,6 +504,29 @@ var gptype2tftype = []tf.DataType{
tf.DataType(tfproto.DataType_DT_DOUBLE), // Type_Float64 = 11,
tf.DataType(tfproto.DataType_DT_STRING), // Type_String = 12,
}
var tftype2type = map[tf.DataType]reflect.Type{
tf.Float: reflect.TypeOf(float32(0)),
tf.Double: reflect.TypeOf(float64(0)),
tf.Int32: reflect.TypeOf(int32(0)),
tf.Uint32: reflect.TypeOf(uint32(0)),
tf.Uint8: reflect.TypeOf(uint8(0)),
tf.Int16: reflect.TypeOf(int16(0)),
tf.Int8: reflect.TypeOf(int8(0)),
tf.String: reflect.TypeOf(string("")),
tf.Complex64: reflect.TypeOf(complex(float32(0), float32(0))),
tf.Int64: reflect.TypeOf(int64(0)),
tf.Uint64: reflect.TypeOf(uint64(0)),
tf.Bool: reflect.TypeOf(true),
tf.Qint8: reflect.TypeOf(nil), // not supported
tf.Quint8: reflect.TypeOf(nil), // not supported
tf.Qint32: reflect.TypeOf(nil), // not supported
tf.Bfloat16: reflect.TypeOf(nil), // not supported
tf.Qint16: reflect.TypeOf(nil), // not supported
tf.Quint16: reflect.TypeOf(nil), // not supported
tf.Uint16: reflect.TypeOf(nil), // not supported
tf.Complex128: reflect.TypeOf(complex(float64(0), float64(0))),
tf.Half: reflect.TypeOf(nil), // not supported
}

func tensorFromNT(nt *graphpipe.NativeTensor) (*tf.Tensor, error) {
if nt.Type == graphpipefb.TypeString {
Expand Down Expand Up @@ -503,6 +560,81 @@ func getInputMap(c *tfContext, inputs map[string]*graphpipe.NativeTensor) (map[t
return inputMap, nil
}

func tensorFromJSON(output tf.Output, data []byte) (interface{}, error) {
t := tftype2type[output.DataType()]
shape := output.Shape()
dims := shape.NumDimensions()
for i := dims - 1; i >= 0; i-- {
size := shape.Size(i)
if size == -1 {
t = reflect.SliceOf(t)
} else {
t = reflect.ArrayOf(int(size), t)
}
}
tensor := reflect.New(t).Interface()
err := json.Unmarshal(data, tensor)
if err != nil {
return nil, err
}
return reflect.ValueOf(tensor).Elem().Interface(), nil
}

func getRESTInputMap(c *tfContext, inputs map[string]json.RawMessage) (map[tf.Output]*tf.Tensor, error) {
inputMap := map[tf.Output]*tf.Tensor{}
for name, input := range inputs {
output := tf.Output{}
var ok bool
if !strings.Contains(name, ":") {
name += ":0"
}
output, ok = c.outputs[name]
if !ok {
msg := "Could not find input '%s'"
logrus.Errorf(msg, name)
return nil, fmt.Errorf(msg, name)
}
tensor, err := tensorFromJSON(output, input)
if err != nil {
logrus.Errorf("Failed to create raw tensor: %v", err)
return nil, err
}
inputTensor, err := tf.NewTensor(tensor)
if err != nil {
logrus.Errorf("Failed to create tensor: %v", err)
return nil, err
}
inputMap[output] = inputTensor
}

return inputMap, nil
}

func (tfc *tfContext) restApply(inputs map[string]json.RawMessage, outputNames []string) (interface{}, error) {
outputRequests, err := getOutputRequests(tfc, outputNames)
if err != nil {
return nil, err
}
inputMap, err := getRESTInputMap(tfc, inputs)
if err != nil {
return nil, err
}
tensors, err := tfc.model.Session.Run(
inputMap,
outputRequests,
nil,
)
if err != nil {
logrus.Errorf("Failed to run session: %v", err)
return nil, err
}
res := []interface{}{}
for _, tensor := range tensors {
res = append(res, tensor.Value())
}
return res, err
}

func (tfc *tfContext) apply(requestContext *graphpipe.RequestContext, config string, inputs map[string]*graphpipe.NativeTensor, outputNames []string) ([]*graphpipe.NativeTensor, error) {
outputIndexes := []int{}
outputTps := make([]*graphpipe.NativeTensor, len(outputNames))
Expand All @@ -522,11 +654,13 @@ func (tfc *tfContext) apply(requestContext *graphpipe.RequestContext, config str
if err != nil {
return nil, err
}

tensors, err := tfc.model.Session.Run(
inputMap,
outputRequests,
nil,
)

if err != nil {
logrus.Errorf("Failed to run session: %v", err)
return nil, err
Expand Down
7 changes: 7 additions & 0 deletions rest_apply.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package graphpipe

import (
"encoding/json"
)

type RESTApplier func(inputs map[string]json.RawMessage, outputNames []string) (interface{}, error)
2 changes: 1 addition & 1 deletion results.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
** Copyright © 2018, Oracle and/or its affiliates. All rights reserved.
** Licensed under the Universal Permissive License v 1.0 as shown at http://oss.oracle.com/licenses/upl.
*/
*/

package graphpipe

Expand Down
28 changes: 27 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
** Copyright © 2018, Oracle and/or its affiliates. All rights reserved.
** Licensed under the Universal Permissive License v 1.0 as shown at http://oss.oracle.com/licenses/upl.
*/
*/

package graphpipe

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -119,6 +120,7 @@ type ServeRawOptions struct {
DefaultInputs []string
DefaultOutputs []string
Apply Applier
RESTApply RESTApplier
GetHandler GetHandlerFunc
}

Expand All @@ -130,6 +132,7 @@ func ServeRaw(opts *ServeRawOptions) error {
c := &appContext{
meta: opts.Meta,
apply: opts.Apply,
restApply: opts.RESTApply,
getHandler: opts.GetHandler,
defaultInputs: opts.DefaultInputs,
defaultOutputs: opts.DefaultOutputs,
Expand All @@ -145,6 +148,7 @@ func ServeRaw(opts *ServeRawOptions) error {
defer c.db.Close()
}
setupLifecycleRoutes(c)
http.Handle("/rest", appHandler{c, RESTHandler})
http.Handle("/", appHandler{c, Handler})
logrus.Infof("Listening on '%s'", opts.Listen)
err = ListenAndServe(opts.Listen, nil)
Expand All @@ -159,6 +163,7 @@ func ServeRaw(opts *ServeRawOptions) error {
type appContext struct {
meta *NativeMetadataResponse
apply Applier
restApply RESTApplier
getHandler GetHandlerFunc
defaultInputs []string
defaultOutputs []string
Expand Down Expand Up @@ -217,6 +222,27 @@ func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logrus.Infof("Request for %s took %s", r.URL.Path, duration)
}

// RESTHandler handle rest http requests.
func RESTHandler(c *appContext, w http.ResponseWriter, r *http.Request) error {
dec := json.NewDecoder(r.Body)
inputs := make(map[string]json.RawMessage)
err := dec.Decode(&inputs)
if err != nil {
return StatusError{400, err}
}
ret, err := c.restApply(inputs,
c.defaultOutputs)
if err != nil {
return StatusError{400, err}
}
enc := json.NewEncoder(w)
err = enc.Encode(ret)
if err != nil {
return StatusError{400, err}
}
return nil
}

// Handler handles our http requests.
func Handler(c *appContext, w http.ResponseWriter, r *http.Request) error {
body, err := ioutil.ReadAll(r.Body)
Expand Down
Loading