Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use instances as model input #46

Merged
merged 5 commits into from
May 29, 2024
Merged

Use instances as model input #46

merged 5 commits into from
May 29, 2024

Conversation

aaprasad
Copy link
Contributor

@aaprasad aaprasad commented May 15, 2024

Currently, at all levels we use biogtr.Frame as the base unit of data given to the models as an input. However, one of the core advantages of the Global Tracking Transformer is that it really associates detections. Thus, if we are able to just use detections as input into the model then at a higher level we can come up with more useful ways to group the instances when passing it into the model. Thus, here we refactor the biogtr.models.GlobalTrackingTransformer and biogtr.models.Transformer to take as input a set of ref_instances and a set of query instances. The query_instances are used as an input into the decoder to associate against the ref_instances. One level up, in biogtr.inference.Tracker, we basically show an example of how we can group instances by frame against window.

Summary by CodeRabbit

  • New Features

    • Enhanced tracking functionality with improved instance handling and feature extraction.
  • Improvements

    • Updated methods to operate on instances instead of frames for better efficiency.
    • Refined metrics computation and evaluation steps in model training and testing.
  • Bug Fixes

    • Corrected bounding box extraction and time index handling to ensure accurate results.
  • Refactor

    • Renamed and restructured functions for clarity and consistency.
    • Adjusted imports and method signatures across various modules.
  • Tests

    • Updated test cases to align with new instance-based processing.

@aaprasad aaprasad requested a review from talmo May 15, 2024 21:57
Copy link
Contributor

coderabbitai bot commented May 15, 2024

Walkthrough

This update introduces significant enhancements to the biogtr module, focusing on refining the handling of instances within frames. Key changes include the addition of a _frame attribute to the Instance class, updates to the Frame class to manage instance attributes, and modifications to various methods to operate on instances instead of frames. Additionally, new utility functions were added to handle bounding box and time extraction, and several methods were refactored for improved clarity and performance.

Changes

File(s) Change Summary
biogtr/data_structures.py Added _frame attribute and frame property to Instance; modified Frame class to set frame attribute for each Instance; updated Frame.__repr__ and Frame.instances setter.
biogtr/inference/tracker.py Updated sliding_inference and _run_global_tracker methods to operate on instances instead of frames.
biogtr/models/global_tracking_transformer.py Refactored to use torch.nn.Module; updated forward method to accept ref_instances and query_instances; added extract_features method.
biogtr/models/gtr_runner.py Modified forward, training_step, validation_step, test_step, predict_step, and _shared_eval_step methods to operate on instances; updated method signatures and processing logic.
biogtr/models/model_utils.py Renamed get_boxes_times to get_boxes; added get_times function; updated bounding box extraction to operate on instances.
biogtr/models/transformer.py Replaced Frame with Instance and updated related functions; modified forward method to operate on instances.
biogtr/training/losses.py Updated imports and logic to use get_boxes and get_times; adjusted bounding box and time index extraction to work with instances.
tests/test_models.py Modified test cases to reflect changes in argument passing and processing of instances within frames; updated iteration and return values accordingly.

Sequence Diagram(s) (Beta)

sequenceDiagram
    participant User
    participant Tracker
    participant GlobalTrackingTransformer
    participant GTRRunner

    User->>Tracker: Call sliding_inference(device)
    Tracker->>GlobalTrackingTransformer: Call forward(ref_instances, query_instances)
    GlobalTrackingTransformer->>GTRRunner: Call extract_features(instances)
    GTRRunner->>GlobalTrackingTransformer: Return extracted features
    GlobalTrackingTransformer->>Tracker: Return tracking results
    Tracker->>User: Return inference results
Loading

Poem

In the realm of code, where data flows,
A rabbit hops, where logic grows.
Instances now with frames align,
Bounding boxes, times define.
Transformations smooth and clear,
Bringing enhancements, far and near.
With every change, precision's cheer! 🐰✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Out of diff range and nitpick comments (2)
biogtr/models/global_tracking_transformer.py (1)

Line range hint 82-122: Update docstring to reflect new parameters.

-        Args:
-            frames: List of Frames from chunk containing crops of objects + gt label info
-            query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window.
+        Args:
+            ref_instances: List of reference instances containing crops of objects + gt label info.
+            query_instances: List of query instances used for self attention. Only used in sliding inference where query instances are the last instances in the window.

The docstring should accurately describe the new parameters ref_instances and query_instances to avoid confusion.

biogtr/models/gtr_runner.py (1)

