-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLTransform #26795
MLTransform #26795
Changes from 51 commits
6c641ec
9c93dc0
89e889d
5548103
f5d050a
6681916
3be1cfd
6caba7e
901a74c
361e0bb
681d164
eac8b3f
011d5d1
def7eb4
1a0a0ed
2393254
f25618e
4256c99
df73361
4497bb5
77b3634
baf1ae7
044f509
c312aef
68a2529
e6ef468
2be4ba6
7a290e2
c2a1fae
21dadb1
df05169
42fd6c4
5c6dcb4
43d24ad
618b2fa
a814650
0a61955
33f8fb2
bc22e9f
4e07f7d
eeed56c
9eed989
3453b9f
3e8f198
e8a3686
72ea029
55b04e8
b65ff05
00fb944
f11d02b
7b2200f
bca2dda
295a80d
1d0b5b1
64bba5e
034a066
1eef0e7
4ed94c7
2e6c5ac
bf81d46
0860489
1dcdaa8
bb9336a
17a4eb1
20f416d
ba33cb7
f0c023b
a315091
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
""" | ||
This example demonstrates how to use MLTransform. | ||
MLTransform is a PTransform that applies multiple data transformations on the | ||
incoming data. | ||
|
||
This example computes the vocabulary on the incoming data. Then, it computes | ||
the TF-IDF of the incoming data using the vocabulary computed in the previous | ||
step. | ||
|
||
1. ComputeAndApplyVocabulary computes the vocabulary on the incoming data and | ||
overrides the incoming data with the vocabulary indices. | ||
2. TFIDF computes the TF-IDF of the incoming data using the vocabulary and | ||
provides vocab_index and tf-idf weights. vocab_index is suffixed with | ||
'_vocab_index' and tf-idf weights are suffixed with '_tfidf' to the | ||
original column name(which is the output of ComputeAndApplyVocabulary). | ||
|
||
MLTransform produces artifacts, for example: ComputeAndApplyVocabulary produces | ||
a text file that contains vocabulary which is saved in `artifact_location`. | ||
ComputeAndApplyVocabulary outputs vocab indices associated with the saved vocab | ||
file. This mode of MLTransform is called artifact `produce` mode. | ||
This will be useful when the data is preprocessed before ML model training. | ||
|
||
The second mode of MLTransform is artifact `consume` mode. In this mode, the | ||
transformations are applied on the incoming data using the artifacts produced | ||
by the previous run of MLTransform. This mode will be useful when the data is | ||
preprocessed before ML model inference. | ||
""" | ||
|
||
import argparse | ||
import logging | ||
import tempfile | ||
|
||
import apache_beam as beam | ||
from apache_beam.ml.transforms.base import ArtifactMode | ||
from apache_beam.ml.transforms.base import MLTransform | ||
from apache_beam.ml.transforms.tft_transforms import ComputeAndApplyVocabulary | ||
from apache_beam.ml.transforms.tft_transforms import TFIDF | ||
from apache_beam.ml.transforms.utils import ArtifactsFetcher | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--artifact_location', type=str, default='') | ||
return parser.parse_known_args() | ||
|
||
|
||
def run(args): | ||
data = [ | ||
dict(x=["Let's", "go", "to", "the", "park"]), | ||
dict(x=["I", "enjoy", "going", "to", "the", "park"]), | ||
dict(x=["I", "enjoy", "reading", "books"]), | ||
dict(x=["Beam", "can", "be", "fun"]), | ||
dict(x=["The", "weather", "is", "really", "nice", "today"]), | ||
dict(x=["I", "love", "to", "go", "to", "the", "park"]), | ||
dict(x=["I", "love", "to", "read", "books"]), | ||
dict(x=["I", "love", "to", "program"]), | ||
] | ||
|
||
with beam.Pipeline() as p: | ||
input_data = p | beam.Create(data) | ||
|
||
# arfifacts produce mode. | ||
input_data |= ( | ||
'MLTransform' >> MLTransform( | ||
artifact_location=args.artifact_location, | ||
artifact_mode=ArtifactMode.PRODUCE, | ||
).with_transform(ComputeAndApplyVocabulary( | ||
columns=['x'])).with_transform(TFIDF(columns=['x']))) | ||
|
||
# _ = input_data | beam.Map(logging.info) | ||
|
||
with beam.Pipeline() as p: | ||
input_data = [ | ||
dict(x=['I', 'love', 'books']), dict(x=['I', 'love', 'Apache', 'Beam']) | ||
] | ||
input_data = p | beam.Create(input_data) | ||
|
||
# artifacts consume mode. | ||
input_data |= ( | ||
MLTransform( | ||
artifact_location=args.artifact_location, | ||
artifact_mode=ArtifactMode.CONSUME, | ||
# you don't need to specify transforms as they are already saved in | ||
# in the artifacts. | ||
)) | ||
|
||
_ = input_data | beam.Map(logging.info) | ||
|
||
# To fetch the artifacts after the pipeline is run | ||
artifacts_fetcher = ArtifactsFetcher(artifact_location=args.artifact_location) | ||
vocab_list = artifacts_fetcher.get_vocab_list() | ||
assert vocab_list[22] == 'Beam' | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.getLogger().setLevel(logging.INFO) | ||
args, pipeline_args = parse_args() | ||
# for this example, create a temp artifact location if not provided. | ||
if args.artifact_location == '': | ||
args.artifact_location = tempfile.mkdtemp() | ||
run(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Generic | ||
from typing import List | ||
from typing import Optional | ||
from typing import TypeVar | ||
|
||
import apache_beam as beam | ||
|
||
# TODO: Abstract methods are not getting pickled with dill. | ||
# https://github.com/uqfoundation/dill/issues/332 | ||
# import abc | ||
|
||
__all__ = ['MLTransform'] | ||
|
||
TransformedDatasetT = TypeVar('TransformedDatasetT') | ||
TransformedMetadataT = TypeVar('TransformedMetadataT') | ||
|
||
# Input/Output types to the MLTransform. | ||
ExampleT = TypeVar('ExampleT') | ||
MLTransformOutputT = TypeVar('MLTransformOutputT') | ||
|
||
# Input to the process data. This could be same or different from ExampleT. | ||
damccorm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ProcessInputT = TypeVar('ProcessInputT') | ||
# Output of the process data. This could be same or different | ||
# from MLTransformOutputT | ||
ProcessOutputT = TypeVar('ProcessOutputT') | ||
|
||
# Input to the apply() method of BaseOperation. | ||
OperationInputT = TypeVar('OperationInputT') | ||
# Output of the apply() method of BaseOperation. | ||
OperationOutputT = TypeVar('OperationOutputT') | ||
|
||
|
||
class ArtifactMode(object): | ||
PRODUCE = 'produce' | ||
CONSUME = 'consume' | ||
|
||
|
||
class BaseOperation(Generic[OperationInputT, OperationOutputT]): | ||
def apply( | ||
self, inputs: OperationInputT, column_name: str, *args, | ||
**kwargs) -> OperationOutputT: | ||
""" | ||
Define any processing logic in the apply() method. | ||
processing logics are applied on inputs and returns a transformed | ||
output. | ||
Args: | ||
inputs: input data. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class _ProcessHandler(Generic[ProcessInputT, ProcessOutputT]): | ||
""" | ||
Only for internal use. No backwards compatibility guarantees. | ||
""" | ||
def process_data( | ||
self, pcoll: beam.PCollection[ProcessInputT] | ||
) -> beam.PCollection[ProcessOutputT]: | ||
""" | ||
Logic to process the data. This will be the entrypoint in | ||
beam.MLTransform to process incoming data. | ||
""" | ||
raise NotImplementedError | ||
|
||
def append_transform(self, transform: BaseOperation): | ||
raise NotImplementedError | ||
|
||
|
||
class MLTransform(beam.PTransform[beam.PCollection[ExampleT], | ||
beam.PCollection[MLTransformOutputT]], | ||
Generic[ExampleT, MLTransformOutputT]): | ||
def __init__( | ||
self, | ||
*, | ||
artifact_location: str, | ||
artifact_mode: str = ArtifactMode.PRODUCE, | ||
transforms: Optional[List[BaseOperation]] = None, | ||
is_input_record_batches: bool = False, | ||
output_record_batches: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we need something like this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One option would be to make these properties of the operations instead of the top level transform. Chained TFT operations could then just use the values from the first/last operation. I also expect that we're going to run into this problem with other frameworks in the future, so I think we need a way for adding additional framework or transform specific parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that would be my preferred experience. So users could do something like:
For TFT, we'd then fuse together any consecutive TFT transforms into a single TFTProcessHandler, resolve any conflicting arguments (e.g. throw if one transform says output_record_batches and the next doesn't take recordBatches or something), and construct the graph There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. info about record_batch lives within the context of TFTProcessHandler. I don't like the idea of passing this arg via operations since in the operations, we don't use this arg anywhere in the operation. It would be just a different way of inferring this arg in TFTProcessHandler. Alternative would be MLTransform would take a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think from a user's point of view we use it in the first one; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, the columns are specified at operation level instead of transform level. The entry point for column x could be at the beginning but the entry point at column y could be in the middle of the list. User might pass If we ask the user to provide like this, I feel like it could get a little complicated
We can also iterate on this in the v2 since I guess this needs another discussion and remove the option for Record batches for now. We would support Dict[str, Any] in V1. what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm comfortable iterating on it in v2 |
||
): | ||
""" | ||
Args: | ||
artifact_location: A storage location for artifacts resulting from | ||
MLTransform. These artifacts include transformations applied to | ||
the dataset and generated values like min, max from ScaleTo01, | ||
and mean, var from ScaleToZScore. Artifacts are produced and stored | ||
in this location when the `artifact_mode` is set to 'produce'. | ||
Conversely, when `artifact_mode` is set to 'consume', artifacts are | ||
retrieved from this location. Note that when consuming artifacts, | ||
it is not necessary to pass the transforms since they are inherently | ||
stored within the artifacts themselves. The value assigned to | ||
`artifact_location` should be a valid storage path where the artifacts | ||
can be written to or read from. | ||
transforms: A list of transforms to apply to the data. All the transforms | ||
are applied in the order they are specified. The input of the | ||
i-th transform is the output of the (i-1)-th transform. Multi-input | ||
transforms are not supported yet. | ||
is_input_record_batches: Whether the input is a RecordBatch. | ||
output_record_batches: Output RecordBatches instead of beam.Row(). | ||
artifact_mode: Whether to produce or consume artifacts. If set to | ||
'consume', the handler will assume that the artifacts are already | ||
computed and stored in the artifact_location. Pass the same artifact | ||
location that was passed during produce phase to ensure that the | ||
right artifacts are read. If set to 'produce', the handler | ||
will compute the artifacts and store them in the artifact_location. | ||
The artifacts will be read from this location during the consume phase. | ||
There is no need to pass the transforms in this case since they are | ||
already embedded in the stored artifacts. | ||
""" | ||
# avoid circular import | ||
# pylint: disable=wrong-import-order, wrong-import-position | ||
from apache_beam.ml.transforms.handlers import TFTProcessHandler | ||
process_handler = TFTProcessHandler( | ||
damccorm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
artifact_location=artifact_location, | ||
artifact_mode=artifact_mode, | ||
transforms=transforms, | ||
is_input_record_batches=is_input_record_batches, | ||
output_record_batches=output_record_batches) | ||
|
||
self._process_handler = process_handler | ||
|
||
def expand( | ||
self, pcoll: beam.PCollection[ExampleT] | ||
) -> beam.PCollection[MLTransformOutputT]: | ||
""" | ||
This is the entrypoint for the MLTransform. This method will | ||
invoke the process_data() method of the _ProcessHandler instance | ||
to process the incoming data. | ||
|
||
process_data takes in a PCollection and applies the PTransforms | ||
necessary to process the data and returns a PCollection of | ||
transformed data. | ||
Args: | ||
pcoll: A PCollection of ExampleT type. | ||
Returns: | ||
A PCollection of MLTransformOutputT type. | ||
""" | ||
return self._process_handler.process_data(pcoll) | ||
|
||
def with_transform(self, transform: BaseOperation): | ||
""" | ||
Add a transform to the MLTransform pipeline. | ||
Args: | ||
transform: A BaseOperation instance. | ||
Returns: | ||
A MLTransform instance. | ||
""" | ||
self._process_handler.append_transform(transform) | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this TODO still apply? What are the consequences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not relevant anymore. I tried today and I wasn't able to reproduce it now