Skip to content

Commit

Permalink
Adds a Java RunInference example (#23619)
Browse files Browse the repository at this point in the history
* Adds a Java RunInference example

* Fixes lint

* Fix spotless

* Fix Java PreCommit

* Address reviewer comments

* Addresses reviewer comments
  • Loading branch information
chamikaramj authored Oct 17, 2022
1 parent 693725d commit 10e15a9
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 6 deletions.
1 change: 1 addition & 0 deletions examples/java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies {
implementation library.java.kafka_clients
implementation project(path: ":sdks:java:core", configuration: "shadow")
implementation project(":sdks:java:extensions:google-cloud-platform-core")
implementation project(":sdks:java:extensions:python")
implementation project(":sdks:java:io:google-cloud-platform")
implementation project(":sdks:java:io:kafka")
implementation project(":sdks:java:extensions:ml")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.examples.multilanguage;

import java.util.ArrayList;
import java.util.List;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.python.transforms.RunInference;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation.Required;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;

/**
* An example Java Multi-language pipeline that Performs image classification on handwritten digits
* from the <a href="https://en.wikipedia.org/wiki/MNIST_database">MNIST</a> database.
*
* <p>For more details and instructions for running this please see <a
* href="https://github.com/apache/beam/tree/master/examples/multi-language">here</a>.
*/
public class SklearnMnistClassification {

/**
* We generate a Python function that produces a KV sklearn model loader and use that to
* instantiate {@link RunInference}. Note that {@code RunInference} can be instantiated with any
* arbitrary function that produces a model loader.
*/
private String getModelLoaderScript() {
String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n";
s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n";
s = s + "def get_model_handler(model_uri):\n";
s = s + " return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n";

return s;
}

/** Filters out the header of the dataset that should not be used for the computation. */
static class FilterNonRecordsFn implements SerializableFunction<String, Boolean> {

@Override
public Boolean apply(String input) {
return !input.startsWith("label");
}
}

/**
* Seperates our input records to label and data. Each input record is a set of comma separated
* string digits where first digit is the label and rest are data (pixels that represent the
* digit).
*/
static class RecordsToLabeledPixelsFn extends SimpleFunction<String, KV<Long, Iterable<Long>>> {

@Override
public KV<Long, Iterable<Long>> apply(String input) {
String[] data = Splitter.on(',').splitToList(input).toArray(new String[] {});
Long label = Long.valueOf(data[0]);
List<Long> pixels = new ArrayList<Long>();
for (int i = 1; i < data.length; i++) {
pixels.add(Long.valueOf(data[i]));
}

return KV.of(label, pixels);
}
}

/** Formats the output to a mapping from the expected digit to the inferred digit. */
static class FormatOutput extends SimpleFunction<KV<Long, Row>, String> {

@Override
public String apply(KV<Long, Row> input) {
return input.getKey() + "," + input.getValue().getString("inference");
}
}

void runExample(SklearnMnistClassificationOptions options, String expansionService) {
// Schema of the output PCollection Row type to be provided to the RunInference transform.
Schema schema =
Schema.of(
Schema.Field.of("example", Schema.FieldType.array(Schema.FieldType.INT64)),
Schema.Field.of("inference", FieldType.STRING));

Pipeline pipeline = Pipeline.create(options);
PCollection<KV<Long, Iterable<Long>>> col =
pipeline
.apply(TextIO.read().from(options.getInput()))
.apply(Filter.by(new FilterNonRecordsFn()))
.apply(MapElements.via(new RecordsToLabeledPixelsFn()));
col.apply(
RunInference.ofKVs(getModelLoaderScript(), schema, VarLongCoder.of())
.withKwarg("model_uri", options.getModelPath())
.withExpansionService(expansionService))
.apply(MapElements.via(new FormatOutput()))
.apply(TextIO.write().to(options.getOutput()));

pipeline.run().waitUntilFinish();
}

public interface SklearnMnistClassificationOptions extends PipelineOptions {

@Description("Path to an input file that contains labels and pixels to feed into the model")
@Default.String("gs://apache-beam-samples/multi-language/mnist/example_input.csv")
String getInput();

void setInput(String value);

@Description("Path for storing the output")
@Required
String getOutput();

void setOutput(String value);

@Description(
"Path to a model file that contains the pickled file of a scikit-learn model trained on MNIST data")
@Default.String("gs://apache-beam-samples/multi-language/mnist/example_model")
String getModelPath();

void setModelPath(String value);

/** Set this option to specify Python expansion service URL. */
@Description("URL of Python expansion service")
@Default.String("")
String getExpansionService();

void setExpansionService(String value);
}

public static void main(String[] args) {
SklearnMnistClassificationOptions options =
PipelineOptionsFactory.fromArgs(args).as(SklearnMnistClassificationOptions.class);
SklearnMnistClassification example = new SklearnMnistClassification();
example.runExample(options, options.getExpansionService());
}
}
126 changes: 122 additions & 4 deletions examples/multi-language/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,147 @@
This project provides examples of Apache Beam
[multi-language pipelines](https://beam.apache.org/documentation/programming-guide/#multi-language-pipelines):

## Using Java transforms from Python

* **python/addprefix** - A Python pipeline that reads a text file and attaches a prefix on the Java side to each input.
* **python/javacount** - A Python pipeline that counts words using the Java `Count.perElement()` transform.
* **python/javadatagenerator** - A Python pipeline that produces a set of strings generated from Java.
This example demonstrates the `JavaExternalTransform` API.

## Instructions for running the pipelines
### Instructions for running the pipelines

### 1) Start the expansion service
#### 1) Start the expansion service

1. Download the latest 'beam-examples-multi-language' JAR. Starting with Apache Beam 2.36.0,
you can find it in [the Maven Central Repository](https://search.maven.org/search?q=g:org.apache.beam).
2. Run the following command, replacing `<version>` and `<port>` with valid values:
`java -jar beam-examples-multi-language-<version>.jar <port> --javaClassLookupAllowlistFile='*'`

### 2) Set up a Python virtual environment for Beam
#### 2) Set up a Python virtual environment for Beam

1. See [the Python quickstart](https://beam.apache.org/get-started/quickstart-py/)
for more information.

### 3) Execute the Python pipeline
#### 3) Execute the Python pipeline

1. In a new shell, run a pipeline in the **python** directory using a Beam runner that supports
multi-language pipelines.

The Python files contain details about the actual commands to run.

## Using Python transforms from Java

### Sklearn Mnist Classification

Performs image classification on handwritten digits from the [MNIST](https://en.wikipedia.org/wiki/MNIST_database)
database.

Please see [here](https://github.com/apache/beam/tree/master/sdks/python/apache_beam/examples/inference) for
context and information regarding the corresponding Python pipeline.

Please note that the Java pipeline is
[availalble in the Beam Java examples module](https://github.com/apache/beam/tree/master/examples/java/src/main/java/org/apache/beam/examples/multilanguage/SklearnMnistClassification.java).

#### Setup

* Obtain/generate a csv input file that contains labels and pixels to feed into the model and store it in
GCS. An example input is available
[here](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_input.csv).

* Create a model file that contains the pickled file of a scikit-learn model
trained on MNIST data and store it in GCS. An example model file is available
[here](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_model).
This model was generated by by running the program given
[here](https://python-course.eu/machine-learning/training-and-testing-with-mnist.php)
on the
[example input dataset](https://storage.googleapis.com/apache-beam-samples/multi-language/mnist/example_input.csv).

* Perform Beam runner specific setup according to instructions
[here](https://beam.apache.org/get-started/quickstart-java/#run-a-pipeline).

Following instructions are for running the pipeline with the Dataflow runner. For other portable runners,
please modify the instructions according to the guidelines
[here](https://beam.apache.org/documentation/sdks/java-multi-language-pipelines/#run-with-directrunner)

#### Instructions for running the Java pipeline on released Beam (Beam 2.43.0 and later).

* Checkout the Beam examples Maven archetype for the relevant Beam version.

```
export BEAM_VERSION=<Beam version>
mvn archetype:generate \
-DarchetypeGroupId=org.apache.beam \
-DarchetypeArtifactId=beam-sdks-java-maven-archetypes-examples \
-DarchetypeVersion=$BEAM_VERSION \
-DgroupId=org.example \
-DartifactId=multi-language-beam \
-Dversion="0.1" \
-Dpackage=org.apache.beam.examples \
-DinteractiveMode=false
```

* Run the pipeline.

```
export GCP_PROJECT=<GCP project>
export GCP_BUCKET=<GCP bucket>
export GCP_REGION=<GCP region>
mvn compile exec:java -Dexec.mainClass=org.apache.beam.examples.multilanguage.SklearnMnistClassification \
-Dexec.args="--runner=DataflowRunner --project=$GCP_PROJECT \
--region=us-central1 \
--gcpTempLocation=gs://$GCP_BUCKET/multi-language-beam/tmp \
--output=gs://$GCP_BUCKET/multi-language-beam/output" \
-Pdataflow-runner
```

* Inspect the output. Each line has data separated by a comma ",". The first item is the actual label of
the digit. The second item is the predicted label of the digit.

```
gsutil cat gs://$GCP_BUCKET/multi-language-beam/output*
```

#### Instructions for running the Java pipeline at HEAD (Beam 2.41.0 and 2.42.0).

* Make sure that Docker is installed and available on your system.

* Build and push Python and Java Docker containers.

```
export DOCKER_ROOT=<Docker root>
./gradlew :sdks:python:container:py38:docker -Pdocker-repository-root=$DOCKER_ROOT -Pdocker-tag=latest
docker push $DOCKER_ROOT/beam_python3.8_sdk:latest
./gradlew :sdks:java:container:java11:docker -Pdocker-repository-root=$DOCKER_ROOT -Pdocker-tag=latest
docker push $DOCKER_ROOT/beam_java11_sdk:latest
```

* Run the pipeline using the following Gradle command (this guide assumes Dataflow runner).
Note that we override both the Java and Python SDK harness containers here.

```
export GCP_PROJECT=<GCP project>
export GCP_BUCKET=<GCP bucket>
export GCP_REGION=<GCP region>
./gradlew :examples:multi-language:sklearnMinstClassification --args=" \
--runner=DataflowRunner \
--project=$GCP_PROJECT \
--gcpTempLocation=gs://$GCP_BUCKET/multi-language-beam/tmp \
--output=gs://$GCP_BUCKET/multi-language-beam/output \
--sdkContainerImage=$DOCKER_ROOT/beam_java11_sdk:latest \
--sdkHarnessContainerImageOverrides=.*python.*,$DOCKER_ROOT/beam_python3.8_sdk:latest \
--region=${GCP_REGION}"
```

* Inspect the output. Each line has data separated by a comma ",". The first item is the actual label
of the digit. The second item is the predicted label of the digit.

```
gsutil cat gs://$GCP_BUCKET/multi-language-beam/output*
```
10 changes: 9 additions & 1 deletion examples/multi-language/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ext.summary = "Java Classes for Multi-language Examples"
dependencies {
implementation library.java.vendored_guava_26_0_jre
implementation project(path: ":sdks:java:core", configuration: "shadow")
runtimeOnly project(path: ":examples:java")
runtimeOnly project(path: ":runners:direct-java", configuration: "shadow")
runtimeOnly project(path: ":runners:google-cloud-dataflow-java")
runtimeOnly project(path: ":runners:portability:java")
Expand All @@ -47,4 +48,11 @@ task pythonDataframeWordCount(type: JavaExec) {
description "Run the Java word count example using external Python DataframeTransform"
mainClass = "org.apache.beam.examples.multilanguage.PythonDataframeWordCount"
classpath = sourceSets.main.runtimeClasspath
}
}

task sklearnMinstClassification(type: JavaExec) {
description "Run the Java pipeline that performns image classification on handwritten digits from the MNIST database"
mainClass = "org.apache.beam.examples.multilanguage.SklearnMnistClassification"
classpath = sourceSets.main.runtimeClasspath
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
* ./gradlew :examples:multi-language:pythonDataframeWordCount --args=" \
* --runner=DataflowRunner \
* --output=gs://{$OUTPUT_BUCKET}/count \
* --experiments=use_runner_v2 \
* --sdkHarnessContainerImageOverrides=.*python.*,gcr.io/apache-beam-testing/beam-sdk/beam_python{$PYTHON_VERSION}_sdk:latest"
* }</pre>
*/
Expand Down
10 changes: 10 additions & 0 deletions sdks/java/maven-archetypes/examples/generate-sources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ rsync -a \
"${EXAMPLES_ROOT}"/src/test/java/org/apache/beam/examples/complete/game/ \
"${ARCHETYPE_ROOT}/src/test/java/complete/game"

#
# Copy the Java multi-language examples.
#

mkdir -p "${ARCHETYPE_ROOT}/src/test/java/multilanguage/"

rsync -a \
"${EXAMPLES_ROOT}"/src/main/java/org/apache/beam/examples/multilanguage/ \
"${ARCHETYPE_ROOT}/src/main/java/multilanguage"

#
# Replace 'package org.apache.beam.examples' with 'package ${package}' in all Java code
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@
<version>${beam.version}</version>
</dependency>

<!-- Adds a dependency on the Python Multi-language pipelines API module. -->
<dependency>
<groupId>org.apache.beam</groupId>
<artifactId>beam-sdks-java-extensions-python</artifactId>
<version>${beam.version}</version>
</dependency>

<!-- Dependencies below this line are specific dependencies needed by the examples code. -->
<dependency>
<groupId>com.google.api-client</groupId>
Expand Down

0 comments on commit 10e15a9

Please sign in to comment.