From a053b12e5b97ca70102d64dff614d90470f2e3ca Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Tue, 3 May 2022 17:57:41 +0200 Subject: [PATCH] feat: updated readme with APIs --- README.md | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 92c6e44..505d397 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,17 @@ print(layers) """ ``` +#### API + +```python +model = Extract( + model: nn.Module, + layer: Union[str, Sequence[str], Dict[str, str]], + keep_output: bool = True, +) +``` + + ### Extract Given a PyTorch model we can display all intermediate nodes of the graph using `get_nodes`: @@ -113,26 +124,26 @@ print(get_nodes(model)) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', Then we can extract outputs using `Extract`, which will create a new model that returns the requested output node: ```python -model_extracted = Extract(model, node_out='sigmoid') +model_ext = Extract(model, node_out='sigmoid') x = torch.rand(1, 5) -sigmoid = model_extracted(x) +sigmoid = model_ext(x) print(sigmoid) # tensor([[0.5570, 0.3652]], grad_fn=) ``` We can also extract a model with new input nodes: ```python -model_extracted = Extract(model, node_in='layer1', node_out='sigmoid') +model_ext = Extract(model, node_in='layer1', node_out='sigmoid') x = torch.rand(1, 3) -sigmoid = model_extracted(x) +sigmoid = model_ext(x) print(sigmoid) # tensor([[0.5444, 0.3965]], grad_fn=) ``` We can also provide multiple inputs and outputs and name them: ```python -model_extracted = Extract(model, node_in={ 'layer1': 'x' }, node_out={ 'sigmoid': 'y1', 'relu': 'y2'}) -out = model_extracted(x = torch.rand(1, 3)) +model_ext = Extract(model, node_in={ 'layer1': 'x' }, node_out={ 'sigmoid': 'y1', 'relu': 'y2'}) +out = model_ext(x = torch.rand(1, 3)) print(out) """ { @@ -170,10 +181,25 @@ print(get_nodes(model)) # ['x', 'layer1a', 'layer1b', 'add', 'layer2'] model_ext = Extract(model, node_in = {'layer1a': 'my_input'}, node_out = {'add': 'my_add'}) print(model_ext.summary) # {'input': ('x', 'my_input'), 'output': {'my_add': add}} -out = model_ext(x = torch.rand(1, 2), my_add = torch.rand(1,2)) +out = model_ext(x = torch.rand(1, 2), my_input = torch.rand(1,2)) print(out) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=)} ``` +#### API + +```python +model = Extract( + model: nn.Module, + node_in: Optional[Union[str, Sequence[str], Dict[str, str]]] = None, + node_out: Optional[Union[str, Sequence[str], Dict[str, str]]] = None, + tracer: Optional[Type[Tracer]] = None, # Tracer class used, default: torch.fx.Tracer + concrete_args: Optional[Dict[str, Any]] = None, # Tracer concrete_args, default: None + keep_output: bool = None, # Set to `True` to return original outputs as first argument, default: True except if node_out are provided + share_modules: bool = False, # Set to true if you want to share module weights with original model +) +``` + +