diff --git a/executor/api/util/utils.go b/executor/api/util/utils.go index 9309d65e02..3a72a06f5e 100644 --- a/executor/api/util/utils.go +++ b/executor/api/util/utils.go @@ -10,7 +10,11 @@ func ExtractRouteFromSeldonMessage(msg *proto.SeldonMessage) []int { values := msg.GetData().GetNdarray().GetValues() routeArr := make([]int, len(values)) for i, value := range values { - routeArr[i] = int(value.GetNumberValue()) + if listValue := value.GetListValue(); listValue != nil { + routeArr[i] = int(listValue.GetValues()[0].GetNumberValue()) + } else { + routeArr[i] = int(value.GetNumberValue()) + } } return routeArr case *proto.DefaultData_Tensor: diff --git a/executor/api/util/utils_test.go b/executor/api/util/utils_test.go new file mode 100644 index 0000000000..b70d92d6ae --- /dev/null +++ b/executor/api/util/utils_test.go @@ -0,0 +1,47 @@ +package util + +import ( + "testing" + + "github.com/golang/protobuf/jsonpb" + . "github.com/onsi/gomega" + "github.com/seldonio/seldon-core/executor/api/grpc/seldon/proto" +) + +func TestExtractRouteFromSeldonMessage(t *testing.T) { + g := NewGomegaWithT(t) + + cases := []struct { + msg string + expected []int + }{ + { + msg: `{"data":{"names":["X1L"],"ndarray":[[1]]}}`, + expected: []int{1}, + }, + { + msg: `{"data":{"ndarray":[2]}}`, + expected: []int{2}, + }, + { + msg: `{"data":{"ndarray":[3,4]}}`, + expected: []int{3, 4}, + }, + { + msg: `{"data":{"names":["X1L","X2L"],"ndarray":[[1,2],[3,4]]}}`, + expected: []int{1, 3}, + }, + { + msg: `{"data":{"ndarray":[]}}`, + expected: []int{}, + }, + } + + for _, c := range cases { + var sm proto.SeldonMessage + jsonpb.UnmarshalString(c.msg, &sm) + routes := ExtractRouteFromSeldonMessage(&sm) + + g.Expect(routes).To(Equal(c.expected)) + } +}