Line range hint 63-75: Update docstring to reflect new parameters.

-        Args:
-            instances: a list of dicts where each dict is a frame with gt data
+        Args:
+            ref_instances: List of reference instances containing crops of objects + gt label info.
+            query_instances: List of query instances used for self attention. Only used in sliding inference where query instances are the last instances in the window.

The docstring should accurately describe the new parameters ref_instances and query_instances to avoid confusion.

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between f0c6f3c and 9bab7bc.
Files selected for processing (8)
  • biogtr/data_structures.py (5 hunks)
  • biogtr/inference/tracker.py (5 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/gtr_runner.py (7 hunks)
  • biogtr/models/model_utils.py (1 hunks)
  • biogtr/models/transformer.py (8 hunks)
  • biogtr/training/losses.py (2 hunks)
  • tests/test_models.py (5 hunks)
Additional Context Used
Ruff (5)
biogtr/inference/tracker.py (3)

121-121: f-string without any placeholders


177-177: Undefined name curr_track


177-177: Local variable curr_track is assigned to but never used

biogtr/models/model_utils.py (1)

3-3: typing.Tuple imported but unused

tests/test_models.py (1)

428-428: Local variable img_shape is assigned to but never used

Additional comments not posted (16)
tests/test_models.py (4)

385-388: Ensure that ref_pos_emb and query_pos_emb are correctly initialized and passed to the transformer_decoder function. This change appears to be consistent with the function's purpose of testing the transformer decoder layer logic.


417-418: The list comprehension correctly flattens the list of instances from frames. Ensure that the transformer can handle the flattened list of instances as expected.


442-443: The list comprehension correctly flattens the list of instances from frames. Ensure that the transformer can handle the flattened list of instances as expected.


512-513: The list comprehension correctly flattens the list of instances from frames. Ensure that the tracking transformer can handle the flattened list of instances as expected.

biogtr/inference/tracker.py (4)

144-146: Ensure that frame_to_track.frame_id.device is correctly initialized and passed to the collate_tracks method. This change appears to be consistent with the method's purpose of performing sliding inference.


237-239: The list comprehensions correctly extract query_instances and all_instances from the frames. Ensure that the method can handle these lists of instances as expected.


262-262: Ensure that the model can handle all_instances and query_instances as inputs. This change appears to be consistent with the method's purpose of running the global tracker.


330-330: The updated calculation correctly uses nonquery_inds to extract asso_nonquery. Ensure that the method can handle this updated calculation as expected.

biogtr/models/transformer.py (4)

14-17: Imports look good and necessary for the changes made in the file.


143-149: The updated forward method signature and docstring correctly reflect the new parameters ref_instances and query_instances.


173-237: The handling of embeddings and positional encodings appears correct and aligns with the new input structure.


Line range hint 240-267: The decoder and attention head logic correctly integrates the new input structure and performs the necessary operations.

biogtr/data_structures.py (4)

127-128: Initialize _frame attribute to None.

The initialization of the _frame attribute to None in the constructor is correct and follows best practices.


609-610: Ensure instances is a list of Instance objects.

Ensure that the instances parameter is always a list of Instance objects. If not, consider adding a type check or validation to handle unexpected types.

if not all(isinstance(instance, Instance) for instance in instances):
    raise ValueError("All elements in instances must be of type Instance")

824-826: Ensure instances is a list of Instance objects.

Ensure that the instances parameter is always a list of Instance objects. If not, consider adding a type check or validation to handle unexpected types.

if not all(isinstance(instance, Instance) for instance in instances):
    raise ValueError("All elements in instances must be of type Instance")

640-640: Update __repr__ method to display keys of _traj_score.

The update to the __repr__ method to display the keys of _traj_score instead of values is correct and improves the readability of the output.

Comment on lines +160 to +162
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a check to ensure that ref_instances is not empty before concatenating features to avoid potential runtime errors.

+        if not ref_instances:
+            raise ValueError("ref_instances should not be empty")
+
        ref_features = torch.cat(
            [instance.features for instance in ref_instances], dim=0
        ).unsqueeze(0)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
if not ref_instances:
raise ValueError("ref_instances should not be empty")
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)

biogtr/data_structures.py Show resolved Hide resolved
Comment on lines +5 to +6
from biogtr.data_structures import Instance
import torch
Copy link
Contributor

@coderabbitai coderabbitai bot May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

- from biogtr.data_structures import Instance

