Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Predictor class #34

Closed
wants to merge 2 commits into from
Closed

Add Predictor class #34

wants to merge 2 commits into from

Conversation

gitttt-1234
Copy link
Collaborator

@gitttt-1234 gitttt-1234 commented Dec 21, 2023

Summary by CodeRabbit

  • New Features
    • Introduced a new inference system for enhanced model predictions.
    • Added support for more efficient instance peak detection and centroid cropping.
    • Users can now utilize multiple prediction models within the application.

Copy link
Contributor

coderabbitai bot commented Dec 21, 2023

Walkthrough

The update involves the introduction of a framework for performing predictions using SLEAP models. Key components for inference, such as a predictor interface and classes for cropping and peak detection, have been added. These elements work together to enable the processing of image data through trained models, facilitating the identification of instances within images.

Changes

File Path Summary
sleap_nn/inference/inference.py Introduced Predictor interface and related classes for image instance predictions.

🐇

In the code where tensors flow,
New classes bloom and objects grow,
Predictions run both fast and keen—
A leap in SLEAP, now to be seen.
🌟🔍

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on X ?


Tips

Chat with CodeRabbit Bot (@coderabbitai)

  • You can directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
  • You can tag CodeRabbit on specific lines of code or entire files in the PR by tagging @coderabbitai in a comment. Examples:
  • You can tag @coderabbitai in a PR comment and ask questions about the PR and the codebase. Use quoted replies to pass the context for follow-up questions. Examples:
    • @coderabbitai render interesting statistics about this repository as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai generate unit tests for the src/utils.ts file.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger a review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • The JSON schema for the configuration file is available here.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/coderabbit-overrides.v2.json

CodeRabbit Discord Community

Join our Discord Community to get help, request features, and share feedback.

Copy link

codecov bot commented Dec 21, 2023

Codecov Report

Attention: 187 lines in your changes are missing coverage. Please review.

Comparison is base (649280d) 99.78% compared to head (34f49a8) 83.40%.
Report is 15 commits behind head on main.

Files Patch % Lines
sleap_nn/inference/inference.py 0.00% 187 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main      #34       +/-   ##
===========================================
- Coverage   99.78%   83.40%   -16.38%     
===========================================
  Files          19       20        +1     
  Lines         936     1139      +203     
===========================================
+ Hits          934      950       +16     
- Misses          2      189      +187     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gitttt-1234 gitttt-1234 marked this pull request as ready for review December 23, 2023 00:31
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 4

Configuration used: CodeRabbit UI

Commits Files that changed from the base of the PR and between ff6801b and 34f49a8.
Files selected for processing (1)
  • sleap_nn/inference/inference.py (1 hunks)
Additional comments: 6
sleap_nn/inference/inference.py (6)
  • 3-26: Ensure that all imported modules are used within the file to avoid unnecessary imports which can lead to increased memory usage and longer load times.

  • 83-100: The abstract methods in the Predictor class are correctly defined using the @abstractmethod decorator. Ensure that subclasses of Predictor implement these methods.

  • 229-341: The forward method in the FindInstancePeaks class contains logic to predict instance peaks from images. Ensure that the preprocess method is adequately implemented and that the peak_finding.find_global_peaks function is correctly used with the appropriate parameters.

  • 344-415: The forward method in the TopDownInferenceModel class correctly chains the centroid_crop and instance_peaks layers. Verify that the error handling for missing centroid_crop and instance_peaks is appropriate and that the method's logic is sound.

  • 535-567: The make_pipeline method in the TopDownPredictor class correctly creates a data loading pipeline. Ensure that the TopdownConfmapsPipeline and LabelsReader are correctly used and that the configuration is properly applied.

  • 569-636: The _make_labeled_frames_from_generator method in the TopDownPredictor class appears to correctly convert inference results into SLEAP-specific data structures. Verify that the method handles all possible keys in the inference result dictionaries and that the SLEAP data structures are correctly populated.

Comment on lines +29 to +81
class Predictor(ABC):
"""Base interface class for predictors."""

@classmethod
def from_model_paths(cls, ckpt_paths: Dict[Text, Text], model: Text) -> "Predictor":
"""Create the appropriate `Predictor` subclass from from the ckpt path.

Args:
ckpt_paths: Dict with keys as model names and values as paths to the checkpoint file having the trained model weights
config_paths: Dict with keys as model names and values as paths to the config.yaml file used for training

Returns:
A subclass of `Predictor`.

See also: `SingleInstancePredictor`, `TopDownPredictor`, `BottomUpPredictor`,
`MoveNetPredictor`, `TopDownMultiClassPredictor`,
`BottomUpMultiClassPredictor`.
"""
# Read configs and find model types.
# configs={}
# for p in ckpt_paths:
# ckpt = torch.load(p) # ???? should we load the checkpoint everytime
# configs[ckpt.config.model_name] = [ckpt.config, ckpt] # checkpoint or checkpoint path

# ckpt_paths = {"centroid": ckpt_path, "centered": ckpt_path}
model_names = ckpt_paths.keys()

if "single_instance" in model:
# predictor = SingleInstancePredictor.from_trained_models(
# model_path=model_paths["single_instance"],
# inference_config = inference_config
# )
pass

elif "topdown" in model:
centroid_ckpt_path = None
confmap_ckpt_path = None
if "centroid" in model_names:
centroid_ckpt_path = ckpt_paths["centroid"]
if "centered" in model_names:
confmap_ckpt_path = ckpt_paths["centered"]

predictor = TopDownPredictor.from_trained_models(
centroid_ckpt_path=centroid_ckpt_path,
confmap_ckpt_path=confmap_ckpt_path,
)

