-
Notifications
You must be signed in to change notification settings - Fork 143
/
workflow.py
executable file
·466 lines (376 loc) · 16.2 KB
/
workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import json
import logging
import os
import sys
import time
import types
import warnings
from functools import singledispatchmethod
from typing import TYPE_CHECKING, Optional, Union
import cloudpickle
import fsspec
import pandas as pd
from merlin.core.compat import cudf
from merlin.dag import Graph
from merlin.dag.executors import DaskExecutor, LocalExecutor
from merlin.dag.node import iter_nodes
from merlin.dag.ops.stat_operator import StatOperator
from merlin.io import Dataset
from merlin.schema import Schema
from nvtabular.ops import LambdaOp
from nvtabular.workflow.node import WorkflowNode
LOG = logging.getLogger("nvtabular")
if TYPE_CHECKING:
import distributed
class Workflow:
"""
The Workflow class applies a graph of operations onto a dataset, letting you transform
datasets to do feature engineering and preprocessing operations. This class follows an API
similar to Transformers in sklearn: we first ``fit`` the workflow by calculating statistics
on the dataset, and then once fit we can ``transform`` datasets by applying these statistics.
Example usage::
# define a graph of operations
cat_features = CAT_COLUMNS >> nvtabular.ops.Categorify()
cont_features = CONT_COLUMNS >> nvtabular.ops.FillMissing() >> nvtabular.ops.Normalize()
workflow = nvtabular.Workflow(cat_features + cont_features + "label")
# calculate statistics on the training dataset
workflow.fit(merlin.io.Dataset(TRAIN_PATH))
# transform the training and validation datasets and write out as parquet
workflow.transform(merlin.io.Dataset(TRAIN_PATH)).to_parquet(output_path=TRAIN_OUT_PATH)
workflow.transform(merlin.io.Dataset(VALID_PATH)).to_parquet(output_path=VALID_OUT_PATH)
Parameters
----------
output_node: WorkflowNode
The last node in the graph of operators this workflow should apply
"""
def __init__(self, output_node: WorkflowNode, client: Optional["distributed.Client"] = None):
self.graph = Graph(output_node)
self.executor = DaskExecutor(client)
@singledispatchmethod
def transform(self, data):
"""Transforms the data by applying the graph of operators to it.
Requires the ``fit`` method to have already been called, or
using a Workflow that has already beeen fit and re-loaded from
disk (using the ``load`` method).
This method returns data of the same type.
In the case of a `Dataset`. The computation is lazy. It won't
happen until the produced Dataset is consumed, or written out
to disk. e.g. with a `dataset.compute()`.
Parameters
-----------
data: Union[Dataset, DataFrameType]
Input Dataset or DataFrame to transform
Returns
-------
Dataset or DataFrame
Transformed Dataset or DataFrame with the workflow graph applied to it
Raises
------
NotImplementedError
If passed an unsupoprted data type to transform.
"""
raise NotImplementedError(
f"Workflow.transform received an unsupported type: {type(data)} "
"Supported types are a `merlin.io.Dataset` or DataFrame (pandas or cudf)"
)
@transform.register
def _(self, dataset: Dataset) -> Dataset:
return self._transform_impl(dataset)
@transform.register
def _(self, dataframe: pd.DataFrame) -> pd.DataFrame:
return self._transform_df(dataframe)
if cudf:
@transform.register
def _(self, dataframe: cudf.DataFrame) -> cudf.DataFrame:
return self._transform_df(dataframe)
def fit_schema(self, input_schema: Schema):
"""Computes input and output schemas for each node in the Workflow graph
Parameters
----------
input_schema : Schema
The input schema to use
Returns
-------
Workflow
This workflow where each node in the graph has a fitted schema
"""
self.graph.construct_schema(input_schema)
return self
@property
def subworkflows(self):
return list(self.graph.subgraphs.keys())
@property
def input_dtypes(self):
return self.graph.input_dtypes
@property
def input_schema(self):
return self.graph.input_schema
@property
def output_schema(self):
return self.graph.output_schema
@property
def output_dtypes(self):
return self.graph.output_dtypes
@property
def output_node(self):
return self.graph.output_node
def _input_columns(self):
return self.graph._input_columns()
def get_subworkflow(self, subgraph_name):
subgraph = self.graph.subgraph(subgraph_name)
return Workflow(subgraph.output_node)
def remove_inputs(self, input_cols) -> "Workflow":
"""Removes input columns from the workflow.
This is useful for the case of inference where you might need to remove label columns
from the processed set.
Parameters
----------
input_cols : list of str
List of column names to
Returns
-------
Workflow
This workflow with the input columns removed from it
See Also
--------
merlin.dag.Graph.remove_inputs
"""
self.graph.remove_inputs(input_cols)
return self
def fit(self, dataset: Dataset) -> "Workflow":
"""Calculates statistics for this workflow on the input dataset
Parameters
-----------
dataset: Dataset
The input dataset to calculate statistics for. If there is a train/test split this
data should be the training dataset only.
Returns
-------
Workflow
This Workflow with statistics calculated on it
"""
self.executor.fit(dataset, self.graph)
return self
def fit_transform(self, dataset: Dataset) -> Dataset:
"""Convenience method to both fit the workflow and transform the dataset in a single
call. Equivalent to calling ``workflow.fit(dataset)`` followed by
``workflow.transform(dataset)``
Parameters
-----------
dataset: Dataset
Input dataset to calculate statistics on, and transform results
Returns
-------
Dataset
Transformed Dataset with the workflow graph applied to it
See Also
--------
fit
transform
"""
self.fit(dataset)
return self.transform(dataset)
def _transform_impl(self, dataset: Dataset, capture_dtypes=False):
if not self.graph.output_schema:
self.graph.construct_schema(dataset.schema)
ddf = dataset.to_ddf(columns=self._input_columns())
return Dataset(
self.executor.transform(
ddf, self.output_node, self.output_dtypes, capture_dtypes=capture_dtypes
),
cpu=dataset.cpu,
base_dataset=dataset.base_dataset,
schema=self.output_schema,
)
def _transform_df(self, df):
if not self.graph.output_schema:
raise ValueError("no output schema")
return LocalExecutor().transform(df, self.output_node, self.output_dtypes)
@classmethod
def _getmodules(cls, fs):
"""
Returns an imprecise but useful approximation of the list of modules
necessary to execute a given list of functions. This approximation is
sound (all modules listed are required by the supplied functions) but not
necessarily complete (not all modules required will necessarily be returned).
For function literals (lambda expressions), this returns
1. the names of every module referenced in the lambda expression, e.g.,
`m` for `lambda x: m.f(x)` and
2. the names of the declaring module for every function referenced in
the lambda expression, e.g. `m` for `import m.f; lambda x: f(x)`
For declared functions, this returns the names of their declaring modules.
The return value will exclude all built-in modules and (on Python 3.10 or later)
all standard library modules.
"""
result = set()
exclusions = set(sys.builtin_module_names)
if hasattr(sys, "stdlib_module_names"):
# sys.stdlib_module_names is only available in Python 3.10 and beyond
exclusions = exclusions | sys.stdlib_module_names
for f in fs:
if f.__name__ == "<lambda>":
for closurevars in [
inspect.getclosurevars(f).globals,
inspect.getclosurevars(f).nonlocals,
]:
for name, val in closurevars.items():
print(f"{name} = {val}")
if isinstance(val, types.ModuleType):
result.add(val)
elif isinstance(val, types.FunctionType):
mod = inspect.getmodule(val)
if mod is not None:
result.add(mod)
else:
mod = inspect.getmodule(f)
if mod is not None:
result.add(mod)
return [mod for mod in result if mod.__name__ not in exclusions]
def save(self, path: Union[str, os.PathLike], modules_byvalue=None):
"""Save this workflow to disk
Parameters
----------
path: Union[str, os.PathLike]
The path to save the workflow to
modules_byvalue:
A list of modules that should be serialized by value. This
should include any modules that will not be available on
the host where this workflow is ultimately deserialized.
In lieu of an explicit list, pass None to serialize all modules
by reference or pass "auto" to use a heuristic to infer which
modules to serialize by value.
"""
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version
path = str(path)
fs = fsspec.get_fs_token_paths(path)[0]
fs.makedirs(path, exist_ok=True)
# point all stat ops to store intermediate output (parquet etc) at the path
# this lets us easily bundle
for stat in Graph.get_nodes_by_op_type([self.output_node], StatOperator):
stat.op.set_storage_path(path, copy=True)
# generate a file of all versions used to generate this bundle
lib = cudf if cudf else pd
with fs.open(fs.sep.join([path, "metadata.json"]), "w") as o:
json.dump(
{
"versions": {
"nvtabular": nvt_version,
lib.__name__: lib.__version__,
"python": sys.version,
},
"generated_timestamp": int(time.time()),
},
o,
)
# track existing by-value modules
preexisting_modules_byvalue = set(cloudpickle.list_registry_pickle_by_value())
# direct cloudpickle to serialize selected modules by value
if modules_byvalue is None:
modules_byvalue = []
elif modules_byvalue == "auto":
l_nodes = self.graph.get_nodes_by_op_type(
list(iter_nodes([self.graph.output_node])), LambdaOp
)
try:
modules_byvalue = Workflow._getmodules([ln.op.f for ln in l_nodes])
except RuntimeError as ex:
warnings.warn(
"Failed to automatically infer modules to serialize by value. "
f'Reason given was "{str(ex)}"'
)
try:
for m in modules_byvalue:
if isinstance(m, types.ModuleType):
cloudpickle.register_pickle_by_value(m)
elif isinstance(m, str) and m in sys.modules:
cloudpickle.register_pickle_by_value(sys.modules[m])
except RuntimeError as ex:
warnings.warn(
f'Failed to register modules to serialize by value. Reason given was "{str(ex)}"'
)
try:
# dump out the full workflow (graph/stats/operators etc) using cloudpickle
with fs.open(fs.sep.join([path, "workflow.pkl"]), "wb") as o:
cloudpickle.dump(self, o)
finally:
# return all modules that we set to serialize by value to by-reference
# (i.e., retain modules that were set to serialize by value before this invocation)
for m in modules_byvalue:
if isinstance(m, types.ModuleType):
if m.__name__ not in preexisting_modules_byvalue:
cloudpickle.unregister_pickle_by_value(m)
elif isinstance(m, str) and m in sys.modules:
if m not in preexisting_modules_byvalue:
cloudpickle.unregister_pickle_by_value(sys.modules[m])
@classmethod
def load(cls, path: Union[str, os.PathLike], client=None) -> "Workflow":
"""Load up a saved workflow object from disk
Parameters
----------
path: Union[str, os.PathLike]
The path to load the workflow from
client: distributed.Client, optional
The Dask distributed client to use for multi-gpu processing and multi-node processing
Returns
-------
Workflow
The Workflow loaded from disk
"""
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version
path = str(path)
fs = fsspec.get_fs_token_paths(path)[0]
# check version information from the metadata blob, and warn if we have a mismatch
meta = json.load(fs.open(fs.sep.join([path, "metadata.json"])))
def parse_version(version):
return version.split(".")[:2]
def check_version(stored, current, name):
if parse_version(stored) != parse_version(current):
warnings.warn(
f"Loading workflow generated with {name} version {stored} "
f"- but we are running {name} {current}. This might cause issues"
)
# make sure we don't have any major/minor version conflicts between the stored worklflow
# and the current environment
lib = cudf if cudf else pd
versions = meta["versions"]
check_version(versions["nvtabular"], nvt_version, "nvtabular")
check_version(versions["python"], sys.version, "python")
if lib.__name__ in versions:
check_version(versions[lib.__name__], lib.__version__, lib.__name__)
else:
expected = "GPU" if "cudf" in versions else "CPU"
warnings.warn(f"Loading workflow generated on {expected}")
# load up the workflow object di
workflow = cloudpickle.load(fs.open(fs.sep.join([path, "workflow.pkl"]), "rb"))
workflow.client = client
# we might have been copied since saving, update all the stat ops
# with the new path to their storage locations
for stat in Graph.get_nodes_by_op_type([workflow.output_node], StatOperator):
stat.op.set_storage_path(path, copy=False)
return workflow
def clear_stats(self):
"""Removes calculated statistics from each node in the workflow graph
See Also
--------
nvtabular.ops.stat_operator.StatOperator.clear
"""
for stat in Graph.get_nodes_by_op_type([self.graph.output_node], StatOperator):
stat.op.clear()