The Instance class is not used in this file. Removing unused imports helps keep the code clean and maintainable.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.data_structures import Instance
import torch
import torch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used for type hinting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaprasad, thank you for pointing that out. You're absolutely right; the Instance class is used for type hinting in the forward method parameters. I appreciate your correction, and I'll make sure to consider this in future reviews. 🐇


New Learnings Added
User: aaprasad
PR: talmolab/biogtr#46
File: biogtr/models/global_tracking_transformer.py:5-6
Timestamp: 2024-05-29T17:50:29.550Z
Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.

Note: Learnings are effective only in the context of specific code segments. To apply general coding guidelines, please configure review instructions.

Comment on lines 95 to 120
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in ref_instances
]
):
crops = frame.get_crops()
z = self.visual_encoder(crops)
ref_crops = torch.concat(
[instance.crop for instance in ref_instances], axis=0
)
ref_z = self.visual_encoder(ref_crops)
for i, z_i in enumerate(ref_z):
ref_instances[i].features = z_i

for i, z_i in enumerate(z):
frame.instances[i].features = z_i
if query_instances:
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in query_instances
]
):
query_crops = torch.concat(
[instance.crop for instance in query_instances], axis=0
)
query_z = self.visual_encoder(query_crops)
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i
Copy link
Contributor

@coderabbitai coderabbitai bot May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify feature extraction logic.

-        if any(
-            [
-                (not instance.has_features()) and instance.has_crop()
-                for instance in ref_instances
-            ]
-        ):
-            ref_crops = torch.concat(
-                [instance.crop for instance in ref_instances], axis=0
-            )
-            ref_z = self.visual_encoder(ref_crops)
-            for i, z_i in enumerate(ref_z):
-                ref_instances[i].features = z_i
-
-        if query_instances:
-            if any(
-                [
-                    (not instance.has_features()) and instance.has_crop()
-                    for instance in query_instances
-                ]
-            ):
-                query_crops = torch.concat(
-                    [instance.crop for instance in query_instances], axis=0
-                )
-                query_z = self.visual_encoder(query_crops)
-                for i, z_i in enumerate(query_z):
-                    query_instances[i].features = z_i
+        def extract_features(instances):
+            if any((not instance.has_features()) and instance.has_crop() for instance in instances):
+                crops = torch.concat([instance.crop for instance in instances], axis=0)
+                features = self.visual_encoder(crops)
+                for i, feature in enumerate(features):
+                    instances[i].features = feature
+
+        extract_features(ref_instances)
+        if query_instances:
+            extract_features(query_instances)

This refactor reduces code duplication and improves readability by extracting the feature extraction logic into a helper function.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in ref_instances
]
):
crops = frame.get_crops()
z = self.visual_encoder(crops)
ref_crops = torch.concat(
[instance.crop for instance in ref_instances], axis=0
)
ref_z = self.visual_encoder(ref_crops)
for i, z_i in enumerate(ref_z):
ref_instances[i].features = z_i
for i, z_i in enumerate(z):
frame.instances[i].features = z_i
if query_instances:
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in query_instances
]
):
query_crops = torch.concat(
[instance.crop for instance in query_instances], axis=0
)
query_z = self.visual_encoder(query_crops)
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i
def extract_features(instances):
if any((not instance.has_features()) and instance.has_crop() for instance in instances):
crops = torch.concat([instance.crop for instance in instances], axis=0)
features = self.visual_encoder(crops)
for i, feature in enumerate(features):
instances[i].features = feature
extract_features(ref_instances)
if query_instances:
extract_features(query_instances)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this! Could be useful to have it as an instance method so that you can call it by itself as well for different kinds of preprocessing pipelines.

It may also be helpful to have it not be all-or-nothing and instead just run on the ones without features (or have a flag in this function to recompute the features explicitly).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@talmo, great suggestions! Implementing the feature extraction as an instance method will indeed enhance modularity and allow for more flexible usage across different preprocessing pipelines. Here's how you could modify the GlobalTrackingTransformer class to include this functionality:

class GlobalTrackingTransformer(torch.nn.Module):
    ...
    def extract_features(self, instances, force_recompute=False):
        """
        Extract features for instances that either don't have them or are forced to recompute.

        Args:
            instances (list[Instance]): The instances to process.
            force_recompute (bool): If True, recompute features even if they already exist.
        """
        for instance in instances:
            if force_recompute or (not instance.has_features() and instance.has_crop()):
                crop = instance.crop.unsqueeze(0)  # Assuming crop is a tensor that needs batch dimension
                feature = self.visual_encoder(crop)
                instance.features = feature.squeeze(0)  # Remove batch dimension after processing

    def forward(self, ref_instances, query_instances=None):
        """Execute forward pass of GTR Model to get association matrix."""
        self.extract_features(ref_instances)
        if query_instances:
            self.extract_features(query_instances)
        ...

