Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for serializing modules involved in LambdaOp execution by value #1741

Merged
merged 14 commits into from
Feb 15, 2023
Merged
111 changes: 106 additions & 5 deletions nvtabular/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import inspect
import json
import logging
import sys
import time
import types
import warnings
from typing import TYPE_CHECKING, Optional

Expand All @@ -31,9 +34,10 @@

from merlin.dag import Graph
from merlin.dag.executors import DaskExecutor, LocalExecutor
from merlin.dag.node import iter_nodes
from merlin.io import Dataset
from merlin.schema import Schema
from nvtabular.ops import StatOperator
from nvtabular.ops import LambdaOp, StatOperator
from nvtabular.workflow.node import WorkflowNode

LOG = logging.getLogger("nvtabular")
Expand Down Expand Up @@ -255,13 +259,68 @@ def _transform_df(self, df):

return LocalExecutor().transform(df, self.output_node, self.output_dtypes)

def save(self, path):
@classmethod
def _getmodules(cls, fs):
"""
Returns an imprecise but useful approximation of the list of modules
necessary to execute a given list of functions. This approximation is
sound (all modules listed are required by the supplied functions) but not
necessarily complete (not all modules required will necessarily be returned).

For function literals (lambda expressions), this returns
1. the names of every module referenced in the lambda expression, e.g.,
`m` for `lambda x: m.f(x)` and
2. the names of the declaring module for every function referenced in
the lambda expression, e.g. `m` for `import m.f; lambda x: f(x)`

For declared functions, this returns the names of their declaring modules.

The return value will exclude all built-in modules and (on Python 3.10 or later)
all standard library modules.
"""
result = set()

exclusions = set(sys.builtin_module_names)
if hasattr(sys, "stdlib_module_names"):
# sys.stdlib_module_names is only available in Python 3.10 and beyond
exclusions = exclusions | sys.stdlib_module_names

for f in fs:
if f.__name__ == "<lambda>":
for closurevars in [
inspect.getclosurevars(f).globals,
inspect.getclosurevars(f).nonlocals,
]:
for name, val in closurevars.items():
print(f"{name} = {val}")
if isinstance(val, types.ModuleType):
result.add(val)
elif isinstance(val, types.FunctionType):
mod = inspect.getmodule(val)
if mod is not None:
result.add(mod)
else:
mod = inspect.getmodule(f)
if mod is not None:
result.add(mod)

return [mod for mod in result if mod.__name__ not in exclusions]

def save(self, path, modules_byvalue=None):
"""Save this workflow to disk

Parameters
----------
path: str
The path to save the workflow to
modules_byvalue:
A list of modules that should be serialized by value. This
should include any modules that will not be available on
the host where this workflow is ultimately deserialized.

In lieu of an explicit list, pass None to serialize all modules
by reference or pass "auto" to use a heuristic to infer which
modules to serialize by value.
"""
# avoid a circular import getting the version
from nvtabular import __version__ as nvt_version
Expand Down Expand Up @@ -290,9 +349,51 @@ def save(self, path):
o,
)

# dump out the full workflow (graph/stats/operators etc) using cloudpickle
with fs.open(fs.sep.join([path, "workflow.pkl"]), "wb") as o:
cloudpickle.dump(self, o)
# track existing by-value modules
preexisting_modules_byvalue = set(cloudpickle.list_registry_pickle_by_value())

# direct cloudpickle to serialize selected modules by value
if modules_byvalue is None:
modules_byvalue = []
elif modules_byvalue == "auto":
l_nodes = self.graph.get_nodes_by_op_type(
list(iter_nodes([self.graph.output_node])), LambdaOp
)

try:
modules_byvalue = Workflow._getmodules([ln.op.f for ln in l_nodes])
except RuntimeError as ex:
warnings.warn(
"Failed to automatically infer modules to serialize by value. "
f'Reason given was "{str(ex)}"'
)

try:
for m in modules_byvalue:
if isinstance(m, types.ModuleType):
cloudpickle.register_pickle_by_value(m)
elif isinstance(m, str) and m in sys.modules:
cloudpickle.register_pickle_by_value(sys.modules[m])
except RuntimeError as ex:
warnings.warn(
f'Failed to register modules to serialize by value. Reason given was "{str(ex)}"'
)

try:
# dump out the full workflow (graph/stats/operators etc) using cloudpickle
with fs.open(fs.sep.join([path, "workflow.pkl"]), "wb") as o:
cloudpickle.dump(self, o)
finally:
# return all modules that we set to serialize by value to by-reference
# (i.e., retain modules that were set to serialize by value before this invocation)

for m in modules_byvalue:
if isinstance(m, types.ModuleType):
if m.__name__ not in preexisting_modules_byvalue:
cloudpickle.unregister_pickle_by_value(m)
elif isinstance(m, str) and m in sys.modules:
if m not in preexisting_modules_byvalue:
cloudpickle.unregister_pickle_by_value(sys.modules[m])

@classmethod
def load(cls, path, client=None) -> "Workflow":
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math
import os
import shutil
import sys

try:
import cudf
Expand Down Expand Up @@ -666,3 +667,71 @@ def test_workflow_saved_schema(tmpdir):
for node in postorder_iter_nodes(workflow2.output_node):
assert node.input_schema is not None
assert node.output_schema is not None


def test_workflow_infer_modules_byvalue(tmp_path):
module_fn = tmp_path / "not_a_real_module.py"
sys.path.append(str(tmp_path))

with open(module_fn, "w") as module_f:
module_f.write("def identity(col):\n return col")

import not_a_real_module

f_0 = not_a_real_module.identity
f_1 = lambda x: not_a_real_module.identity(x) # noqa
f_2 = lambda x: f_0(x) # noqa

try:
for fn, f in {
"not_a_real_module.identity": f_0,
"lambda x: not_a_real_module.identity(x)": f_1,
"lambda x: f_0(x)": f_2,
}.items():
assert not_a_real_module in Workflow._getmodules(
[f]
), f"inferred module dependencies from {fn}"

finally:
sys.path.pop()
del sys.modules["not_a_real_module"]


def test_workflow_explicit_modules_byvalue(tmp_path):
module_fn = tmp_path / "not_a_real_module.py"
sys.path.append(str(tmp_path))

with open(module_fn, "w") as module_f:
module_f.write("def identity(col):\n return col")

import not_a_real_module

wf = nvt.Workflow(["col_a"] >> nvt.ops.LambdaOp(not_a_real_module.identity))

wf.save(str(tmp_path / "identity-workflow"), modules_byvalue=[not_a_real_module])

del not_a_real_module
del sys.modules["not_a_real_module"]
os.unlink(str(tmp_path / "not_a_real_module.py"))

Workflow.load(str(tmp_path / "identity-workflow"))


def test_workflow_auto_infer_modules_byvalue(tmp_path):
module_fn = tmp_path / "not_a_real_module.py"
sys.path.append(str(tmp_path))

with open(module_fn, "w") as module_f:
module_f.write("def identity(col):\n return col")

import not_a_real_module

wf = nvt.Workflow(["col_a"] >> nvt.ops.LambdaOp(not_a_real_module.identity))

wf.save(str(tmp_path / "identity-workflow"), modules_byvalue="auto")

del not_a_real_module
del sys.modules["not_a_real_module"]
os.unlink(str(tmp_path / "not_a_real_module.py"))

Workflow.load(str(tmp_path / "identity-workflow"))