Skip to content

Commit

Permalink
feat: added keep_output to inspect
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed May 3, 2022
1 parent 3f48afe commit bcbe64f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions surgeon_pytorch/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,16 @@ def get_layers_modules(

class Inspect(nn.Module):
def __init__(
self, model: nn.Module, layer: Union[str, Sequence[str], Dict[str, str]]
self,
model: nn.Module,
layer: Union[str, Sequence[str], Dict[str, str]],
keep_output: bool = True,
):
super().__init__()
self.model = model
self.layers = get_layers_modules(model, layer)
self.layer = layer
self.keep_output = keep_output

def register_hooks(self):
for layer in self.layers:
Expand Down Expand Up @@ -94,11 +98,11 @@ def forward(self, *args, **kwargs):
if isinstance(self.layer, list) or is_tuple:
layers_output = [layer.output for layer in self.layers]
layers_output = tuple(layers_output) if is_tuple else layers_output
return model_output, layers_output
return (model_output, layers_output) if self.keep_output else layers_output
elif isinstance(self.layer, dict):
layers_output = {layer.key: layer.output for layer in self.layers}
return model_output, layers_output
return (model_output, layers_output) if self.keep_output else layers_output

layer_output = self.layers[0].output

return model_output, layer_output
return (model_output, layer_output) if self.keep_output else layer_output

0 comments on commit bcbe64f

Please sign in to comment.