-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Predictor class #34
Conversation
WalkthroughThe update involves the introduction of a framework for performing predictions using SLEAP models. Key components for inference, such as a predictor interface and classes for cropping and peak detection, have been added. These elements work together to enable the processing of image data through trained models, facilitating the identification of instances within images. Changes
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on X ? TipsChat with CodeRabbit Bot (
|
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #34 +/- ##
===========================================
- Coverage 99.78% 83.40% -16.38%
===========================================
Files 19 20 +1
Lines 936 1139 +203
===========================================
+ Hits 934 950 +16
- Misses 2 189 +187 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 4
Configuration used: CodeRabbit UI
Files selected for processing (1)
- sleap_nn/inference/inference.py (1 hunks)
Additional comments: 6
sleap_nn/inference/inference.py (6)
3-26: Ensure that all imported modules are used within the file to avoid unnecessary imports which can lead to increased memory usage and longer load times.
83-100: The abstract methods in the
Predictor
class are correctly defined using the@abstractmethod
decorator. Ensure that subclasses ofPredictor
implement these methods.229-341: The
forward
method in theFindInstancePeaks
class contains logic to predict instance peaks from images. Ensure that thepreprocess
method is adequately implemented and that thepeak_finding.find_global_peaks
function is correctly used with the appropriate parameters.344-415: The
forward
method in theTopDownInferenceModel
class correctly chains thecentroid_crop
andinstance_peaks
layers. Verify that the error handling for missingcentroid_crop
andinstance_peaks
is appropriate and that the method's logic is sound.535-567: The
make_pipeline
method in theTopDownPredictor
class correctly creates a data loading pipeline. Ensure that theTopdownConfmapsPipeline
andLabelsReader
are correctly used and that the configuration is properly applied.569-636: The
_make_labeled_frames_from_generator
method in theTopDownPredictor
class appears to correctly convert inference results into SLEAP-specific data structures. Verify that the method handles all possible keys in the inference result dictionaries and that the SLEAP data structures are correctly populated.
class Predictor(ABC): | ||
"""Base interface class for predictors.""" | ||
|
||
@classmethod | ||
def from_model_paths(cls, ckpt_paths: Dict[Text, Text], model: Text) -> "Predictor": | ||
"""Create the appropriate `Predictor` subclass from from the ckpt path. | ||
|
||
Args: | ||
ckpt_paths: Dict with keys as model names and values as paths to the checkpoint file having the trained model weights | ||
config_paths: Dict with keys as model names and values as paths to the config.yaml file used for training | ||
|
||
Returns: | ||
A subclass of `Predictor`. | ||
|
||
See also: `SingleInstancePredictor`, `TopDownPredictor`, `BottomUpPredictor`, | ||
`MoveNetPredictor`, `TopDownMultiClassPredictor`, | ||
`BottomUpMultiClassPredictor`. | ||
""" | ||
# Read configs and find model types. | ||
# configs={} | ||
# for p in ckpt_paths: | ||
# ckpt = torch.load(p) # ???? should we load the checkpoint everytime | ||
# configs[ckpt.config.model_name] = [ckpt.config, ckpt] # checkpoint or checkpoint path | ||
|
||
# ckpt_paths = {"centroid": ckpt_path, "centered": ckpt_path} | ||
model_names = ckpt_paths.keys() | ||
|
||
if "single_instance" in model: | ||
# predictor = SingleInstancePredictor.from_trained_models( | ||
# model_path=model_paths["single_instance"], | ||
# inference_config = inference_config | ||
# ) | ||
pass | ||
|
||
elif "topdown" in model: | ||
centroid_ckpt_path = None | ||
confmap_ckpt_path = None | ||
if "centroid" in model_names: | ||
centroid_ckpt_path = ckpt_paths["centroid"] | ||
if "centered" in model_names: | ||
confmap_ckpt_path = ckpt_paths["centered"] | ||
|
||
predictor = TopDownPredictor.from_trained_models( | ||
centroid_ckpt_path=centroid_ckpt_path, | ||
confmap_ckpt_path=confmap_ckpt_path, | ||
) | ||
|
||
else: | ||
raise ValueError( | ||
"Could not create predictor from model paths:" + "\n".join(ckpt_paths) | ||
) | ||
predictor.model_path = ckpt_paths | ||
return predictor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method from_model_paths
in the Predictor
class contains commented-out code and a pass
statement under the if "single_instance" in model:
condition. This should be cleaned up or implemented before merging.
class InferenceModule(nn.Module): | ||
""" | ||
Inference model base class. | ||
|
||
This class wraps the `nn.Model` class to provide inference | ||
utilities such as handling different input data types, preprocessing and variable | ||
output shapes. This layer expects the same input as the model (rank-4 image). | ||
|
||
Attributes: | ||
model: A `torch.nn.Model` that will be called on the input to this layer. | ||
conf: OmegaConf file with inference related parameters. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
torch_model: L.LightningModule, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.torch_model = torch_model | ||
|
||
def preprocess(self, imgs): | ||
return imgs | ||
|
||
def forward( | ||
self, | ||
data: Union[ | ||
np.ndarray, | ||
torch.Tensor, | ||
Dict[str, torch.Tensor], | ||
sio.Video, | ||
], | ||
**kwargs, | ||
) -> Union[Dict[str, np.ndarray], Dict[str, torch.Tensor]]: | ||
"""Predict instances with the data. | ||
|
||
Args: | ||
data: Input data in any form. Possible types: | ||
- `np.ndarray`, `tf.Tensor`: Images of shape | ||
`(samples, t, channels, height, width)` | ||
- `dict` with key `"image"` as a tensor | ||
- `torch.utils.data.DataLoader` that generates examples in one of the above formats. | ||
0 - `sleap.Video` which will be converted into a pipeline that generates | ||
batches of `batch_size` frames. | ||
|
||
Returns: | ||
The model outputs as a dictionary of tensors | ||
""" | ||
|
||
imgs = self.preprocess(data) | ||
outs = self.torch_model(imgs) | ||
return outs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The InferenceModule
class correctly extends nn.Module
and provides a forward
method. Ensure that the preprocess
method is implemented or overridden in subclasses as it currently only returns the input without any processing.
class CentroidCrop(InferenceModule): | ||
pass | ||
|
||
|
||
class FindInstancePeaks(InferenceModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CentroidCrop
class is empty and does not override any methods from its superclass InferenceModule
. If this class is intended to be implemented in the future, consider adding a TODO
comment or implement the required functionality.
centroid_config: Optional[DictConfig] = attr.ib(default=None) | ||
confmap_config: Optional[DictConfig] = attr.ib(default=None) | ||
centroid_model: Optional[L.LightningModule] = attr.ib(default=None) | ||
confmap_model: Optional[L.LightningModule] = attr.ib(default=None) | ||
|
||
def _initialize_inference_model(self): | ||
"""Initialize the inference model from the trained models and configuration.""" | ||
|
||
self.model_config = OmegaConf.create() | ||
|
||
# Create an instance of CentroidLayer if centroid_config is not None | ||
if self.centroid_config is None: | ||
centroid_crop_layer = None | ||
else: | ||
self.model_config["centroid"] = self.centroid_config | ||
self.model_config["data"] = self.centroid_config.inference_config.data | ||
pass | ||
|
||
# Create an instance of FindInstancePeaks layer if confmap_config is not None | ||
if self.confmap_config is None: | ||
pass | ||
else: | ||
self.model_config["confmaps"] = self.confmap_config | ||
self.model_config["data"] = self.confmap_config.inference_config.data | ||
instance_peaks_layer = FindInstancePeaks( | ||
torch_model=self.confmap_model, | ||
peak_threshold=self.confmap_config.inference_config.peak_threshold, | ||
output_stride=self.confmap_config.inference_config.output_stride, | ||
refinement=self.confmap_config.inference_config.integral_refinement, | ||
integral_patch_size=self.confmap_config.inference_config.integral_patch_size, | ||
) | ||
|
||
# Initialize the inference model with centroid and conf map layers | ||
self.inference_model = TopDownInferenceModel( | ||
centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer | ||
) | ||
|
||
@property | ||
def data_config(self) -> DictConfig: | ||
# Returns data config section from the overall config | ||
return self.model_config.data | ||
|
||
@classmethod | ||
def from_trained_models( | ||
cls, | ||
centroid_ckpt_path: Optional[Text] = None, | ||
confmap_ckpt_path: Optional[Text] = None, | ||
) -> "TopDownPredictor": | ||
"""Create predictor from saved models. | ||
|
||
Args: | ||
centroid_ckpt_path: Path to a centroid ckpt file. | ||
confmap_ckpt_path: Path to a centroid ckpt file. | ||
|
||
Returns: | ||
An instance of `TopDownPredictor` with the loaded models. | ||
|
||
One of the two models can be left as `None` to perform inference with ground | ||
truth data. This will only work with `LabelsReader` as the provider. | ||
""" | ||
if centroid_ckpt_path is None and confmap_ckpt_path is None: | ||
raise ValueError( | ||
"Either the centroid or topdown confidence map model must be provided." | ||
) | ||
|
||
if centroid_ckpt_path is not None: | ||
# Load centroid model. | ||
pass | ||
|
||
else: | ||
centroid_config = None | ||
centroid_model = None | ||
|
||
if confmap_ckpt_path is not None: | ||
# Load confmap model. | ||
confmap_ckpt = torch.load(confmap_ckpt_path) | ||
skeleton = confmap_ckpt["skeleton"] | ||
confmap_config = confmap_ckpt["config"] | ||
confmap_model = TopDownCenteredInstanceModel.load_from_checkpoint( | ||
"mice_default_init_with_amsgrad/last-v4.ckpt", config=confmap_config | ||
) | ||
confmap_model.to(confmap_config.inference_config.device) | ||
confmap_model.m_device = confmap_config.inference_config.device | ||
|
||
else: | ||
confmap_config = None | ||
confmap_model = None | ||
|
||
# create an instance of TopDownPredictor class | ||
obj = cls( | ||
centroid_config=centroid_config, | ||
centroid_model=centroid_model, | ||
confmap_config=confmap_config, | ||
confmap_model=confmap_model, | ||
) | ||
|
||
obj._initialize_inference_model() | ||
obj.skeleton = skeleton | ||
|
||
return obj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TopDownPredictor
class's _initialize_inference_model
method contains commented-out code and placeholders that should be either implemented or removed. Additionally, the from_trained_models
method has a hardcoded checkpoint path which should be parameterized to allow for flexibility.
Summary by CodeRabbit