Skip to content

Commit 3b33b7b

Browse files
author
Vincent Moens
committed
[Feature] Plotting TensorDictSequential graphs
ghstack-source-id: ff93fb4 Pull Request resolved: #1144
1 parent 2360386 commit 3b33b7b

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

tensordict/nn/probabilistic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010

1111
from textwrap import indent
12-
from typing import Any, Dict, List, Optional, overload, OrderedDict
12+
from typing import Any, Dict, List, Optional, OrderedDict, overload
1313

1414
import torch
1515

@@ -800,8 +800,7 @@ def __init__(
800800
aggregate_probabilities: bool | None = None,
801801
include_sum: bool | None = None,
802802
inplace: bool | None = None,
803-
) -> None:
804-
...
803+
) -> None: ...
805804

806805
@overload
807806
def __init__(
@@ -812,8 +811,7 @@ def __init__(
812811
aggregate_probabilities: bool | None = None,
813812
include_sum: bool | None = None,
814813
inplace: bool | None = None,
815-
) -> None:
816-
...
814+
) -> None: ...
817815

818816
def __init__(
819817
self,

tensordict/nn/sequence.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from __future__ import annotations
77

88
import collections
9+
import contextlib
910
import logging
1011
from copy import deepcopy
1112
from typing import Any, Callable, Iterable, List, OrderedDict, overload
@@ -574,3 +575,64 @@ def __setitem__(
574575

575576
def __delitem__(self, index: int | slice | str) -> None:
576577
self.module.__delitem__(idx=index)
578+
579+
def plot(self, example_input: TensorDictBase | None = None, **kwargs):
580+
import pydot
581+
582+
graph = pydot.Dot(
583+
"my_graph", graph_type="digraph", bgcolor="yellow", splines="curved"
584+
)
585+
graph.set_bgcolor("white")
586+
587+
if example_input is not None:
588+
from torch._subclasses.fake_tensor import FakeTensorMode
589+
590+
fake_mode = FakeTensorMode()
591+
converter = fake_mode.fake_tensor_converter
592+
fake_td = example_input.apply(
593+
lambda x: converter.from_real_tensor(fake_mode, x)
594+
)
595+
else:
596+
fake_td = None
597+
fake_mode = contextlib.nullcontext()
598+
599+
with fake_mode:
600+
iterator = (
601+
enumerate(self._module_iter())
602+
if not isinstance(self.module, nn.ModuleDict)
603+
else self.module.items()
604+
)
605+
for name, module in iterator:
606+
graph.add_node(
607+
pydot.Node(str(name), shape="box")
608+
) # label=str(node.module)))
609+
610+
# Check if in_keys are there already
611+
in_keys = module.in_keys
612+
for in_key in in_keys:
613+
if in_key not in graph.obj_dict["nodes"]:
614+
in_key_node = pydot.Node(
615+
in_key, label=in_key, shape="plaintext"
616+
)
617+
graph.add_node(in_key_node)
618+
in_key_edge = pydot.Edge(
619+
in_key, str(name), color="blue", style="arrow"
620+
)
621+
graph.add_edge(in_key_edge)
622+
623+
if not isinstance(module, TensorDictModule):
624+
fake_td = self._run_module(module, fake_td, **kwargs)
625+
626+
out_keys = module.out_keys
627+
for out_key in out_keys:
628+
if out_key not in graph.obj_dict["nodes"]:
629+
out_key_node = pydot.Node(
630+
out_key, label=out_key, shape="plaintext"
631+
)
632+
graph.add_node(out_key_node)
633+
out_key_edge = pydot.Edge(
634+
str(name), out_key, color="blue", style="arrow"
635+
)
636+
graph.add_edge(out_key_edge)
637+
638+
graph.write_png("/Users/vmoens/Downloads/my_graph.png")

0 commit comments

Comments
 (0)