Skip to content
This repository was archived by the owner on Nov 25, 2020. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
vendor/*/
.idea/
*.sw*
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ func MultiRemote(client *http.Client, uri string, config string, ins []interface
// request using NativeTensor objects. The raw call is provided
// for requests that need optimal performance and do not need to
// be converted into native go types.
func MultiRemoteRaw(client *http.Client, uri string, config string, inputs []*NativeTensor, inputNames, outputNames []string) ([]*NativeTensor, error) {
func MultiRemoteRaw(client graphpipe.Client, uri string, config string, inputs []*NativeTensor, inputNames, outputNames []string) ([]*NativeTensor, error) {
```
In similar fashion to the serving model, the client for making remote
calls is made up of three functions, Remote, MultiRemote, and
MultiRemoteRaw.
In similar fashion to the serving model, the API for making remote
calls has three functions (Remote, MultiRemote, and MultiRemoteRaw).

The functions range from simple to complex. The first three of those will
convert your native Go types into tensors and back, while the last one uses
Expand Down
23 changes: 20 additions & 3 deletions cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Sirupsen/logrus"
bolt "github.com/coreos/bbolt"
graphpipefb "github.com/oracle/graphpipe-go/graphpipefb"
"sync"
)

const (
Expand Down Expand Up @@ -110,7 +111,12 @@ func (t *Nt) tensorFromIndexes(indexes []int) *NativeTensor {
return nt
}

func getKey(c *appContext, inputs []*Nt, index int) []byte {
type hashPair struct {
index int
hash []byte
}

func getKey(inputs []*Nt, index int) []byte {
numInputs := len(inputs)
if numInputs == 0 {
return []byte(emptyKey)
Expand Down Expand Up @@ -377,7 +383,7 @@ func getInputTensors(req *graphpipefb.InferRequest) ([]*NativeTensor, error) {
tensor := &graphpipefb.Tensor{}

if !req.InputTensors(tensor, i) {
err := fmt.Errorf("Bad input tensor")
err := fmt.Errorf("Bad input tensor #%d", i)
return nil, err
}

Expand Down Expand Up @@ -409,8 +415,19 @@ func getResultsCached(c *appContext, requestContext *RequestContext, req *graphp
}

keys := make([][]byte, numChunks)
ch := make(chan hashPair, numChunks)
var wg sync.WaitGroup
wg.Add(numChunks)
for i := 0; i < numChunks; i++ {
keys[i] = getKey(c, inputs, i)
go func(i int) {
ch <- hashPair{i, getKey(inputs, i)}
wg.Done()
}(i)
}
wg.Wait()
close(ch)
for elem := range ch {
keys[elem.index] = elem.hash
}

outputNames, err := getOutputNames(c, req)
Expand Down
124 changes: 124 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package graphpipe

import (
"net"
"net/http"
fb "github.com/google/flatbuffers/go"
"bytes"
"github.com/Sirupsen/logrus"
"io/ioutil"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"github.com/gen2brain/shm"
)

type Client interface {
call(*fb.Builder, []byte) ([]byte, error)
builder() *fb.Builder
}

type ShmClient struct {
Conn *net.Conn
Shm []byte
}

type HttpClient struct {
NetHttpClient *http.Client
Uri string
}


func (sc ShmClient) builder() *fb.Builder {
b := fb.NewBuilder(0)
b.Bytes = sc.Shm
b.Reset()
return b
}

func (sc ShmClient) call(builder *fb.Builder, request []byte) ([]byte, error) {
startPos := builder.Head()
length := len(request)
WriteInt(sc.Conn, uint32(startPos))
WriteInt(sc.Conn, uint32(length))
respStartPos, err := ReadInt(sc.Conn)
respSize, err := ReadInt(sc.Conn)
if err != nil {
return nil, err
}
body := sc.Shm[respStartPos:respStartPos + respSize]
return body, nil
}


func (hc HttpClient) builder() *fb.Builder {
return fb.NewBuilder(1024)
}

func (hc HttpClient) call(builder *fb.Builder, request []byte) ([]byte, error) {
rq, err := http.NewRequest("POST", hc.Uri, bytes.NewReader(request))
if err != nil {
logrus.Errorf("Failed to create request: %v", err)
return nil, err
}

// send the request
rs, err := hc.NetHttpClient.Do(rq)
if err != nil {
logrus.Errorf("Failed to send request: %v", err)
return nil, err
}
defer rs.Body.Close()

body, err := ioutil.ReadAll(rs.Body)
if err != nil {
logrus.Errorf("Failed to read body: %v", err)
return nil, err
}
if rs.StatusCode != 200 {
return nil, fmt.Errorf("Remote failed with %d: %s", rs.StatusCode, string(body))
}

return body, nil
}



// Opens the socket, creates the shared memory, communicates the shm id over
// the socket, and installs a signal handler to close the socket and remove
// the shm.
func CreateShmClient(socket string, shmSize int) Client {
conn, err := net.Dial("unix", socket)
if err != nil {
logrus.Fatal("Dial error", err)
}

shmId, err := shm.Get(shm.IPC_PRIVATE, shmSize, shm.IPC_CREAT|0777)
if err != nil || shmId < 0 {
panic(fmt.Sprintf("Could not shmget %d bytes", shmSize))
}

sigc := make(chan os.Signal, 1)
signal.Notify(sigc, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
go func(conn *net.Conn, c chan os.Signal) {
sig := <-c
log.Printf("Caught signal %s: shutting down", sig)
shm.Rm(shmId)
(*conn).Close()
os.Exit(-1)
}(&conn, sigc)

shmBytes, err := shm.At(shmId, 0, 0)
// Communicate our shm id to the server.
WriteInt(&conn, uint32(shmId))
if err != nil {
panic(err)
}

return ShmClient{
Conn: &conn,
Shm: shmBytes,
}
}
8 changes: 6 additions & 2 deletions cmd/graphpipe-batcher/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,14 @@ func serve(opts options) error {
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: 5 * time.Second,
}
var client = &http.Client{
var httpClient = &http.Client{
Timeout: time.Second * 60,
Transport: transport,
}
var client = graphpipe.HttpClient{
NetHttpClient: httpClient,
Uri: opts.targetURL,
}

data := []*ioData{}

Expand Down Expand Up @@ -313,7 +317,7 @@ func serve(opts options) error {
inputs = append(inputs, &nt)
}
//ship it!
tensors, err := graphpipe.MultiRemoteRaw(client, opts.targetURL, "", inputs, inputNames, outputNames)
tensors, err := graphpipe.MultiRemoteRaw(client, "", inputs, inputNames, outputNames)
if err != nil {
for _, io := range data {
io.Error = err
Expand Down
23 changes: 14 additions & 9 deletions cmd/graphpipe-tf/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"time"
_ "net/http/pprof"

"github.com/Sirupsen/logrus"
"github.com/spf13/cobra"
Expand All @@ -43,15 +44,16 @@ func version() string {
}

type options struct {
verbose bool
version bool
cache bool
cacheDir string
listen string
model string
inputs string
shape string
outputs string
verbose bool
version bool
cache bool
cacheDir string
domainSocket string
listen string
model string
inputs string
shape string
outputs string
}

func main() {
Expand Down Expand Up @@ -94,6 +96,8 @@ func main() {
}
f := cmd.Flags()
f.StringVarP(&opts.cacheDir, "cache-dir", "d", "~/.graphpipe", "directory for local cache state")
f.StringVarP(&opts.domainSocket, "domain-socket", "s", "graphpipe-go.sock",
"where to create the domain socket.")
f.StringVarP(&opts.listen, "listen", "l", "127.0.0.1:9000", "listen string")
f.StringVarP(&opts.model, "model", "m", "", "tensorflow model to load. Accepts local file or http(s) url.")
f.StringVarP(&opts.inputs, "inputs", "i", "", "comma seprated default inputs")
Expand Down Expand Up @@ -372,6 +376,7 @@ func serve(opts options) error {
logrus.Infof("Using default outputs %s", dOut)

serveOpts := &graphpipe.ServeRawOptions{
DomainSocket: opts.domainSocket,
Listen: opts.listen,
CacheFile: cachePath,
Meta: c.meta,
Expand Down
23 changes: 23 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
fb "github.com/google/flatbuffers/go"

graphpipefb "github.com/oracle/graphpipe-go/graphpipefb"
"net"
"encoding/binary"
"io"
)

// Serialize writes a builder object to a byte array
Expand Down Expand Up @@ -470,3 +473,23 @@ func buildTensorRaw(b *fb.Builder, dataFb, stringValFb fb.UOffsetT, shape []int6
}
return graphpipefb.TensorEnd(b)
}

func ReadInt(fd *net.Conn) (uint32, error) {
buf := make([]byte, 4)
_, err := io.ReadFull(*fd, buf)
if err != nil {
return 0, err
}
num := binary.LittleEndian.Uint32(buf)
return num, nil
}

func WriteInt(fd *net.Conn, num uint32) error {
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, num)
_, err := (*fd).Write(bytes)
return err
}



Loading