diff --git a/CHANGES.md b/CHANGES.md index f737862ca2f3..310078a58067 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -62,6 +62,7 @@ ## New Features / Improvements +* RunInference Wrapper with Sklearn Model Handler support added in Go SDK ([#24497](https://github.com/apache/beam/issues/23382)). * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go index 498afeac8289..cd5946742fd3 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "sort" + "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" @@ -687,6 +688,20 @@ func (m *marshaller) expandCrossLanguage(namedEdge NamedEdge) (string, error) { EnvironmentId: m.addDefaultEnv(), } + // Add the coders for output in the marshaller even if expanded is nil + // for output coder field in expansion request. + // We need this specifically for Python External Transforms. + names := strings.Split(spec.Urn, ":") + if len(names) > 2 && names[2] == "python" { + for _, out := range edge.Output { + id, err := m.coders.Add(out.To.Coder) + if err != nil { + return "", errors.Wrapf(err, "failed to add output coder to coder registry: %v", m.coders) + } + out.To.Coder.ID = id + } + } + if edge.External.Expanded != nil { // Outputs need to temporarily match format of unnamed Go SDK Nodes. // After the initial pipeline is constructed, these will be used to correctly diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/expand.go b/sdks/go/pkg/beam/core/runtime/xlangx/expand.go index d6ef94711d96..9076b93e1f87 100644 --- a/sdks/go/pkg/beam/core/runtime/xlangx/expand.go +++ b/sdks/go/pkg/beam/core/runtime/xlangx/expand.go @@ -28,9 +28,9 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/xlangx/expansionx" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" - "github.com/apache/beam/sdks/v2/go/pkg/beam/log" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang" "google.golang.org/grpc" ) @@ -111,12 +111,30 @@ func expand( ext.ExpansionAddr = config } + // The external transforms that needs to specify the output coder + // in expansion request sends tagged output as xlang.SetOutputCoder. + outputCoderID := make(map[string]string) + newOutputMap := make(map[string]int) + for tag, id := range edge.External.OutputsMap { + if tag == xlang.SetOutputCoder { + newOutputMap[graph.UnnamedOutputTag] = id + outputCoderID[tag] = edge.Output[id].To.Coder.ID + // Since only one output coder request can be specified, we break here. + // This is because with graph.UnnamedOutputTag is used as a key in edge.External.OutputsMap. + break + } + } + + if len(newOutputMap) > 0 { + edge.External.OutputsMap = newOutputMap + } return h(ctx, &HandlerParams{ Config: config, Req: &jobpb.ExpansionRequest{ - Components: comps, - Transform: transform, - Namespace: ext.Namespace, + Components: comps, + Transform: transform, + Namespace: ext.Namespace, + OutputCoderRequests: outputCoderID, }, edge: edge, ext: ext, @@ -239,7 +257,6 @@ func startPythonExpansionService(service, extraPackage string) (stopFunc func() if err != nil { return nil, "", err } - log.Debugf(context.Background(), "path: %v", venvPython) serviceRunner, err := expansionx.NewPyExpansionServiceRunner(venvPython, service, "") if err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/xlangx/namespace.go b/sdks/go/pkg/beam/core/runtime/xlangx/namespace.go index e3887ff01d53..723d2e25cce7 100644 --- a/sdks/go/pkg/beam/core/runtime/xlangx/namespace.go +++ b/sdks/go/pkg/beam/core/runtime/xlangx/namespace.go @@ -100,7 +100,21 @@ func addNamespace(t *pipepb.PTransform, c *pipepb.Components, namespace string) } } - // c.Transforms = make(map[string]*pipepb.PTransform) + // update component coderIDs for other coders not present in t.Inputs, t.Outputs + for id, coder := range c.GetCoders() { + if _, exists := idMap[id]; exists { + continue + } + var updatedComponentCoderIDs []string + updatedComponentCoderIDs = append(updatedComponentCoderIDs, coder.ComponentCoderIds...) + for i, ccid := range coder.GetComponentCoderIds() { + if _, exists := idMap[ccid]; exists { + updatedComponentCoderIDs[i] = idMap[ccid] + } + } + coder.ComponentCoderIds = updatedComponentCoderIDs + } + sourceName := t.UniqueName for _, t := range c.Transforms { if t.UniqueName != sourceName { diff --git a/sdks/go/pkg/beam/transforms/xlang/inference/inference.go b/sdks/go/pkg/beam/transforms/xlang/inference/inference.go new file mode 100644 index 000000000000..b55ccb276c4b --- /dev/null +++ b/sdks/go/pkg/beam/transforms/xlang/inference/inference.go @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package inference has the cross language implementation of RunInference API implemented in Python SDK. +// An exapnsion service for python external transforms can be started by running +// +// $ python -m apache_beam.runners.portability.expansion_service_main -p $PORT_FOR_EXPANSION_SERVICE +package inference + +import ( + "context" + "reflect" + "strings" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/xlangx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex" + "github.com/apache/beam/sdks/v2/go/pkg/beam/log" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang/python" +) + +func init() { + beam.RegisterType(reflect.TypeOf((*runInferenceConfig)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*argsStruct)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*sklearn)(nil)).Elem()) + beam.RegisterType(reflect.TypeOf((*PredictionResult)(nil)).Elem()) +} + +var outputT = reflect.TypeOf((*PredictionResult)(nil)).Elem() + +// PredictionResult represents the result of a prediction obtained from Python's RunInference API. +type PredictionResult struct { + Example []int64 `beam:"example"` + Inference int32 `beam:"inference"` +} + +type runInferenceConfig struct { + args argsStruct + expansionAddr string + extraPackages []string +} + +type runInferenceOption func(*runInferenceConfig) + +// WithArgs set arguments for the RunInference transform parameters. +func WithArgs(args []string) runInferenceOption { + return func(c *runInferenceConfig) { + c.args.args = append(c.args.args, args...) + } +} + +// WithExpansionAddr provides URL for Python expansion service. +func WithExpansionAddr(expansionAddr string) runInferenceOption { + return func(c *runInferenceConfig) { + c.expansionAddr = expansionAddr + } +} + +// WithExtraPackages is used to specify additional packages when using an automated expansion service. +// Packages required to run the required Model are included implicitly, +// eg: scikit-learn, pandas for Sklearn Model Handler. +func WithExtraPackages(extraPackages []string) runInferenceOption { + return func(c *runInferenceConfig) { + c.extraPackages = extraPackages + } +} + +type argsStruct struct { + args []string +} + +func inferExtraPackages(modelHandler string) []string { + extraPackages := []string{} + + mhLowered := strings.ToLower(modelHandler) + if strings.Contains(mhLowered, "sklearn") { + extraPackages = append(extraPackages, "scikit-learn", "pandas") + } else if strings.Contains(mhLowered, "pytorch") { + extraPackages = append(extraPackages, "torch") + } + if len(extraPackages) > 0 { + log.Infof(context.Background(), "inferExtraPackages: %v", extraPackages) + } + return extraPackages +} + +// sklearn configures the parameters for the sklearn inference transform. +type sklearn struct { + // ModelHandlerProvider defines the model handler to be used. + ModelHandlerProvider python.CallableSource `beam:"model_handler_provider"` + // ModelURI indicates the model path to be used for Sklearn Model Handler. + ModelURI string `beam:"model_uri"` +} + +// sklearnConfig could be used to configure other optional parameters in future if necessary. +type sklearnConfig func(*sklearn) + +// SklearnModel configures the parameters required to perform RunInference transform +// on Sklearn Model. It returns an sklearn object which should be used to call +// RunInference transform. +// ModelURI is the required parameter indicating the path to the sklearn model. +// +// Example: +// modelURI := "gs://storage/model" +// model := inference.SklearnModel(modelURI) +// prediction := model.RunInference(s, input, inference.WithExpansionAddr("localhost:9000")) +func SklearnModel(modelURI string, opts ...sklearnConfig) sklearn { + sm := sklearn{ + ModelHandlerProvider: python.CallableSource("apache_beam.ml.inference.sklearn_inference.SklearnModelHandlerNumpy"), + ModelURI: modelURI, + } + for _, opt := range opts { + opt(&sm) + } + return sm +} + +// RunInference transforms the input pcollection by calling RunInference in Python SDK +// using Sklearn Model Handler with python expansion service. +// ExpansionAddress can be provided by using inference.WithExpansionAddr(address). +// NOTE: This wrapper doesn't work for keyed input PCollection. +// +// Example: +// inputRow := [][]int64{{0, 0}, {1, 1}} +// input := beam.CreateList(s, inputRow) +// modelURI = gs://example.com/tmp/staged/sklearn_model +// model := inference.SklearnModel(modelURI) +// prediction := model.RunInference(s, input, inference.WithExpansionAddr("localhost:9000")) +func (sk sklearn) RunInference(s beam.Scope, col beam.PCollection, opts ...runInferenceOption) beam.PCollection { + s.Scope("xlang.inference.sklearn.RunInference") + + cfg := runInferenceConfig{} + for _, opt := range opts { + opt(&cfg) + } + if cfg.expansionAddr == "" { + cfg.extraPackages = append(cfg.extraPackages, inferExtraPackages(string(sk.ModelHandlerProvider))...) + } + + return runInference[sklearn](s, col, sk, cfg) +} + +func runInference[Kwargs any](s beam.Scope, col beam.PCollection, k Kwargs, cfg runInferenceConfig) beam.PCollection { + if cfg.expansionAddr == "" { + if len(cfg.extraPackages) > 0 { + cfg.expansionAddr = xlangx.UseAutomatedPythonExpansionService(python.ExpansionServiceModule, xlangx.AddExtraPackages(cfg.extraPackages)) + } else { + cfg.expansionAddr = xlangx.UseAutomatedPythonExpansionService(python.ExpansionServiceModule) + } + } + pet := python.NewExternalTransform[argsStruct, Kwargs]("apache_beam.ml.inference.base.RunInference.from_callable") + pet.WithKwargs(k) + pet.WithArgs(cfg.args) + pl := beam.CrossLanguagePayload(pet) + + // Since External RunInference Transform with Python Expansion Service will send encoded output, we need to specify + // output coder. We do this by setting the output tag as xlang.SetOutputCoder so that while sending + // an expansion request we populate the OutputCoderRequests field. If this is not done then the encoded output + // may not be decoded with coders known to Go SDK. + outputType := map[string]typex.FullType{xlang.SetOutputCoder: typex.New(outputT)} + + result := beam.CrossLanguage(s, "beam:transforms:python:fully_qualified_named", pl, cfg.expansionAddr, beam.UnnamedInput(col), outputType) + return result[beam.UnnamedOutputTag()] +} diff --git a/sdks/go/pkg/beam/transforms/xlang/xlang.go b/sdks/go/pkg/beam/transforms/xlang/xlang.go new file mode 100644 index 000000000000..7980c5619811 --- /dev/null +++ b/sdks/go/pkg/beam/transforms/xlang/xlang.go @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package xlang contains cross-language transforms. +package xlang + +// SetOutputCoder is used to set the tagged input in cases of external transforms +// where the output coder request field needs to be specified. +const SetOutputCoder = "SetOutputCoder" diff --git a/sdks/go/test/integration/transforms/xlang/inference/inference.go b/sdks/go/test/integration/transforms/xlang/inference/inference.go new file mode 100644 index 000000000000..7fd89874fa94 --- /dev/null +++ b/sdks/go/test/integration/transforms/xlang/inference/inference.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inference + +import ( + "github.com/apache/beam/sdks/v2/go/pkg/beam" + _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow" + _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink" + _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" + "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang/inference" +) + +func SklearnInference(expansionAddr string) *beam.Pipeline { + p, s := beam.NewPipelineWithRoot() + + inputRow := [][]int64{{0, 0}, {1, 1}} + input := beam.CreateList(s, inputRow) + output := []inference.PredictionResult{ + { + Example: []int64{0, 0}, + Inference: 0, + }, + { + Example: []int64{1, 1}, + Inference: 1, + }, + } + outCol := inference.SklearnModel("/tmp/staged/sklearn_model").RunInference(s, input, inference.WithExpansionAddr(expansionAddr)) + passert.Equals(s, outCol, output[0], output[1]) + return p +} diff --git a/sdks/go/test/integration/transforms/xlang/inference/inference_test.go b/sdks/go/test/integration/transforms/xlang/inference/inference_test.go new file mode 100644 index 000000000000..72892c7c0b15 --- /dev/null +++ b/sdks/go/test/integration/transforms/xlang/inference/inference_test.go @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inference + +import ( + "flag" + "log" + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam" + _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow" + _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink" + _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" + "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" + "github.com/apache/beam/sdks/v2/go/test/integration" +) + +var expansionAddr string // Populate with expansion address labelled "python_transform". + +func checkFlags(t *testing.T) { + if expansionAddr == "" { + t.Skip("No python transform expansion address provided.") + } +} + +func TestSklearnInference(t *testing.T) { + integration.CheckFilters(t) + checkFlags(t) + p := SklearnInference(expansionAddr) + ptest.RunAndValidate(t, p) +} + +func TestMain(m *testing.M) { + flag.Parse() + beam.Init() + + services := integration.NewExpansionServices() + defer func() { services.Shutdown() }() + addr, err := services.GetAddr("python_transform") + if err != nil { + log.Printf("skipping missing expansion service: %v", err) + } else { + expansionAddr = addr + } + + ptest.MainRet(m) +}