-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update RunInference documentation #22250
Changes from 24 commits
c7b0d99
eb89211
1a29d9e
2be4739
f5a0f0c
1d404b6
2d954de
245dc49
99b1308
b0f430a
b02fd09
1234657
0973e45
5dd713e
8a50ee8
e49ebec
5db0b67
85d866b
f766888
40351c4
e02160f
2be34ea
090c356
c2995c9
dbe9b05
214ba7f
07e1f3e
7d8ce8e
c0c4548
580fa7f
651f52d
d3f80d5
380fcd3
033997b
335f1f1
489fce7
4b962b4
9e98188
46f5ebd
1f7ce97
07a99d7
49b0a7f
2d988c2
abde489
c340531
a3786c9
cde3380
67b5d42
4160d1b
4aa492c
c1d5643
4e89126
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,157 @@ | ||||
# coding=utf-8 | ||||
# | ||||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||||
# contributor license agreements. See the NOTICE file distributed with | ||||
# this work for additional information regarding copyright ownership. | ||||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||||
# (the "License"); you may not use this file except in compliance with | ||||
# the License. You may obtain a copy of the License at | ||||
# | ||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||
# | ||||
# Unless required by applicable law or agreed to in writing, software | ||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
# | ||||
|
||||
# pytype: skip-file | ||||
|
||||
def torch_unkeyed_model_handler(test=None): | ||||
# [START torch_unkeyed_model_handler] | ||||
import apache_beam as beam | ||||
import numpy | ||||
import torch | ||||
from apache_beam.ml.inference.base import RunInference | ||||
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor | ||||
|
||||
class LinearRegression(torch.nn.Module): | ||||
def __init__(self, input_dim=1, output_dim=1): | ||||
super().__init__() | ||||
self.linear = torch.nn.Linear(input_dim, output_dim) | ||||
|
||||
def forward(self, x): | ||||
out = self.linear(x) | ||||
return out | ||||
|
||||
model_state_dict_path = 'gs://apache-beam-samples/run_inference/five_times_table_torch.pt' # pylint: disable=line-too-long | ||||
model_class = LinearRegression | ||||
model_params = {'input_dim': 1, 'output_dim': 1} | ||||
model_handler = PytorchModelHandlerTensor( | ||||
model_class=model_class, | ||||
model_params=model_params, | ||||
state_dict_path=model_state_dict_path) | ||||
|
||||
unkeyed_data = numpy.array([10, 40, 60, 90], | ||||
dtype=numpy.float32).reshape(-1, 1) | ||||
|
||||
with beam.Pipeline() as p: | ||||
predictions = ( | ||||
p | ||||
| 'InputData' >> beam.Create(unkeyed_data) | ||||
| 'ConvertNumpyToTensor' >> beam.Map(torch.Tensor) | ||||
| 'PytorchRunInference' >> RunInference(model_handler=model_handler) | ||||
| beam.Map(print)) | ||||
# [END torch_unkeyed_model_handler] | ||||
if test: | ||||
test(predictions) | ||||
|
||||
|
||||
def torch_keyed_model_handler(test=None): | ||||
# [START torch_keyed_model_handler] | ||||
import apache_beam as beam | ||||
import torch | ||||
from apache_beam.ml.inference.base import KeyedModelHandler | ||||
from apache_beam.ml.inference.base import RunInference | ||||
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor | ||||
|
||||
class LinearRegression(torch.nn.Module): | ||||
def __init__(self, input_dim=1, output_dim=1): | ||||
super().__init__() | ||||
self.linear = torch.nn.Linear(input_dim, output_dim) | ||||
|
||||
def forward(self, x): | ||||
out = self.linear(x) | ||||
return out | ||||
|
||||
model_state_dict_path = 'gs://apache-beam-samples/run_inference/five_times_table_torch.pt' # pylint: disable=line-too-long | ||||
model_class = LinearRegression | ||||
model_params = {'input_dim': 1, 'output_dim': 1} | ||||
keyed_model_handler = KeyedModelHandler( | ||||
PytorchModelHandlerTensor( | ||||
model_class=model_class, | ||||
model_params=model_params, | ||||
state_dict_path=model_state_dict_path)) | ||||
|
||||
keyed_data = [("first_question", 105.00), ("second_question", 108.00), | ||||
("third_question", 1000.00), ("fourth_question", 1013.00)] | ||||
|
||||
with beam.Pipeline() as p: | ||||
predictions = ( | ||||
p | ||||
| 'KeyedInputData' >> beam.Create(keyed_data) | ||||
| "ConvertIntToTensor" >> | ||||
beam.Map(lambda x: (x[0], torch.Tensor([x[1]]))) | ||||
| 'PytorchRunInference' >> | ||||
RunInference(model_handler=keyed_model_handler) | ||||
| beam.Map(print)) | ||||
# [END torch_keyed_model_handler] | ||||
if test: | ||||
test(predictions) | ||||
|
||||
|
||||
def sklearn_unkeyed_model_handler(test=None): | ||||
# [START sklearn_unkeyed_model_handler] | ||||
import apache_beam as beam | ||||
import numpy | ||||
from apache_beam.ml.inference.base import RunInference | ||||
from apache_beam.ml.inference.sklearn_inference import ModelFileType | ||||
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy | ||||
|
||||
sklearn_model_filename = 'gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl' # pylint: disable=line-too-long | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't find this filesystem in the unit tests. Can you double check this @AnandInguva There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. Other pytorch examples don't need to import it. Do we need to install There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can add a check to skip if GCP is not detected? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since SKlearn is a installed as a beam dependency in each tox environment, we need to add if GCP is installed as well in every tox test for the Sklearn tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we skip the test if apache_beam[gcp] is not installed? @tvalentyn @yeandy . My test fetch the file from the GCS bucket. I just provide the GCS location to the RunInference transform There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typically our tests do sth like:
does this approach work here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I may be missing some context here but can we install gcp dependenices for tox environment for sklearn tests similar to Line 274 in 64bcc7d
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I figured it out. Thanks |
||||
sklearn_model_handler = SklearnModelHandlerNumpy( | ||||
model_uri=sklearn_model_filename, model_file_type=ModelFileType.PICKLE) | ||||
|
||||
unkeyed_data = numpy.array([20, 40, 60, 90], | ||||
dtype=numpy.float32).reshape(-1, 1) | ||||
with beam.Pipeline() as p: | ||||
predictions = ( | ||||
p | ||||
| "ReadInputs" >> beam.Create(unkeyed_data) | ||||
| "RunInferenceSklearn" >> | ||||
RunInference(model_handler=sklearn_model_handler) | ||||
| beam.Map(print)) | ||||
# [END sklearn_unkeyed_model_handler] | ||||
if test: | ||||
test(predictions) | ||||
|
||||
|
||||
def sklearn_keyed_model_handler(test=None): | ||||
# [START sklearn_keyed_model_handler] | ||||
import apache_beam as beam | ||||
from apache_beam.ml.inference.base import KeyedModelHandler | ||||
from apache_beam.ml.inference.base import RunInference | ||||
from apache_beam.ml.inference.sklearn_inference import ModelFileType | ||||
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy | ||||
|
||||
sklearn_model_filename = 'gs://apache-beam-samples/run_inference/five_times_table_sklearn.pkl' # pylint: disable=line-too-long | ||||
sklearn_model_handler = KeyedModelHandler( | ||||
SklearnModelHandlerNumpy( | ||||
model_uri=sklearn_model_filename, | ||||
model_file_type=ModelFileType.PICKLE)) | ||||
|
||||
keyed_data = [("first_question", 105.00), ("second_question", 108.00), | ||||
("third_question", 1000.00), ("fourth_question", 1013.00)] | ||||
|
||||
with beam.Pipeline() as p: | ||||
predictions = ( | ||||
p | ||||
| "ReadInputs" >> beam.Create(keyed_data) | ||||
| "ConvertDataToList" >> beam.Map(lambda x: (x[0], [x[1]])) | ||||
| "RunInferenceSklearn" >> | ||||
RunInference(model_handler=sklearn_model_handler) | ||||
| beam.Map(print)) | ||||
# [END sklearn_keyed_model_handler] | ||||
if test: | ||||
test(predictions) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# coding=utf-8 | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# pytype: skip-file | ||
|
||
import unittest | ||
from io import StringIO | ||
|
||
import mock | ||
|
||
from apache_beam.examples.snippets.util import assert_matches_stdout | ||
from apache_beam.testing.test_pipeline import TestPipeline | ||
|
||
from . import runinference | ||
|
||
def check_torch_keyed_model_handler(actual): | ||
expected = '''[START torch_keyed_model_handler] | ||
('first_question', PredictionResult(example=tensor([105.]), inference=tensor([523.6982], grad_fn=<UnbindBackward0>))) | ||
('second_question', PredictionResult(example=tensor([108.]), inference=tensor([538.5867], grad_fn=<UnbindBackward0>))) | ||
('third_question', PredictionResult(example=tensor([1000.]), inference=tensor([4965.4019], grad_fn=<UnbindBackward0>))) | ||
('fourth_question', PredictionResult(example=tensor([1013.]), inference=tensor([5029.9180], grad_fn=<UnbindBackward0>))) | ||
[END torch_keyed_model_handler]'''.splitlines()[1:-1] | ||
assert_matches_stdout(actual, expected) | ||
|
||
|
||
def check_sklearn_keyed_model_handler(actual): | ||
expected = '''[START sklearn_keyed_model_handler] | ||
('first_question', PredictionResult(example=[105.0], inference=array([525.]))) | ||
('second_question', PredictionResult(example=[108.0], inference=array([540.]))) | ||
('third_question', PredictionResult(example=[1000.0], inference=array([5000.]))) | ||
('fourth_question', PredictionResult(example=[1013.0], inference=array([5065.]))) | ||
[END sklearn_keyed_model_handler] '''.splitlines()[1:-1] | ||
assert_matches_stdout(actual, expected) | ||
|
||
|
||
def check_torch_unkeyed_model_handler(actual): | ||
expected = '''[START torch_unkeyed_model_handler] | ||
PredictionResult(example=tensor([10.]), inference=tensor([52.2325], grad_fn=<UnbindBackward0>)) | ||
PredictionResult(example=tensor([40.]), inference=tensor([201.1165], grad_fn=<UnbindBackward0>)) | ||
PredictionResult(example=tensor([60.]), inference=tensor([300.3724], grad_fn=<UnbindBackward0>)) | ||
PredictionResult(example=tensor([90.]), inference=tensor([449.2563], grad_fn=<UnbindBackward0>)) | ||
[END torch_unkeyed_model_handler] '''.splitlines()[1:-1] | ||
assert_matches_stdout(actual, expected) | ||
|
||
|
||
def check_sklearn_unkeyed_model_handler(actual): | ||
expected = '''[START sklearn_unkeyed_model_handler] | ||
PredictionResult(example=array([20.], dtype=float32), inference=array([100.], dtype=float32)) | ||
PredictionResult(example=array([40.], dtype=float32), inference=array([200.], dtype=float32)) | ||
PredictionResult(example=array([60.], dtype=float32), inference=array([300.], dtype=float32)) | ||
PredictionResult(example=array([90.], dtype=float32), inference=array([450.], dtype=float32)) | ||
[END sklearn_unkeyed_model_handler] '''.splitlines()[1:-1] | ||
assert_matches_stdout(actual, expected) | ||
|
||
@mock.patch('apache_beam.Pipeline', TestPipeline) | ||
@mock.patch( | ||
'apache_beam.examples.snippets.transforms.elementwise.runinference.print', str) | ||
class RunInferenceTest(unittest.TestCase): | ||
def test_torch_unkeyed_model_handler(self): | ||
runinference.torch_unkeyed_model_handler(check_torch_unkeyed_model_handler) | ||
|
||
def test_torch_keyed_model_handler(self): | ||
runinference.torch_keyed_model_handler(check_torch_keyed_model_handler) | ||
|
||
def test_sklearn_unkeyed_model_handler(self): | ||
runinference.sklearn_unkeyed_model_handler(check_sklearn_unkeyed_model_handler) | ||
|
||
def test_sklearn_keyed_model_handler(self): | ||
runinference.sklearn_keyed_model_handler(check_sklearn_keyed_model_handler) | ||
|
||
def test_images(self): | ||
runinference.images(check_images) | ||
|
||
def test_digits(self): | ||
runinference.digits(check_digits) | ||
|
||
rszper marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AnandInguva I see
ModuleNotFoundError: No module named 'torch'
. Do you know where we can install extra deps for snippet examples?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which environment do you run the snippet examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tox.