else:
raise ValueError(
"Could not create predictor from model paths:" + "\n".join(ckpt_paths)
)
predictor.model_path = ckpt_paths
return predictor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method from_model_paths in the Predictor class contains commented-out code and a pass statement under the if "single_instance" in model: condition. This should be cleaned up or implemented before merging.

Comment on lines +170 to +222
class InferenceModule(nn.Module):
"""
Inference model base class.

This class wraps the `nn.Model` class to provide inference
utilities such as handling different input data types, preprocessing and variable
output shapes. This layer expects the same input as the model (rank-4 image).

Attributes:
model: A `torch.nn.Model` that will be called on the input to this layer.
conf: OmegaConf file with inference related parameters.

"""

def __init__(
self,
torch_model: L.LightningModule,
**kwargs,
):
super().__init__(**kwargs)
self.torch_model = torch_model

def preprocess(self, imgs):
return imgs

def forward(
self,
data: Union[
np.ndarray,
torch.Tensor,
Dict[str, torch.Tensor],
sio.Video,
],
**kwargs,
) -> Union[Dict[str, np.ndarray], Dict[str, torch.Tensor]]:
"""Predict instances with the data.

Args:
data: Input data in any form. Possible types:
- `np.ndarray`, `tf.Tensor`: Images of shape
`(samples, t, channels, height, width)`
- `dict` with key `"image"` as a tensor
- `torch.utils.data.DataLoader` that generates examples in one of the above formats.
0 - `sleap.Video` which will be converted into a pipeline that generates
batches of `batch_size` frames.

Returns:
The model outputs as a dictionary of tensors
"""

imgs = self.preprocess(data)
outs = self.torch_model(imgs)
return outs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The InferenceModule class correctly extends nn.Module and provides a forward method. Ensure that the preprocess method is implemented or overridden in subclasses as it currently only returns the input without any processing.

Comment on lines +225 to +229
class CentroidCrop(InferenceModule):
pass


class FindInstancePeaks(InferenceModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CentroidCrop class is empty and does not override any methods from its superclass InferenceModule. If this class is intended to be implemented in the future, consider adding a TODO comment or implement the required functionality.

Comment on lines +418 to +533
centroid_config: Optional[DictConfig] = attr.ib(default=None)
confmap_config: Optional[DictConfig] = attr.ib(default=None)
centroid_model: Optional[L.LightningModule] = attr.ib(default=None)
confmap_model: Optional[L.LightningModule] = attr.ib(default=None)

def _initialize_inference_model(self):
"""Initialize the inference model from the trained models and configuration."""

self.model_config = OmegaConf.create()

# Create an instance of CentroidLayer if centroid_config is not None
if self.centroid_config is None:
centroid_crop_layer = None
else:
self.model_config["centroid"] = self.centroid_config
self.model_config["data"] = self.centroid_config.inference_config.data
pass

# Create an instance of FindInstancePeaks layer if confmap_config is not None
if self.confmap_config is None:
pass
else:
self.model_config["confmaps"] = self.confmap_config
self.model_config["data"] = self.confmap_config.inference_config.data
instance_peaks_layer = FindInstancePeaks(
torch_model=self.confmap_model,
peak_threshold=self.confmap_config.inference_config.peak_threshold,
output_stride=self.confmap_config.inference_config.output_stride,
refinement=self.confmap_config.inference_config.integral_refinement,
integral_patch_size=self.confmap_config.inference_config.integral_patch_size,
)

# Initialize the inference model with centroid and conf map layers
self.inference_model = TopDownInferenceModel(
centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer
)

@property
def data_config(self) -> DictConfig:
# Returns data config section from the overall config
return self.model_config.data

@classmethod
def from_trained_models(
cls,
centroid_ckpt_path: Optional[Text] = None,
confmap_ckpt_path: Optional[Text] = None,
) -> "TopDownPredictor":
"""Create predictor from saved models.

Args:
centroid_ckpt_path: Path to a centroid ckpt file.
confmap_ckpt_path: Path to a centroid ckpt file.

Returns:
An instance of `TopDownPredictor` with the loaded models.

One of the two models can be left as `None` to perform inference with ground
truth data. This will only work with `LabelsReader` as the provider.
"""
if centroid_ckpt_path is None and confmap_ckpt_path is None:
raise ValueError(
"Either the centroid or topdown confidence map model must be provided."
)

if centroid_ckpt_path is not None:
# Load centroid model.
pass

else:
centroid_config = None
centroid_model = None

if confmap_ckpt_path is not None:
# Load confmap model.
confmap_ckpt = torch.load(confmap_ckpt_path)
skeleton = confmap_ckpt["skeleton"]
confmap_config = confmap_ckpt["config"]
confmap_model = TopDownCenteredInstanceModel.load_from_checkpoint(
"mice_default_init_with_amsgrad/last-v4.ckpt", config=confmap_config
)
confmap_model.to(confmap_config.inference_config.device)
confmap_model.m_device = confmap_config.inference_config.device

else:
confmap_config = None
confmap_model = None

# create an instance of TopDownPredictor class
obj = cls(
centroid_config=centroid_config,
centroid_model=centroid_model,
confmap_config=confmap_config,
confmap_model=confmap_model,
)

obj._initialize_inference_model()
obj.skeleton = skeleton

return obj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TopDownPredictor class's _initialize_inference_model method contains commented-out code and placeholders that should be either implemented or removed. Additionally, the from_trained_models method has a hardcoded checkpoint path which should be parameterized to allow for flexibility.

@gitttt-1234 gitttt-1234 deleted the divya/inference branch May 15, 2024 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant