Skip to content

Commit

Permalink
Add semaphore locking to TensorFlow calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
dchoi-viant committed Jul 10, 2023
1 parent 0b6d6c6 commit 8e28c7e
Show file tree
Hide file tree
Showing 14 changed files with 702 additions and 474 deletions.
34 changes: 20 additions & 14 deletions service/config/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,28 @@ import (

// Model represents model config
type Model struct {
ID string
Dir string
URL string
Debug bool
Location string `json:",omitempty" yaml:",omitempty"`
Tags []string
OutputType string `json:",omitempty" yaml:",omitempty"`
UseDict *bool `json:",omitempty" yaml:",omitempty"`
DictURL string
Transformer string `json:",omitempty" yaml:",omitempty"`
DataStore string `json:",omitempty" yaml:",omitempty"`
Modified *Modified `json:",omitempty" yaml:",omitempty"`
Stream *config.Stream `json:",omitempty" yaml:",omitempty"`
ID string
Dir string
URL string
Debug bool

Location string `json:",omitempty" yaml:",omitempty"`
Tags []string

OutputType string `json:",omitempty" yaml:",omitempty"` // Deprecated - we can infer output types from TF graph
UseDict *bool `json:",omitempty" yaml:",omitempty"`
DictURL string // Deprecated - we usually extract the dictionary/vocabulary from TF graph

Transformer string `json:",omitempty" yaml:",omitempty"`
DataStore string `json:",omitempty" yaml:",omitempty"`

Modified *Modified `json:",omitempty" yaml:",omitempty"`
Stream *config.Stream `json:",omitempty" yaml:",omitempty"`

shared.MetaInput `json:",omitempty" yaml:",inline"`
DictMeta DictionaryMeta
Test TestPayload `json:",omitempty" yaml:",omitempty"`

Test TestPayload `json:",omitempty" yaml:",omitempty"`
}

type TestPayload struct {
Expand Down
200 changes: 200 additions & 0 deletions service/endpoint/checker/self.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package checker

import (
"context"
"fmt"
"log"
"math/rand"
"strconv"
"time"

"github.com/viant/mly/service/config"
"github.com/viant/mly/shared"
"github.com/viant/mly/shared/client"
"github.com/viant/toolbox"
)

func SelfTest(host []*client.Host, timeout time.Duration, modelID string, usesTransformer bool, inputs_ []*shared.Field, tp config.TestPayload, outputs []*shared.Field, debug bool) error {
cli, err := client.New(modelID, host, client.WithDebug(true))
if err != nil {
return fmt.Errorf("%s:%w", modelID, err)
}

inputs := cli.Config.Datastore.MetaInput.Inputs

// generate payload

var testData map[string]interface{}
var batchSize int
if len(tp.Batch) > 0 {
for k, v := range tp.Batch {
testData[k] = v
batchSize = len(v)
}
} else {
if len(tp.Single) > 0 {
testData = tp.Single

for _, field := range inputs {
n := field.Name
sv, ok := tp.Single[n]
switch field.DataType {
case "int", "int32", "int64":
if !ok {
testData[n] = rand.Int31()
} else {
switch tsv := sv.(type) {
case string:
testData[n], err = strconv.Atoi(tsv)
if err != nil {
return err
}
}
}
case "float", "float32", "float64":
testData[n] = rand.Float32()
default:
if !ok {
testData[n] = fmt.Sprintf("test-%d", rand.Int31())
} else {
testData[n] = toolbox.AsString(sv)
}
}
}
} else {
testData = make(map[string]interface{})
for _, field := range inputs {
n := field.Name
switch field.DataType {
case "int", "int32", "int64":
testData[n] = rand.Int31()
case "float", "float32", "float64":
testData[n] = rand.Float32()
default:
testData[n] = fmt.Sprintf("test-%d", rand.Int31())
}
}
}

if tp.SingleBatch {
for _, field := range inputs {
fn := field.Name
tv := testData[fn]
switch field.DataType {
case "int", "int32", "int64":
var v int
switch atv := tv.(type) {
case int:
v = atv
case int32:
case int64:
v = int(atv)
default:
return fmt.Errorf("test data malformed: %s expected int-like, found %T", fn, tv)
}

b := [1]int{v}
testData[fn] = b[:]
case "float", "float32", "float64":
var v float32
switch atv := tv.(type) {
case float32:
v = atv
case float64:
v = float32(atv)
default:
return fmt.Errorf("test data malformed: %s expected float32-like, found %T", fn, tv)
}

b := [1]float32{v}
testData[fn] = b[:]
default:
switch atv := tv.(type) {
case string:
b := [1]string{atv}
testData[fn] = b[:]
default:
return fmt.Errorf("test data malformed: %s expected string-like, found %T", fn, tv)
}
}
}

batchSize = 1
}
}

if debug {
log.Printf("[%s test] batchSize:%d %+v", modelID, batchSize, testData)
}

msg := cli.NewMessage()
defer msg.Release()

if batchSize > 0 {
msg.SetBatchSize(batchSize)
}

for k, vs := range testData {
switch at := vs.(type) {
case []float32:
msg.FloatsKey(k, at)
case []float64:
rat := make([]float32, len(at))
for i, v := range at {
rat[i] = float32(v)
}
msg.FloatsKey(k, rat)
case float32:
msg.FloatKey(k, at)
case float64:
msg.FloatKey(k, float32(at))

case []int:
msg.IntsKey(k, at)
case []int32:
rat := make([]int, len(at))
for i, v := range at {
rat[i] = int(v)
}
msg.IntsKey(k, rat)
case []int64:
rat := make([]int, len(at))
for i, v := range at {
rat[i] = int(v)
}
msg.IntsKey(k, rat)

case int:
msg.IntKey(k, at)
case int32:
msg.IntKey(k, int(at))
case int64:
msg.IntKey(k, int(at))

case []string:
msg.StringsKey(k, at)
case string:
msg.StringKey(k, at)

default:
return fmt.Errorf("%s:could not form payload %T (%+v)", modelID, at, at)
}
}

resp := new(client.Response)
// see if there is a transform
// if there is, trigger the transform with mock data?
resp.Data = Generated(outputs, batchSize, usesTransformer)()

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

// send response
err = cli.Run(ctx, msg, resp)

if err != nil {
return fmt.Errorf("%s:Run():%v", modelID, err)
}

return nil
}
12 changes: 6 additions & 6 deletions service/endpoint/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"io/ioutil"
"net/http"

"github.com/pkg/errors"
"github.com/viant/afs"
Expand All @@ -12,8 +13,6 @@ import (
sconfig "github.com/viant/mly/shared/config"
"github.com/viant/toolbox"
"gopkg.in/yaml.v2"

"net/http"
)

const (
Expand All @@ -24,10 +23,11 @@ const (
type Config struct {
config.ModelList `json:",omitempty" yaml:",inline"`
sconfig.DatastoreList `json:",omitempty" yaml:",inline"`
Endpoint econfig.Endpoint
EnableMemProf bool
EnableCPUProf bool
AllowedSubnet []string `json:",omitempty" yaml:",omitempty"`

Endpoint econfig.Endpoint
EnableMemProf bool
EnableCPUProf bool
AllowedSubnet []string `json:",omitempty" yaml:",omitempty"`
}

// Init initialise config
Expand Down
16 changes: 13 additions & 3 deletions service/endpoint/config/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@ import (
"time"
)

//Endpoint represents an endpoint
// Endpoint represents an endpoint
type Endpoint struct {
Port int
ReadTimeoutMs int `json:",omitempty" yaml:",omitempty"`
WriteTimeoutMs int `json:",omitempty" yaml:",omitempty"`
WriteTimeout time.Duration `json:",omitempty" yaml:",omitempty"`
MaxHeaderBytes int `json:",omitempty" yaml:",omitempty"`
PoolMaxSize int `json:",omitempty" yaml:",omitempty"`
BufferSize int `json:",omitempty" yaml:",omitempty"`

// HTTP data buffer pool - used when reading a payload, for saving memory
PoolMaxSize int `json:",omitempty" yaml:",omitempty"`
BufferSize int `json:",omitempty" yaml:",omitempty"`

MaxEvaluatorConcurrency int32 `json:",omitempty" yaml:",omitempty"`
}

//Init init applied default settings
Expand All @@ -29,13 +33,19 @@ func (e *Endpoint) Init() {
if e.WriteTimeout == 0 {
e.WriteTimeout = time.Duration(e.WriteTimeoutMs) * time.Millisecond
}

if e.MaxHeaderBytes == 0 {
e.MaxHeaderBytes = 8 * 1024
}

if e.PoolMaxSize == 0 {
e.PoolMaxSize = 512
}
if e.BufferSize == 0 {
e.BufferSize = 128 * 1024
}

if e.MaxEvaluatorConcurrency <= 0 {
e.MaxEvaluatorConcurrency = 3000
}
}
1 change: 1 addition & 0 deletions service/endpoint/config/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package config
32 changes: 0 additions & 32 deletions service/endpoint/health.go
Original file line number Diff line number Diff line change
@@ -1,33 +1 @@
package endpoint

import (
"encoding/json"
"net/http"
"sync"
)

const healthURI = "/v1/api/health"

type healthHandler struct {
healths map[string]*int32
mu *sync.Mutex
}

func NewHealthHandler() *healthHandler {
return &healthHandler{
mu: new(sync.Mutex),
healths: make(map[string]*int32),
}
}

func (h *healthHandler) RegisterHealthPoint(name string, isOkPtr *int32) {
h.mu.Lock()
defer h.mu.Unlock()
h.healths[name] = isOkPtr
}

func (h *healthHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
JSON, _ := json.Marshal(h.healths)
writer.Header().Set("Content-Type", "application/json")
writer.Write(JSON)
}
31 changes: 31 additions & 0 deletions service/endpoint/health/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package health

import (
"encoding/json"
"net/http"
"sync"
)

type HealthHandler struct {
healths map[string]*int32
mu *sync.Mutex
}

func NewHealthHandler() *HealthHandler {
return &HealthHandler{
mu: new(sync.Mutex),
healths: make(map[string]*int32),
}
}

func (h *HealthHandler) RegisterHealthPoint(name string, isOkPtr *int32) {
h.mu.Lock()
defer h.mu.Unlock()
h.healths[name] = isOkPtr
}

func (h *HealthHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
JSON, _ := json.Marshal(h.healths)
writer.Header().Set("Content-Type", "application/json")
writer.Write(JSON)
}
Loading

0 comments on commit 8e28c7e

Please sign in to comment.