diff --git a/sdk/go/go.mod b/sdk/go/go.mod index a61212a296..7c029da109 100644 --- a/sdk/go/go.mod +++ b/sdk/go/go.mod @@ -6,6 +6,7 @@ require ( github.com/golang/protobuf v1.3.2 github.com/google/go-cmp v0.3.0 github.com/opentracing/opentracing-go v1.1.0 + github.com/stretchr/testify v1.4.0 // indirect go.opencensus.io v0.22.1 google.golang.org/grpc v1.24.0 ) diff --git a/sdk/go/go.sum b/sdk/go/go.sum index f5ad986fc7..56df48673e 100644 --- a/sdk/go/go.sum +++ b/sdk/go/go.sum @@ -1,6 +1,9 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -15,6 +18,11 @@ github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= go.opencensus.io v0.22.1 h1:8dP3SGL7MPB94crU3bEPplMPe83FI4EouesJUeFHv50= go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -36,6 +44,7 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd h1:r7DufRZuZbWB7j439YfAzP8RPDa9unLkpwQKUYbIMPI= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -56,5 +65,9 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s= google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/sdk/go/request.go b/sdk/go/request.go index cf84852a0c..a69ab8b55f 100644 --- a/sdk/go/request.go +++ b/sdk/go/request.go @@ -3,11 +3,8 @@ package feast import ( "fmt" "github.com/gojek/feast/sdk/go/protos/feast/serving" - "github.com/golang/protobuf/ptypes/duration" - "github.com/golang/protobuf/ptypes/timestamp" "strconv" "strings" - "time" ) var ( @@ -21,26 +18,21 @@ type OnlineFeaturesRequest struct { // in the format featureSet:version:featureName. Features []string - // MaxAgeSeconds is the maximum allowed staleness of the features in seconds. This max age will be applied to all - // featureSets in this request. - // Setting this value to 0 will cause feast to default to the max age specified on the feature set spec, if any. - MaxAgeSeconds int - // Entities is the list of entity rows to retrieve features on. Each row is a map of entity name to entity value. Entities []Row } // Builds the feast-specified request payload from the wrapper. func (r OnlineFeaturesRequest) buildRequest() (*serving.GetOnlineFeaturesRequest, error) { - featureSets, err := buildFeatureSets(r.Features, r.MaxAgeSeconds) + featureSets, err := buildFeatureSets(r.Features) if err != nil { return nil, err } + entityRows := make([]*serving.GetOnlineFeaturesRequest_EntityRow, len(r.Entities)) for i := range r.Entities { entityRows[i] = &serving.GetOnlineFeaturesRequest_EntityRow{ - EntityTimestamp: ×tamp.Timestamp{Seconds: time.Now().Unix()}, Fields: r.Entities[i], } } @@ -50,14 +42,14 @@ func (r OnlineFeaturesRequest) buildRequest() (*serving.GetOnlineFeaturesRequest }, nil } -func buildFeatureSets(features []string, maxAgeSeconds int) ([]*serving.GetOnlineFeaturesRequest_FeatureSet, error) { +func buildFeatureSets(features []string) ([]*serving.GetOnlineFeaturesRequest_FeatureSet, error) { featureSetMap := map[string]*serving.GetOnlineFeaturesRequest_FeatureSet{} for _, feature := range features { split := strings.Split(feature, ":") - featureSetName, featureSetVersion, featureName := split[0], split[1], split[2] if len(split) != 3 { return nil, fmt.Errorf(ErrInvalidFeatureName, feature) } + featureSetName, featureSetVersion, featureName := split[0], split[1], split[2] key := featureSetName + ":" + featureSetVersion if fs, ok := featureSetMap[key]; !ok { version, err := strconv.Atoi(featureSetVersion) @@ -68,7 +60,6 @@ func buildFeatureSets(features []string, maxAgeSeconds int) ([]*serving.GetOnlin Name: featureSetName, Version: int32(version), FeatureNames: []string{featureName}, - MaxAge: &duration.Duration{Seconds: int64(maxAgeSeconds)}, } } else { fs.FeatureNames = append(fs.GetFeatureNames(), featureName) diff --git a/sdk/go/request_test.go b/sdk/go/request_test.go index 57fcfb7980..2cf67be151 100644 --- a/sdk/go/request_test.go +++ b/sdk/go/request_test.go @@ -5,7 +5,6 @@ import ( "github.com/gojek/feast/sdk/go/protos/feast/serving" "github.com/gojek/feast/sdk/go/protos/feast/types" json "github.com/golang/protobuf/jsonpb" - "github.com/golang/protobuf/ptypes/duration" "github.com/google/go-cmp/cmp" "testing" ) @@ -22,7 +21,6 @@ func TestGetOnlineFeaturesRequest(t *testing.T) { name: "valid", req: OnlineFeaturesRequest{ Features: []string{"fs:1:feature1", "fs:1:feature2", "fs:2:feature1"}, - MaxAgeSeconds: 10, Entities: []Row{ {"entity1": Int64Val(1), "entity2": StrVal("bob")}, {"entity1": Int64Val(1), "entity2": StrVal("annie")}, @@ -35,13 +33,11 @@ func TestGetOnlineFeaturesRequest(t *testing.T) { Name: "fs", Version: 1, FeatureNames: []string{"feature1", "feature2"}, - MaxAge: &duration.Duration{Seconds: 10}, }, { Name: "fs", Version: 2, FeatureNames: []string{"feature1"}, - MaxAge: &duration.Duration{Seconds: 10}, }, }, EntityRows: []*serving.GetOnlineFeaturesRequest_EntityRow{ @@ -73,7 +69,6 @@ func TestGetOnlineFeaturesRequest(t *testing.T) { name: "invalid_feature_name/wrong_format", req: OnlineFeaturesRequest{ Features: []string{"fs1:feature1"}, - MaxAgeSeconds: 10, Entities: []Row{}, }, wantErr: true, @@ -83,7 +78,6 @@ func TestGetOnlineFeaturesRequest(t *testing.T) { name: "invalid_feature_name/invalid_version", req: OnlineFeaturesRequest{ Features: []string{"fs:a:feature1"}, - MaxAgeSeconds: 10, Entities: []Row{}, }, wantErr: true, @@ -101,9 +95,6 @@ func TestGetOnlineFeaturesRequest(t *testing.T) { t.Errorf("error = %v, expected err = %v", err, tc.err) return } - for i, er := range got.GetEntityRows() { - tc.want.EntityRows[i].EntityTimestamp = er.GetEntityTimestamp(); - } if !cmp.Equal(got, tc.want) { m := json.Marshaler{} gotJson, _ := m.MarshalToString(got) diff --git a/sdk/go/response.go b/sdk/go/response.go index fed3ed53b4..086321cacd 100644 --- a/sdk/go/response.go +++ b/sdk/go/response.go @@ -3,11 +3,13 @@ package feast import ( "fmt" "github.com/gojek/feast/sdk/go/protos/feast/serving" + "github.com/gojek/feast/sdk/go/protos/feast/types" ) var ( ErrLengthMismatch = "Length mismatch; number of na values (%d) not equal to number of features requested (%d)." ErrFeatureNotFound = "Feature %s not found in response." + ErrTypeMismatch = "Requested output of type %s does not match type of feature value returned." ) // OnlineFeaturesResponse is a wrapper around serving.GetOnlineFeaturesResponse. @@ -38,10 +40,40 @@ func (r OnlineFeaturesResponse) Int64Arrays(order []string, fillNa []int64) ([][ if !exists { return nil, fmt.Errorf(ErrFeatureNotFound, fname) } - if fValue.GetVal() == nil { + val := fValue.GetVal() + if val == nil { rows[i][j] = fillNa[j] + } else if int64Val, ok := val.(*types.Value_Int64Val); ok { + rows[i][j] = int64Val.Int64Val } else { - rows[i][j] = fValue.GetInt64Val() + return nil, fmt.Errorf(ErrTypeMismatch, "int64") + } + } + } + return rows, nil +} + +// Float64Arrays retrieves the result of the request as a list of float64 slices. Any missing values will be filled +// with the missing values provided. +func (r OnlineFeaturesResponse) Float64Arrays(order []string, fillNa []float64) ([][]float64, error) { + rows := make([][]float64, len(r.RawResponse.FieldValues)) + if len(fillNa) != len(order) { + return nil, fmt.Errorf(ErrLengthMismatch, len(fillNa), len(order)) + } + for i, val := range r.RawResponse.FieldValues { + rows[i] = make([]float64, len(order)) + for j, fname := range order { + fValue, exists := val.Fields[fname] + if !exists { + return nil, fmt.Errorf(ErrFeatureNotFound, fname) + } + val := fValue.GetVal() + if val == nil { + rows[i][j] = fillNa[j] + } else if doubleVal, ok := val.(*types.Value_DoubleVal); ok { + rows[i][j] = doubleVal.DoubleVal + } else { + return nil, fmt.Errorf(ErrTypeMismatch, "float64") } } }