Skip to content

Commit

Permalink
[Go SDK] RunInference wrapper supporting Sklearn Model Handler (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
riteshghorse authored and lostluck committed Dec 22, 2022
1 parent 95f48b2 commit 5761590
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

## New Features / Improvements

* RunInference Wrapper with Sklearn Model Handler support added in Go SDK ([#24497](https://github.com/apache/beam/issues/23382)).
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Breaking Changes
Expand Down
15 changes: 15 additions & 0 deletions sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"fmt"
"sort"
"strings"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
Expand Down Expand Up @@ -687,6 +688,20 @@ func (m *marshaller) expandCrossLanguage(namedEdge NamedEdge) (string, error) {
EnvironmentId: m.addDefaultEnv(),
}

// Add the coders for output in the marshaller even if expanded is nil
// for output coder field in expansion request.
// We need this specifically for Python External Transforms.
names := strings.Split(spec.Urn, ":")
if len(names) > 2 && names[2] == "python" {
for _, out := range edge.Output {
id, err := m.coders.Add(out.To.Coder)
if err != nil {
return "", errors.Wrapf(err, "failed to add output coder to coder registry: %v", m.coders)
}
out.To.Coder.ID = id
}
}

if edge.External.Expanded != nil {
// Outputs need to temporarily match format of unnamed Go SDK Nodes.
// After the initial pipeline is constructed, these will be used to correctly
Expand Down
27 changes: 22 additions & 5 deletions sdks/go/pkg/beam/core/runtime/xlangx/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/xlangx/expansionx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -111,12 +111,30 @@ func expand(
ext.ExpansionAddr = config
}

// The external transforms that needs to specify the output coder
// in expansion request sends tagged output as xlang.SetOutputCoder.
outputCoderID := make(map[string]string)
newOutputMap := make(map[string]int)
for tag, id := range edge.External.OutputsMap {
if tag == xlang.SetOutputCoder {
newOutputMap[graph.UnnamedOutputTag] = id
outputCoderID[tag] = edge.Output[id].To.Coder.ID
// Since only one output coder request can be specified, we break here.
// This is because with graph.UnnamedOutputTag is used as a key in edge.External.OutputsMap.
break
}
}

if len(newOutputMap) > 0 {
edge.External.OutputsMap = newOutputMap
}
return h(ctx, &HandlerParams{
Config: config,
Req: &jobpb.ExpansionRequest{
Components: comps,
Transform: transform,
Namespace: ext.Namespace,
Components: comps,
Transform: transform,
Namespace: ext.Namespace,
OutputCoderRequests: outputCoderID,
},
edge: edge,
ext: ext,
Expand Down Expand Up @@ -239,7 +257,6 @@ func startPythonExpansionService(service, extraPackage string) (stopFunc func()
if err != nil {
return nil, "", err
}
log.Debugf(context.Background(), "path: %v", venvPython)

serviceRunner, err := expansionx.NewPyExpansionServiceRunner(venvPython, service, "")
if err != nil {
Expand Down
16 changes: 15 additions & 1 deletion sdks/go/pkg/beam/core/runtime/xlangx/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,21 @@ func addNamespace(t *pipepb.PTransform, c *pipepb.Components, namespace string)
}
}

// c.Transforms = make(map[string]*pipepb.PTransform)
// update component coderIDs for other coders not present in t.Inputs, t.Outputs
for id, coder := range c.GetCoders() {
if _, exists := idMap[id]; exists {
continue
}
var updatedComponentCoderIDs []string
updatedComponentCoderIDs = append(updatedComponentCoderIDs, coder.ComponentCoderIds...)
for i, ccid := range coder.GetComponentCoderIds() {
if _, exists := idMap[ccid]; exists {
updatedComponentCoderIDs[i] = idMap[ccid]
}
}
coder.ComponentCoderIds = updatedComponentCoderIDs
}

sourceName := t.UniqueName
for _, t := range c.Transforms {
if t.UniqueName != sourceName {
Expand Down
177 changes: 177 additions & 0 deletions sdks/go/pkg/beam/transforms/xlang/inference/inference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package inference has the cross language implementation of RunInference API implemented in Python SDK.
// An exapnsion service for python external transforms can be started by running
//
// $ python -m apache_beam.runners.portability.expansion_service_main -p $PORT_FOR_EXPANSION_SERVICE
package inference

import (
"context"
"reflect"
"strings"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/xlangx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
"github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang"
"github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang/python"
)

func init() {
beam.RegisterType(reflect.TypeOf((*runInferenceConfig)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*argsStruct)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*sklearn)(nil)).Elem())
beam.RegisterType(reflect.TypeOf((*PredictionResult)(nil)).Elem())
}

var outputT = reflect.TypeOf((*PredictionResult)(nil)).Elem()

// PredictionResult represents the result of a prediction obtained from Python's RunInference API.
type PredictionResult struct {
Example []int64 `beam:"example"`
Inference int32 `beam:"inference"`
}

type runInferenceConfig struct {
args argsStruct
expansionAddr string
extraPackages []string
}

type runInferenceOption func(*runInferenceConfig)

// WithArgs set arguments for the RunInference transform parameters.
func WithArgs(args []string) runInferenceOption {
return func(c *runInferenceConfig) {
c.args.args = append(c.args.args, args...)
}
}

// WithExpansionAddr provides URL for Python expansion service.
func WithExpansionAddr(expansionAddr string) runInferenceOption {
return func(c *runInferenceConfig) {
c.expansionAddr = expansionAddr
}
}

// WithExtraPackages is used to specify additional packages when using an automated expansion service.
// Packages required to run the required Model are included implicitly,
// eg: scikit-learn, pandas for Sklearn Model Handler.
func WithExtraPackages(extraPackages []string) runInferenceOption {
return func(c *runInferenceConfig) {
c.extraPackages = extraPackages
}
}

type argsStruct struct {
args []string
}

func inferExtraPackages(modelHandler string) []string {
extraPackages := []string{}

mhLowered := strings.ToLower(modelHandler)
if strings.Contains(mhLowered, "sklearn") {
extraPackages = append(extraPackages, "scikit-learn", "pandas")
} else if strings.Contains(mhLowered, "pytorch") {
extraPackages = append(extraPackages, "torch")
}
if len(extraPackages) > 0 {
log.Infof(context.Background(), "inferExtraPackages: %v", extraPackages)
}
return extraPackages
}

// sklearn configures the parameters for the sklearn inference transform.
type sklearn struct {
// ModelHandlerProvider defines the model handler to be used.
ModelHandlerProvider python.CallableSource `beam:"model_handler_provider"`
// ModelURI indicates the model path to be used for Sklearn Model Handler.
ModelURI string `beam:"model_uri"`
}

// sklearnConfig could be used to configure other optional parameters in future if necessary.
type sklearnConfig func(*sklearn)

// SklearnModel configures the parameters required to perform RunInference transform
// on Sklearn Model. It returns an sklearn object which should be used to call
// RunInference transform.
// ModelURI is the required parameter indicating the path to the sklearn model.
//
// Example:
// modelURI := "gs://storage/model"
// model := inference.SklearnModel(modelURI)
// prediction := model.RunInference(s, input, inference.WithExpansionAddr("localhost:9000"))
func SklearnModel(modelURI string, opts ...sklearnConfig) sklearn {
sm := sklearn{
ModelHandlerProvider: python.CallableSource("apache_beam.ml.inference.sklearn_inference.SklearnModelHandlerNumpy"),
ModelURI: modelURI,
}
for _, opt := range opts {
opt(&sm)
}
return sm
}

// RunInference transforms the input pcollection by calling RunInference in Python SDK
// using Sklearn Model Handler with python expansion service.
// ExpansionAddress can be provided by using inference.WithExpansionAddr(address).
// NOTE: This wrapper doesn't work for keyed input PCollection.
//
// Example:
// inputRow := [][]int64{{0, 0}, {1, 1}}
// input := beam.CreateList(s, inputRow)
// modelURI = gs://example.com/tmp/staged/sklearn_model
// model := inference.SklearnModel(modelURI)
// prediction := model.RunInference(s, input, inference.WithExpansionAddr("localhost:9000"))
func (sk sklearn) RunInference(s beam.Scope, col beam.PCollection, opts ...runInferenceOption) beam.PCollection {
s.Scope("xlang.inference.sklearn.RunInference")

cfg := runInferenceConfig{}
for _, opt := range opts {
opt(&cfg)
}
if cfg.expansionAddr == "" {
cfg.extraPackages = append(cfg.extraPackages, inferExtraPackages(string(sk.ModelHandlerProvider))...)
}

return runInference[sklearn](s, col, sk, cfg)
}

func runInference[Kwargs any](s beam.Scope, col beam.PCollection, k Kwargs, cfg runInferenceConfig) beam.PCollection {
if cfg.expansionAddr == "" {
if len(cfg.extraPackages) > 0 {
cfg.expansionAddr = xlangx.UseAutomatedPythonExpansionService(python.ExpansionServiceModule, xlangx.AddExtraPackages(cfg.extraPackages))
} else {
cfg.expansionAddr = xlangx.UseAutomatedPythonExpansionService(python.ExpansionServiceModule)
}
}
pet := python.NewExternalTransform[argsStruct, Kwargs]("apache_beam.ml.inference.base.RunInference.from_callable")
pet.WithKwargs(k)
pet.WithArgs(cfg.args)
pl := beam.CrossLanguagePayload(pet)

// Since External RunInference Transform with Python Expansion Service will send encoded output, we need to specify
// output coder. We do this by setting the output tag as xlang.SetOutputCoder so that while sending
// an expansion request we populate the OutputCoderRequests field. If this is not done then the encoded output
// may not be decoded with coders known to Go SDK.
outputType := map[string]typex.FullType{xlang.SetOutputCoder: typex.New(outputT)}

result := beam.CrossLanguage(s, "beam:transforms:python:fully_qualified_named", pl, cfg.expansionAddr, beam.UnnamedInput(col), outputType)
return result[beam.UnnamedOutputTag()]
}
21 changes: 21 additions & 0 deletions sdks/go/pkg/beam/transforms/xlang/xlang.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package xlang contains cross-language transforms.
package xlang

// SetOutputCoder is used to set the tagged input in cases of external transforms
// where the output coder request field needs to be specified.
const SetOutputCoder = "SetOutputCoder"
45 changes: 45 additions & 0 deletions sdks/go/test/integration/transforms/xlang/inference/inference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package inference

import (
"github.com/apache/beam/sdks/v2/go/pkg/beam"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
"github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang/inference"
)

func SklearnInference(expansionAddr string) *beam.Pipeline {
p, s := beam.NewPipelineWithRoot()

inputRow := [][]int64{{0, 0}, {1, 1}}
input := beam.CreateList(s, inputRow)
output := []inference.PredictionResult{
{
Example: []int64{0, 0},
Inference: 0,
},
{
Example: []int64{1, 1},
Inference: 1,
},
}
outCol := inference.SklearnModel("/tmp/staged/sklearn_model").RunInference(s, input, inference.WithExpansionAddr(expansionAddr))
passert.Equals(s, outCol, output[0], output[1])
return p
}
Loading

0 comments on commit 5761590

Please sign in to comment.