Collection of pretrained networks in pytorch readily available for transfer learning tasks like neural system identification.
pip install ptrnets
Find a list of all available models like this:
from ptrnets import AVAILABLE_MODELS
print(AVAILABLE_MODELS)
Import a model like this:
from ptrnets import simclr_resnet50x2
model = simclr_resnet50x2(pretrained=True)
You can access intermediate representations in two ways:
You can conveniently access intermediate representations of a forward pass using the ptrnets.utils.mlayer.probe_model
function Example:
import torch
from ptrnets import resnet50
from ptrnets.utils.mlayer import probe_model
model = resnet50(pretrained=True)
available_layers = [name for name, _ in model.named_modules()]
layer_name = "layer2.1"
assert layer_name in available_layers, f"Layer {layer_name} not available. Choose from {available_layers}"
model_probe = probe_model(model, layer_name)
x = torch.rand(1, 3, 224, 224)
output = model_probe(x)
Note: if the input is not large enough to do a full forward pass through the network, you might need to use a try-except
block to catch the RuntimeError
.
ptrnets.utils.mlayer.clip_model
creates a copy of the model up to a specific layer. Because the model is smaller, a forward pass can run faster.
However, the output is only guaranteed to be the same as the original model's if the architecture is fully sequential up until that layer.
Example:
import torch
from ptrnets import vgg16
from ptrnets.utils.mlayer import clip_model, probe_model
model = vgg16(pretrained=True)
available_layers = [name for name, _ in model.named_modules()]
layer_name = "features.18"
assert layer_name in available_layers, f"Layer {layer_name} not available. Choose from {available_layers}"
model_clipped = clip_model(model, layer_name) # Creates new model up to the layer
x = torch.rand(1, 3, 224, 224)
output = model_clipped(x)
assert torch.allclose(output, probe_model(model, layer_name)(x)), "Output of clipped model is not the same as the original model"
Pull requests are welcome. Please see instructions here.