This method extract_features checks each instance to determine if it needs feature extraction based on the force_recompute flag or the absence of existing features. It processes only the necessary instances, which can improve efficiency, especially when dealing with large datasets.

Feel free to adjust the tensor handling (unsqueeze and squeeze) based on your actual data structure. This should give you a good starting point to integrate the method into your preprocessing pipelines!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@talmo if I make it a class method it won't have access to self.visual_encoder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

@@ -2,35 +2,69 @@

from typing import List, Tuple, Iterable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

- from typing import Tuple

The Tuple type is not used in this file. Removing unused imports helps keep the code clean and maintainable.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from typing import List, Tuple, Iterable
from typing import List, Iterable

biogtr/models/model_utils.py Show resolved Hide resolved
biogtr/models/model_utils.py Show resolved Hide resolved
@@ -8,6 +8,7 @@
from biogtr.training.losses import AssoLoss
from biogtr.models.model_utils import init_optimizer, init_scheduler
from pytorch_lightning import LightningModule
from biogtr.data_structures import Frame, Instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

- from biogtr.data_structures import Frame, Instance
+ from biogtr.data_structures import Frame

The Instance class is not used in this file. Removing unused imports helps keep the code clean and maintainable.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from biogtr.data_structures import Frame, Instance
from biogtr.data_structures import Frame

Comment on lines 159 to 164
frames = [frame for frame in frames if frame.has_instances()]
if len(frames) == 0:
return None

instances = [instance for frame in frames for instance in frame.instances]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize instance extraction.

-            frames = [frame for frame in frames if frame.has_instances()]
-            if len(frames) == 0:
-                return None
-
-            instances = [instance for frame in frames for instance in frame.instances]
+            instances = [
+                instance
+                for frame in frames if frame.has_instances()
+                for instance in frame.instances
+            ]
+            if len(instances) == 0:
+                return None

This refactor combines the filtering and extraction of instances into a single list comprehension, making the code more concise and efficient.


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
frames = [frame for frame in frames if frame.has_instances()]
if len(frames) == 0:
return None
instances = [instance for frame in frames for instance in frame.instances]
instances = [
instance
for frame in frames if frame.has_instances()
for instance in frame.instances
]
if len(instances) == 0:
return None

biogtr/training/losses.py Show resolved Hide resolved
Comment on lines 95 to 120
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in ref_instances
]
):
crops = frame.get_crops()
z = self.visual_encoder(crops)
ref_crops = torch.concat(
[instance.crop for instance in ref_instances], axis=0
)
ref_z = self.visual_encoder(ref_crops)
for i, z_i in enumerate(ref_z):
ref_instances[i].features = z_i

for i, z_i in enumerate(z):
frame.instances[i].features = z_i
if query_instances:
if any(
[
(not instance.has_features()) and instance.has_crop()
for instance in query_instances
]
):
query_crops = torch.concat(
[instance.crop for instance in query_instances], axis=0
)
query_z = self.visual_encoder(query_crops)
for i, z_i in enumerate(query_z):
query_instances[i].features = z_i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this! Could be useful to have it as an instance method so that you can call it by itself as well for different kinds of preprocessing pipelines.

It may also be helpful to have it not be all-or-nothing and instead just run on the ones without features (or have a flag in this function to recompute the features explicitly).

