Skip to content

Commit

Permalink
feat: updated readme with APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed May 3, 2022
1 parent bcbe64f commit a053b12
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ print(layers)
"""
```

#### API

```python
model = Extract(
model: nn.Module,
layer: Union[str, Sequence[str], Dict[str, str]],
keep_output: bool = True,
)
```


### Extract

Given a PyTorch model we can display all intermediate nodes of the graph using `get_nodes`:
Expand Down Expand Up @@ -113,26 +124,26 @@ print(get_nodes(model)) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3',
Then we can extract outputs using `Extract`, which will create a new model that returns the requested output node:

```python
model_extracted = Extract(model, node_out='sigmoid')
model_ext = Extract(model, node_out='sigmoid')
x = torch.rand(1, 5)
sigmoid = model_extracted(x)
sigmoid = model_ext(x)
print(sigmoid) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
```

We can also extract a model with new input nodes:

```python
model_extracted = Extract(model, node_in='layer1', node_out='sigmoid')
model_ext = Extract(model, node_in='layer1', node_out='sigmoid')
x = torch.rand(1, 3)
sigmoid = model_extracted(x)
sigmoid = model_ext(x)
print(sigmoid) # tensor([[0.5444, 0.3965]], grad_fn=<SigmoidBackward0>)
```

We can also provide multiple inputs and outputs and name them:

```python
model_extracted = Extract(model, node_in={ 'layer1': 'x' }, node_out={ 'sigmoid': 'y1', 'relu': 'y2'})
out = model_extracted(x = torch.rand(1, 3))
model_ext = Extract(model, node_in={ 'layer1': 'x' }, node_out={ 'sigmoid': 'y1', 'relu': 'y2'})
out = model_ext(x = torch.rand(1, 3))
print(out)
"""
{
Expand Down Expand Up @@ -170,10 +181,25 @@ print(get_nodes(model)) # ['x', 'layer1a', 'layer1b', 'add', 'layer2']
model_ext = Extract(model, node_in = {'layer1a': 'my_input'}, node_out = {'add': 'my_add'})
print(model_ext.summary) # {'input': ('x', 'my_input'), 'output': {'my_add': add}}

out = model_ext(x = torch.rand(1, 2), my_add = torch.rand(1,2))
out = model_ext(x = torch.rand(1, 2), my_input = torch.rand(1,2))
print(out) # {'my_add': tensor([[ 0.3722, -0.6843]], grad_fn=<AddBackward0>)}
```

#### API

```python
model = Extract(
model: nn.Module,
node_in: Optional[Union[str, Sequence[str], Dict[str, str]]] = None,
node_out: Optional[Union[str, Sequence[str], Dict[str, str]]] = None,
tracer: Optional[Type[Tracer]] = None, # Tracer class used, default: torch.fx.Tracer
concrete_args: Optional[Dict[str, Any]] = None, # Tracer concrete_args, default: None
keep_output: bool = None, # Set to `True` to return original outputs as first argument, default: True except if node_out are provided
share_modules: bool = False, # Set to true if you want to share module weights with original model
)
```





Expand Down

0 comments on commit a053b12

Please sign in to comment.