Skip to content

Commit

Permalink
[Doc] Add doc on export with nested keys
Browse files Browse the repository at this point in the history
ghstack-source-id: 9c95e2dba6751d93c20c66d0dba0d4219dc61c0b
Pull Request resolved: #1085

(cherry picked from commit 3cb5855)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 1ef1188 commit d64c33d
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tutorials/sphinx_tuto/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,25 @@
# and the FX graph:
print("fx graph:", model_export.graph_module.print_readable())

##################################################
# Working with nested keys
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# Nested keys are a core feature of the tensordict library, and being able to export modules that read and write
# nested entries is therefore an important feature to support.
# Because keyword arguments must be regualar strings, it is not possible for :class:`~tensordict.nn.dispatch` to work
# directly with them. Instead, ``dispatch`` will unpack nested keys joined with a regular underscore (`"_"`), as the
# following example shows.

model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))

model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())


##################################################
# Note that the callable returned by `module()` is a pure python callable that can be in turn compiled using
# :func:`~torch.compile`.
Expand Down

0 comments on commit d64c33d

Please sign in to comment.