@@ -59,7 +60,9 @@ def __init__(
self.metrics = metrics
self.persistent_tracking = persistent_tracking

def forward(self, instances) -> torch.Tensor:
def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring for the new signature


def training_step(
self, train_batch: list[dict], batch_idx: int
self, train_batch: list[list[Frame]], batch_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

@@ -92,7 +93,7 @@ def training_step(
return result

def validation_step(
self, val_batch: list[dict], batch_idx: int
self, val_batch: list[list[Frame]], batch_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

@@ -109,7 +110,9 @@ def validation_step(

return result

def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]:
def test_step(
self, test_batch: list[list[Frame]], batch_idx: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring

import torch


def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Extract the bounding boxes and frame indices from the input list of instances.
def get_boxes(instances: List[Instance]) -> torch.tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_boxes(instances: List[Instance]) -> torch.tensor:
def get_boxes(instances: List[Instance]) -> torch.Tensor:

Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors containing the
bounding boxes normalized by the height and width of the image
and corresponding frame indices, respectively.
The bounding boxes normalized by the height and width of the image
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape? dtype?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between 9bab7bc and 076398a.
Files selected for processing (4)
  • biogtr/models/global_tracking_transformer.py (2 hunks)
  • biogtr/models/gtr_runner.py (5 hunks)
  • biogtr/models/model_utils.py (1 hunks)
  • biogtr/models/transformer.py (8 hunks)
Additional Context Used
Ruff (1)
biogtr/models/model_utils.py (1)

3-3: typing.Tuple imported but unused

Learnings (1)
biogtr/models/global_tracking_transformer.py (1)
User: aaprasad
PR: talmolab/biogtr#46
File: biogtr/models/global_tracking_transformer.py:5-6
Timestamp: 2024-05-29T17:50:29.550Z
Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
Additional comments not posted (25)
biogtr/models/global_tracking_transformer.py (4)

5-6: Imports are correctly used for type hinting in the forward method parameters.


Line range hint 11-82: The __init__ method is well-implemented with clear documentation and appropriate initialization of components.


82-100: The forward method is correctly implemented, handling both reference and optional query instances effectively.


104-132: The extract_features method is efficiently implemented, with checks for necessary feature computation and handling of batch normalization errors.

biogtr/models/model_utils.py (5)

9-29: The get_boxes function is correctly implemented, efficiently extracting and normalizing bounding boxes from instances.


32-68: The get_times function is well-implemented, correctly handling both reference and optional query instances to extract time indices.


Line range hint 70-86: The softmax_asso function is correctly implemented, applying softmax to association outputs efficiently.


Line range hint 88-118: The init_optimizer function is well-implemented, handling various configurations and correctly initializing the appropriate optimizer.


Line range hint 120-156: The init_scheduler function is correctly implemented, handling various configurations and correctly initializing the appropriate scheduler.

biogtr/models/gtr_runner.py (8)

11-11: The import of Frame and Instance is correctly used for type hinting and data handling within the class.


63-76: The forward method is correctly implemented, handling both reference and optional query instances effectively.


Line range hint 79-91: The training_step method is well-implemented, correctly processing the training batch and logging metrics.


Line range hint 97-109: The validation_step method is correctly implemented, efficiently processing the validation batch and logging metrics.


Line range hint 115-127: The test_step method is well-implemented, correctly processing the test batch and logging metrics.


132-147: The predict_step method is correctly implemented, effectively running inference on the batch and returning predicted frames.


149-182: The _shared_eval_step method is efficiently implemented, used by train, test, and validation steps to run evaluation and handle exceptions appropriately.


Line range hint 184-213: The configure_optimizers method is well-implemented, handling optional configurations for optimizers and schedulers effectively.

biogtr/models/transformer.py (8)

14-17: Imports are correctly used for type hinting and functionality within the class.


Line range hint 25-59: The __init__ method of the Transformer class is well-implemented, with clear documentation and appropriate initialization of components.


Line range hint 143-250: The forward method is correctly implemented, handling both reference and optional query instances effectively, and returning detailed outputs.


Line range hint 262-322: The TransformerEncoderLayer class is well-implemented, with a clear and efficient implementation of the encoder layer.


Line range hint 324-410: The TransformerDecoderLayer class is correctly implemented, handling optional decoder self-attention and providing a clear implementation of the decoder layer.


Line range hint 412-438: The TransformerEncoder class is well-implemented, correctly handling multiple encoder layers and optional normalization.


Line range hint 440-472: The TransformerDecoder class is correctly implemented, handling multiple decoder layers and optional normalization and intermediate layer returns efficiently.


Line range hint 474-502: The utility functions _get_clones and _get_activation_fn are well-implemented, providing necessary support for the transformer model.


Args:
instances: A list of instances to compute features for
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
force_recompute: indicate whether to compute features for all instances regardless of if they have instances
force_recompute: indicate whether to compute features for all instances regardless of if they have features

train_batch: A single batch from the dataset which is a list of dicts
with length `clip_length` where each dict is a frame
train_batch: A single batch from the dataset which is a list of `Frame` objects
with length `clip_length` containing Instances and other metadata.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with length `clip_length` containing Instances and other metadata.
with length `clip_length` containing Instances and other metadata.

@aaprasad aaprasad merged commit dc30967 into main May 29, 2024
3 checks passed
@aaprasad aaprasad deleted the aadi/instance-model-input branch May 29, 2024 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants