-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update RunInference documentation #22250
Changes from 1 commit
c7b0d99
eb89211
1a29d9e
2be4739
f5a0f0c
1d404b6
2d954de
245dc49
99b1308
b0f430a
b02fd09
1234657
0973e45
5dd713e
8a50ee8
e49ebec
5db0b67
85d866b
f766888
40351c4
e02160f
2be34ea
090c356
c2995c9
dbe9b05
214ba7f
07e1f3e
7d8ce8e
c0c4548
580fa7f
651f52d
d3f80d5
380fcd3
033997b
335f1f1
489fce7
4b962b4
9e98188
46f5ebd
1f7ce97
07a99d7
49b0a7f
2d988c2
abde489
c340531
a3786c9
cde3380
67b5d42
4160d1b
4aa492c
c1d5643
4e89126
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ For more information, see the [`BatchElements` transform documentation](https:// | |
|
||
### Shared helper class | ||
|
||
Instead of loading a model for each thread in the process, we use the `Shared` class, which allows us to load one model that is shared across all threads of each worker in a DoFn. For more information, see the | ||
Using the `Shared` class within RunInference implementation allows us to load the model only once per process and share it with all DoFn instances created in that process. This feature reduces memory consumption and model loading time. For more information, see the | ||
[`Shared` class documentation](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/utils/shared.py#L20). | ||
|
||
### Multi-model pipelines | ||
|
@@ -53,7 +53,7 @@ with pipeline as p: | |
``` | ||
Where `model_handler` is the model handler setup code. | ||
|
||
To import models, you need to wrap them around a `ModelHandler` object. Which `ModelHandler` you import depends on the framework and type of data structure that contains the inputs. The following examples show some ModelHandlers that you might want to import. | ||
To import models, you need to configure a `ModelHandler` object that wraps the underlying model. Which `ModelHandler` you import depends on the framework and type of data structure that contains the inputs. The following examples show some ModelHandlers that you might want to import. | ||
|
||
``` | ||
from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy | ||
|
@@ -171,7 +171,7 @@ In some cases, the `PredictionResults` output might not include the correct pred | |
|
||
The RunInference API currently expects outputs to be an `Iterable[Any]`. Example return types are `Iterable[Tensor]` or `Iterable[Dict[str, Tensor]]`. When RunInference zips the inputs with the predictions, the predictions iterate over the dictionary keys instead of the batch elements. The result is that the key name is preserved but the prediction tensors are discarded. For more information, see the [Pytorch RunInference PredictionResult is a Dict](https://github.com/apache/beam/issues/22240) issue in the Apache Beam GitHub project. | ||
|
||
To work with the current RunInference implementation, you can create a wrapper class that overrides the `model(input)` call. In PyTorch, for example, your wrapper would override the `forward()` function and return an output with the appropriate format of `List[Dict[str, torch.Tensor]]`. For more information, see our [HuggingFace language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py#L49). | ||
To work with the current RunInference implementation, you can create a wrapper class that overrides the `model(input)` call. In PyTorch, for example, your wrapper would override the `forward()` function and return an output with the appropriate format of `List[Dict[str, torch.Tensor]]`. For more information, see our [HuggingFace language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py#L49) and our [Bert language modeling example](https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are the same links, looks like not the change we intended to make? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my last comment referred to disable batching section There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops. This should be fixed now. |
||
|
||
### Unable to batch tensor elements | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These little chunks of code below here seem out of place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeandy How do you want to handle this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to reword it to something like this, and keep the (refactored) code block?
To import models, you need to wrap them around a
ModelHandler
object. TheModelHandler
you import will depend on the framework and type of data structure that contains the inputs. See the following examples on which ones you may want to import.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some updates. Take a look and let me know if we need more changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. By the way, the imports I originally wrote had some typos, so I fixed them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to fix the typos