Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data structures #47

Merged
merged 21 commits into from
Jun 3, 2024
Merged

Refactor data structures #47

merged 21 commits into from
Jun 3, 2024

Conversation

aaprasad
Copy link
Contributor

@aaprasad aaprasad commented May 20, 2024

As per #39, just as we store model inputs in class objects for easy manipulation, we want to do the same for model outputs.
Thus, in this PR we:

  • separate out input/output data classes into biogtr.io and have separate modules for biogtr.io.Frame and biogtr.io.Instance
  • Refactor Frame and Instance classes to use attrs for initialization
  • Add an AssociationMatrix class which
    - stores the association matrix
    - enables easy lookup thru either int or Instance indexing
    - reduces the association matrix to (n_query/n_traj , n_ref/n_traj)
  • Add a Track object which stores instances of the same track id
  • Address Store embeddings in Instance or Frame Objects #35 by storing embeddings in Instance object and have models just return AssociationMatrix's

Summary by CodeRabbit

  • New Features

    • Introduced AssociationMatrix for managing and analyzing association scores.
    • Added Instance class for handling individual tracking instances.
    • Introduced Frame class for managing video frame data and related instances.
    • Added Track class for managing instances of the same track.
  • Enhancements

    • Updated GlobalTrackingTransformer to improve instance handling and feature extraction.
    • Enhanced GTRRunner to operate on lists of Frame objects for better data management and processing.
  • Bug Fixes

    • Corrected import paths and method references to ensure seamless functionality.
  • Tests

    • Added new test cases for AssociationMatrix and Track to improve testing coverage and reliability.
    • Refactored existing tests for Instance and Frame for better accuracy and robustness.

Copy link
Contributor

coderabbitai bot commented May 20, 2024

Walkthrough

The recent changes primarily focus on enhancing the biogtr library by introducing new data structures and functionalities. Major updates include the addition of the AssociationMatrix class, new methods and attributes for the Instance, Frame, and Track classes, and significant refactoring of the GlobalTrackingTransformer and GTRRunner classes. These changes improve the handling of instance data, streamline the tracking process, and enhance the overall performance and flexibility of the library.

Changes

File(s) Change Summary
biogtr/io/association_matrix.py Introduced AssociationMatrix class with methods for matrix manipulation and conversion.
biogtr/io/instance.py Added Instance class with methods for handling instance data, including conversion and device management.
biogtr/io/frame.py Introduced Frame class for managing video frame data, including instances and association matrices.
biogtr/io/track.py Added Track class to manage instances of the same track, with methods for manipulation and retrieval.
biogtr/inference/tracker.py Updated import statements and method calls to reflect new data structures, adjusted data handling in methods.
biogtr/models/global_tracking_transformer.py Refactored GlobalTrackingTransformer class to use torch.nn.Module, updated forward method to handle Instance objects.
biogtr/models/gtr_runner.py Modified GTRRunner class methods to operate on lists of Frame objects, updated forward method parameters.
biogtr/models/transformer.py Updated Transformer class to handle Instance objects, revised embeddings and positional embeddings handling.
tests/test_data_model.py Refactored and added new tests for Instance, Frame, Track, and AssociationMatrix classes.

Poem

In the land where code does thrive,
A matrix blooms, associations alive.
Instances dance in frames so bright,
Tracks align in the data's light.
Transformers now with vision clear,
Propel the models far and near.
Cheers to changes, swift and grand,
In the world of biogtr, they stand! 🌟


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

Share
Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai generate interesting stats about this repository and render them as a table.
    • @coderabbitai show all the console.log statements in this repository.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (invoked as PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Additionally, you can add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.

CodeRabbit Configration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 23

Outside diff range and nitpick comments (9)
biogtr/io/__init__.py (4)

3-3: Consider adding Frame to __all__ to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.


4-4: Consider adding Instance to __all__ to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.


5-5: Consider adding AssociationMatrix to __all__ to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.


6-6: Consider adding Track to __all__ to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.

biogtr/datasets/microscopy_dataset.py (1)

Line range hint 87-87: Replace the lambda expression with a function definition for better readability and maintainability.

- parser = lambda x: data_utils.parse_synthetic(x, source=source)
+ def parser(x):
+     return data_utils.parse_synthetic(x, source=source)
tests/test_inference.py (1)

Line range hint 168-168: Local variable N is assigned but never used.

- N = N_t * T
tests/test_models.py (1)

Line range hint 429-429: Remove the unused variable img_shape to clean up the code.

- img_shape = (1, 100, 100)
biogtr/inference/tracker.py (2)

Line range hint 121-121: F-string without any placeholders.

This line contains an f-string but does not have any placeholders. If dynamic string formatting is not needed, consider using a regular string.


Line range hint 177-177: Undefined name 'curr_track' and local variable 'curr_track' is assigned to but never used.

It seems there is a typo or logical error here. The variable curr_track is incremented but never initialized. Ensure that it is correctly initialized and used, or remove it if unnecessary.

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between f0c6f3c and c0ceac1.
Files selected for processing (26)
  • biogtr/datasets/base_dataset.py (1 hunks)
  • biogtr/datasets/cell_tracking_dataset.py (2 hunks)
  • biogtr/datasets/eval_dataset.py (1 hunks)
  • biogtr/datasets/microscopy_dataset.py (2 hunks)
  • biogtr/datasets/sleap_dataset.py (1 hunks)
  • biogtr/inference/metrics.py (2 hunks)
  • biogtr/inference/track.py (1 hunks)
  • biogtr/inference/track_queue.py (1 hunks)
  • biogtr/inference/tracker.py (9 hunks)
  • biogtr/io/init.py (1 hunks)
  • biogtr/io/association_matrix.py (1 hunks)
  • biogtr/io/frame.py (1 hunks)
  • biogtr/io/instance.py (1 hunks)
  • biogtr/io/track.py (1 hunks)
  • biogtr/io/visualize.py (1 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/gtr_runner.py (7 hunks)
  • biogtr/models/model_utils.py (1 hunks)
  • biogtr/models/transformer.py (9 hunks)
  • biogtr/training/losses.py (2 hunks)
  • biogtr/training/train.py (1 hunks)
  • tests/test_config.py (1 hunks)
  • tests/test_data_model.py (4 hunks)
  • tests/test_inference.py (1 hunks)
  • tests/test_models.py (6 hunks)
  • tests/test_training.py (1 hunks)
Files skipped from review due to trivial changes (1)
  • biogtr/training/train.py
Additional Context Used
Ruff (31)
biogtr/datasets/microscopy_dataset.py (1)

87-87: Do not assign a lambda expression, use a def

biogtr/datasets/sleap_dataset.py (1)

9-9: warnings imported but unused

biogtr/inference/track.py (1)

133-133: f-string without any placeholders

biogtr/inference/tracker.py (3)

121-121: f-string without any placeholders


177-177: Undefined name curr_track


177-177: Local variable curr_track is assigned to but never used

biogtr/io/__init__.py (4)

3-3: biogtr.io.frame.Frame imported but unused; consider removing, adding to __all__, or using a redundant alias


4-4: biogtr.io.instance.Instance imported but unused; consider removing, adding to __all__, or using a redundant alias


5-5: biogtr.io.association_matrix.AssociationMatrix imported but unused; consider removing, adding to __all__, or using a redundant alias


6-6: biogtr.io.track.Track imported but unused; consider removing, adding to __all__, or using a redundant alias

biogtr/io/frame.py (6)

47-47: Undefined name Instance


48-48: Undefined name AssociationMatrix


243-243: Undefined name Instance


252-252: Undefined name Instance


282-282: Undefined name AssociationMatrix


301-301: Undefined name AssociationMatrix

biogtr/io/instance.py (3)

90-90: Undefined name Frame


415-415: Undefined name Frame


424-424: Undefined name Frame

biogtr/io/track.py (5)

17-17: Undefined name Instance


48-48: Undefined name Instances


66-66: Undefined name Frame


82-82: Undefined name Instance


82-82: Undefined name Instance

biogtr/models/model_utils.py (1)

3-3: typing.Tuple imported but unused

tests/test_inference.py (2)

4-4: pytest imported but unused


168-168: Local variable N is assigned to but never used

tests/test_models.py (1)

429-429: Local variable img_shape is assigned to but never used

tests/test_training.py (3)

10-10: biogtr.models.global_tracking_transformer.GlobalTrackingTransformer imported but unused


11-11: omegaconf.DictConfig imported but unused


47-47: Local variable feats is assigned to but never used

Additional comments not posted (45)
biogtr/datasets/eval_dataset.py (2)

4-4: The import of Frame is correctly used in the return type of __getitem__. Good job ensuring type consistency.


5-5: The import of Instance is correctly used for creating Instance objects within the __getitem__ method. Well done on maintaining clear type usage.

tests/test_config.py (1)

4-4: The import of Config is correctly used throughout the test functions. Good job ensuring that the tests cover the updated import paths.

tests/test_training.py (3)

6-6: The import of Frame is correctly used in the creation of Frame objects within the test functions. Well done on maintaining clear type usage.


7-7: The import of Instance is correctly used for creating Instance objects within the test functions. Good job ensuring type consistency.


12-12: The import of Config is correctly used in the test_config_gtr_runner function. Good job ensuring that the tests cover the updated import paths.

biogtr/inference/track.py (2)

3-3: The import of Config is correctly used in the main function to load configuration settings. Well done on maintaining clear usage.


5-5: The import of Frame is correctly used in the export_trajectories and inference functions to handle frame data. Good job ensuring type consistency.

biogtr/models/global_tracking_transformer.py (1)

Line range hint 82-122: The refactoring to handle Instance objects instead of Frame objects is well-implemented. Consider adding more detailed comments explaining the processing of Instance objects, especially in the feature extraction logic.

biogtr/datasets/base_dataset.py (1)

4-4: The update to the import path for Frame aligns with the new structure of the project.

tests/test_data_model.py (3)

10-10: The tests for the Instance object are comprehensive and cover all necessary attributes and methods.


108-116: The tests for the Frame object, including its interaction with the AssociationMatrix, are well-implemented and thorough.


135-192: The tests for the AssociationMatrix are comprehensive, effectively testing both normal operations and error conditions.

biogtr/datasets/microscopy_dataset.py (1)

6-7: The update to the import paths for Frame and Instance aligns with the new structure of the project.

biogtr/models/model_utils.py (1)

9-28: The function get_boxes correctly extracts bounding boxes from instances and normalizes them as expected.

biogtr/datasets/cell_tracking_dataset.py (2)

6-7: The update to the import paths for Frame and Instance aligns with the new structure of the project.


126-126: The CellTrackingDataset class correctly handles Frame and Instance objects, aligning with the refactoring goals of the PR.

tests/test_inference.py (1)

6-7: Updated imports to reflect new module structure.

biogtr/models/gtr_runner.py (5)

11-12: Updated imports to reflect new module structure.


Line range hint 64-76: Refactored forward method to handle Instance objects instead of Frame objects, aligning with the new data structure.


149-185: Refactored _shared_eval_step to handle new data structures and metrics. Ensure exception handling and logging are correctly implemented.

Verification successful

The search results indicate that exception handling and logging are implemented in various parts of the codebase, including the _shared_eval_step method in biogtr/models/gtr_runner.py. The exception handling in _shared_eval_step is consistent with other parts of the codebase, where exceptions are caught, logged, and re-raised.

  • The _shared_eval_step method catches exceptions, logs the frame and video IDs, and re-raises the exception.
  • This pattern is consistent with other exception handling instances found in the codebase.

Therefore, the refactored _shared_eval_step method appears to handle exceptions and logging correctly.

Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify exception handling and logging in `_shared_eval_step`.

# Test: Search for exception handling patterns in the codebase. Expect: Consistent and correct implementation.
rg --type python $'try|except' --context 5

Length of output: 74



Script:

#!/bin/bash
# Description: Verify exception handling and logging in `_shared_eval_step`.

# Test: Search for exception handling patterns in the codebase. Expect: Consistent and correct implementation.
rg 'try|except' --glob '*.py' --context 5

Length of output: 26532


Line range hint 132-147: Refactored predict_step to use the updated tracking logic. Ensure that the tracker is correctly initialized and used here.

Verification successful

The tracker is correctly initialized in the gtr_runner.py file:

  • biogtr/models/gtr_runner.py: self.tracker = Tracker(**tracker_cfg)

This confirms that the tracker is properly set up before being used in the predict_step method.

Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify the initialization and usage of `tracker` in `predict_step`.

# Test: Search for the tracker initialization in the codebase. Expect: Correct initialization and usage.
rg --type python $'Tracker('

Length of output: 96



Script:

#!/bin/bash
# Description: Verify the initialization and usage of `tracker` in `predict_step`.

# Test: Search for the tracker initialization in the codebase. Expect: Correct initialization and usage.
rg --type python 'Tracker\('

Length of output: 61



Script:

#!/bin/bash
# Description: Verify the initialization and usage of `tracker` in `predict_step`.

# Test: Search for the tracker initialization in the codebase. Expect: Correct initialization and usage.
rg --glob '*.py' 'Tracker\('

Length of output: 166


Line range hint 79-97: Updated training_step, validation_step, and test_step methods to handle lists of Frame objects. Ensure that the data passed to these methods are correctly structured as lists of lists of Frame objects.

Also applies to: 114-116

Verification successful

The methods training_step, validation_step, and test_step are defined in biogtr/models/gtr_runner.py and are called in tests/test_training.py. The calls to these methods pass [batch] as an argument, which suggests that the data structure is a list of lists of Frame objects.

  • tests/test_training.py: metrics = gtr_runner.training_step([batch], i)
  • tests/test_training.py: metrics = gtr_runner.validation_step([batch], j)
  • tests/test_training.py: metrics = gtr_runner.test_step([batch], k)
Scripts Executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify that the data passed to training_step, validation_step, and test_step are lists of lists of Frame objects.

# Test: Search for the method calls in the codebase. Expect: Correct data structure in method calls.
rg --type python $'GTRRunner.training_step|GTRRunner.validation_step|GTRRunner.test_step'

Length of output: 121



Script:

#!/bin/bash
# Description: Verify that the data passed to training_step, validation_step, and test_step are lists of lists of Frame objects.

# Test: Search for the method calls in the codebase. Expect: Correct data structure in method calls.
rg --type py $'GTRRunner.training_step|GTRRunner.validation_step|GTRRunner.test_step'

Length of output: 86



Script:

#!/bin/bash
# Description: Verify the definitions and calls of training_step, validation_step, and test_step methods within the GTRRunner class.

# Search for the definitions of the methods in the codebase.
rg --type py 'def training_step|def validation_step|def test_step'

# Search for the calls to these methods in the codebase.
rg --type py 'training_step|validation_step|test_step'

Length of output: 675

biogtr/training/losses.py (2)

3-4: Updated imports to reflect new module structure and utility functions.


52-56: Refactored to use the new get_boxes and get_times utility functions, aligning with the updated data structures.

biogtr/inference/metrics.py (2)

6-6: Updated import to reflect new module structure.


108-108: Updated documentation to reflect new data structures.

biogtr/inference/track_queue.py (1)

4-4: Updated import to reflect new module structure.

biogtr/io/visualize.py (2)

18-56: The implementation of fill_missing function is robust and well-documented.


270-284: The save_vid function is correctly implemented and effectively saves video frames to a file.

biogtr/datasets/sleap_dataset.py (1)

10-11: Correctly updated imports to reflect the new structure of the biogtr.io package.

tests/test_models.py (1)

5-6: Updated imports to reflect the new structure of the biogtr.io package.

biogtr/io/instance.py (4)

53-90: Refactor the Instance class to use attrs for initialization.

The use of attrs for defining the Instance class simplifies the code and makes it more maintainable by automatically generating methods like __init__, __repr__, and others based on the defined fields. This is a good application of the DRY principle.


92-116: Review the __attrs_post_init__ method for proper initialization logic.

The method correctly handles the initialization of various attributes with appropriate default values and conditions. This method ensures that the instance is fully prepared for use right after its creation.


131-148: Ensure proper device handling in the to method.

The to method is well-implemented, providing functionality to change the device or data type of the instance's tensor attributes. This is crucial for operations in different computing environments, especially in deep learning contexts.


150-183: Review the to_slp method for converting to sleap_io.PredictedInstance.

This method provides a bridge between the internal data representation and the SLEAP framework, which is essential for integrating with other components that expect SLEAP data types.

biogtr/models/transformer.py (9)

14-15: Update imports to reflect new data structure locations.


144-145: Refactor the forward method to handle lists of Instance objects.


149-150: Update documentation to reflect new parameter types and expected data structures.


173-176: Ensure robust handling of potential NaN values in reference boxes.


179-181: Ensure embeddings are correctly calculated and returned based on the return_embedding flag.

Also applies to: 183-185


212-238: Handle optional query_instances with appropriate checks and transformations.


257-262: Correctly form and append new AssociationMatrix objects.


Line range hint 388-409: Ensure positional embeddings are added correctly in the decoder layer.


Line range hint 512-535: Properly manage the flow of data through the decoder layers, respecting the return_intermediate flag.

ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the concatenation of instance features.

Consider using a more efficient method for concatenating tensors, especially when dealing with large lists of instances. This could potentially improve performance.

- [instance.features for instance in ref_instances], dim=0
+ torch.stack([instance.features for instance in ref_instances], dim=0)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
ref_features = torch.stack(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)

biogtr/io/frame.py Show resolved Hide resolved
biogtr/io/frame.py Show resolved Hide resolved
biogtr/io/association_matrix.py Outdated Show resolved Hide resolved
raise ValueError(
(
"Query instances must equal number of rows in Association matrix"
f"Found {len(value)} query instances but {self.matrix.shape[0]} columns."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typographical error in error message.

- f"Found {len(value)} query instances but {self.matrix.shape[0]} columns."
+ f"Found {len(value)} query instances but {self.matrix.shape[0]} rows."

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
f"Found {len(value)} query instances but {self.matrix.shape[0]} columns."
f"Found {len(value)} query instances but {self.matrix.shape[0]} rows."

biogtr/io/association_matrix.py Outdated Show resolved Hide resolved
@@ -7,7 +7,8 @@
import sleap_io as sio
import random
import warnings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import of warnings.

- import warnings

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
import warnings

@@ -3,13 +3,93 @@
import torch
import pytest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import pytest.

- import pytest

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
import pytest

biogtr/io/__init__.py Show resolved Hide resolved
biogtr/models/model_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between c0ceac1 and a6a6ace.
Files selected for processing (3)
  • biogtr/io/association_matrix.py (1 hunks)
  • biogtr/io/instance.py (1 hunks)
  • tests/test_data_model.py (4 hunks)
Additional Context Used
Ruff (3)
biogtr/io/instance.py (3)

91-91: Undefined name Frame


461-461: Undefined name Frame


470-470: Undefined name Frame

Additional comments not posted (12)
tests/test_data_model.py (4)

10-10: Comprehensive test coverage for Instance class.


108-116: Effective testing of Frame class with comprehensive checks on properties and associated objects.


135-196: Thorough testing of AssociationMatrix indexing and error handling.


199-210: Basic but correct testing of Track class properties.

biogtr/io/association_matrix.py (5)

11-23: Proper initialization and validation in AssociationMatrix constructor.


75-83: Correct conversion of association matrix to numpy array.


85-131: Flexible and well-documented method for converting association matrix to DataFrame.


133-196: Complex but correctly implemented method for reducing association matrix with detailed documentation.


198-221: Correct retrieval of elements from association matrix with comprehensive indexing support.

biogtr/io/instance.py (3)

53-131: Proper initialization and conversion in Instance constructor with comprehensive attribute handling.


132-149: Correct handling of device and data type changes in Instance objects.


151-184: Correct conversion of Instance to sleap_io.PredictedInstance with appropriate error handling.

@aaprasad aaprasad marked this pull request as draft May 21, 2024 19:45
Copy link

codecov bot commented May 21, 2024

Codecov Report

Attention: Patch coverage is 67.34450% with 273 lines in your changes are missing coverage. Please review.

Project coverage is 68.48%. Comparing base (041d0a4) to head (89007a9).
Report is 12 commits behind head on main.

Current head 89007a9 differs from pull request most recent head 48efb7d

Please upload reports for the commit 48efb7d to get more accurate results.

Files Patch % Lines
biogtr/io/visualize.py 0.00% 115 Missing ⚠️
biogtr/io/instance.py 75.39% 62 Missing ⚠️
biogtr/io/frame.py 75.00% 49 Missing ⚠️
biogtr/io/association_matrix.py 75.96% 25 Missing ⚠️
biogtr/io/track.py 77.41% 7 Missing ⚠️
biogtr/models/global_tracking_transformer.py 76.47% 4 Missing ⚠️
biogtr/models/gtr_runner.py 80.95% 4 Missing ⚠️
biogtr/models/model_utils.py 86.36% 3 Missing ⚠️
biogtr/datasets/eval_dataset.py 0.00% 2 Missing ⚠️
biogtr/inference/track.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #47      +/-   ##
==========================================
- Coverage   69.60%   68.48%   -1.12%     
==========================================
  Files          28       34       +6     
  Lines        2385     2891     +506     
==========================================
+ Hits         1660     1980     +320     
- Misses        725      911     +186     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

Outside diff range and nitpick comments (3)
tests/test_models.py (1)

Line range hint 429-429: Remove the unused local variable 'img_shape'.

- img_shape = (1, 100, 100)
biogtr/inference/tracker.py (2)

Line range hint 121-121: Remove unnecessary f-string as it contains no placeholders.

- warnings.warn(f"Clearing Queue after tracking")
+ warnings.warn("Clearing Queue after tracking")

Line range hint 177-177: Fix the undefined variable curr_track and ensure it is used correctly.

- curr_track += 1
+ curr_track_id += 1
Review Details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits Files that changed from the base of the PR and between a6a6ace and 89007a9.
Files selected for processing (9)
  • biogtr/inference/tracker.py (9 hunks)
  • biogtr/io/frame.py (1 hunks)
  • biogtr/io/instance.py (1 hunks)
  • biogtr/io/track.py (1 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/gtr_runner.py (7 hunks)
  • biogtr/models/transformer.py (8 hunks)
  • tests/test_data_model.py (4 hunks)
  • tests/test_models.py (6 hunks)
Additional Context Used
Ruff (16)
biogtr/inference/tracker.py (3)

121-121: f-string without any placeholders


177-177: Undefined name curr_track


177-177: Local variable curr_track is assigned to but never used

biogtr/io/frame.py (3)

59-59: Undefined name AssociationMatrix


328-328: Undefined name AssociationMatrix


347-347: Undefined name AssociationMatrix

biogtr/io/instance.py (3)

93-93: Undefined name Frame


520-520: Undefined name Frame


529-529: Undefined name Frame

biogtr/io/track.py (5)

17-17: Undefined name Instance


46-46: Undefined name Instances


64-64: Undefined name Frame


80-80: Undefined name Instance


80-80: Undefined name Instance

biogtr/models/global_tracking_transformer.py (1)

84-84: Undefined name AssociationMatrix

tests/test_models.py (1)

429-429: Local variable img_shape is assigned to but never used

GitHub Check Runs (1)
codecov/patch failure (30)

biogtr/io/frame.py: [warning] 21-21: biogtr/io/frame.py#L21
Added line #L21 was not covered by tests


biogtr/io/frame.py: [warning] 80-80: biogtr/io/frame.py#L80
Added line #L80 was not covered by tests


biogtr/io/frame.py: [warning] 109-109: biogtr/io/frame.py#L109
Added line #L109 was not covered by tests


biogtr/io/frame.py: [warning] 112-112: biogtr/io/frame.py#L112
Added line #L112 was not covered by tests


biogtr/io/frame.py: [warning] 115-116: biogtr/io/frame.py#L115-L116
Added lines #L115 - L116 were not covered by tests


biogtr/io/frame.py: [warning] 121-121: biogtr/io/frame.py#L121
Added line #L121 was not covered by tests


biogtr/io/frame.py: [warning] 141-146: biogtr/io/frame.py#L141-L146
Added lines #L141 - L146 were not covered by tests


biogtr/io/frame.py: [warning] 170-174: biogtr/io/frame.py#L170-L174
Added lines #L170 - L174 were not covered by tests


biogtr/io/frame.py: [warning] 201-201: biogtr/io/frame.py#L201
Added line #L201 was not covered by tests


biogtr/io/frame.py: [warning] 221-221: biogtr/io/frame.py#L221
Added line #L221 was not covered by tests


biogtr/io/frame.py: [warning] 241-241: biogtr/io/frame.py#L241
Added line #L241 was not covered by tests

Additional comments not posted (10)
tests/test_data_model.py (1)

108-116: Ensure the 'AssociationMatrix' is correctly initialized and used in the test.

The test correctly initializes the 'AssociationMatrix' with appropriate instances and asserts its properties.

biogtr/models/gtr_runner.py (1)

Line range hint 64-76: Review the forward method for correct handling of instances.

The method correctly handles the instances and uses the model's forward method to compute predictions.

biogtr/inference/tracker.py (4)

6-6: Update import paths to align with new namespace structure.


Line range hint 17-41: Initialization parameters are well-documented and correctly set up.


Line range hint 43-57: Method delegation is correctly implemented for simplifying tracker usage.


212-212: Ensure the docstring accurately describes the method's functionality and parameters.

biogtr/models/transformer.py (2)

144-145: Ensure correct handling of Instance objects in the forward method.

The method signature correctly expects lists of Instance objects for both reference and query instances, aligning with the changes in data handling.


158-160: Optimize the concatenation of instance features.

Consider using a more efficient method for concatenating tensors, especially when dealing with large lists of instances. This could potentially improve performance.

- [instance.features for instance in ref_instances], dim=0
+ torch.stack([instance.features for instance in ref_instances], dim=0)

Likely invalid or redundant comment.

biogtr/io/instance.py (2)

11-28: Utility functions _to_tensor and _expand_to_rank are implemented correctly and efficiently.

Also applies to: 30-52


55-650: The Instance class is well-structured and utilizes attrs effectively for attribute management. Ensure all dependencies, such as the Frame class, are correctly imported to prevent runtime errors.

biogtr/io/frame.py Show resolved Hide resolved
alias="img_shape", converter=_to_tensor, factory=list
)

_instances: list["Instance"] = attrs.field(alias="instances", factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undefined name 'AssociationMatrix' used in type hint.

Please ensure that AssociationMatrix is properly imported or defined within this module or in an importable module.

biogtr/io/frame.py Show resolved Hide resolved
biogtr/io/instance.py Show resolved Hide resolved
self._embeddings[emb_type] = embedding

@property
def frame(self) -> "Frame":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure the Frame class is properly imported to avoid runtime errors.

+ from biogtr.io.frame import Frame

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def frame(self) -> "Frame":
from biogtr.io.frame import Frame
def frame(self) -> "Frame":

biogtr/inference/tracker.py Outdated Show resolved Hide resolved
biogtr/io/track.py Show resolved Hide resolved
"""
return len(self.instances)

def __getitem__(self, ind) -> Union["Instance", list["Instance"]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure the type hint for 'Instance' is correctly defined.

- def __getitem__(self, ind) -> Union["Instance", list["Instance"]]:
+ def __getitem__(self, ind) -> Union[Instance, list[Instance]]:

Committable suggestion was skipped due low confidence.

biogtr/io/track.py Show resolved Hide resolved
biogtr/models/global_tracking_transformer.py Show resolved Hide resolved
@talmo talmo changed the base branch from main to aadi/instance-model-input May 28, 2024 21:31
biogtr/io/association_matrix.py Outdated Show resolved Hide resolved
Comment on lines 149 to 150
row_to, col_to = to
row_by, col_by = by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would strongly prefer to have these as the input arguments explicitly instead of tuples.

col_inds = [i for i in range(len(self.ref_instances))]
row_inds = [i for i in range(len(self.query_instances))]

if "tra" in col_to:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too clever bud. What if I provide "instrance"???

Be explicit and list out the conditions with equality checks. if col_to == "track" or col_to == "trajectory"


reduced_matrix = []
for row_track, row_instances in row_tracks.items():
# print(row_instances)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

col_tracks = self.get_tracks(self.ref_instances, col_by)
col_inds = list(col_tracks.keys())
n_cols = len(col_inds)
if "tra" in row_to:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^

Comment on lines 186 to 187
# print(col_instances)
# print(asso_matrix)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete

asso_matrix = self[row_instances, col_instances]
# print(col_instances)
# print(asso_matrix)
if "tra" in col_to:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^

# print(asso_matrix)
if "tra" in col_to:
asso_matrix = reduce_method(asso_matrix, axis=1)
if "tra" in row_to:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^

Returns:
A dictionary of track_id:instances
"""
if "pred" in label.lower():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

==

]
for track_id in traj_ids
}
elif "gt" in label.lower():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

==

Copy link
Contributor

@talmo talmo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep complexity in mind. Organization is great, but we don't need this many bells and whistles for input validation. 2 extra lines of docs is worth 10000 if-else checks.

Comment on lines 25 to 33
AVAILABLE_REDUCTIONS = attrs.field(
init=False,
default={
"instance": ["inst", "instance"],
"track": ["track", "traj", "trajectory"],
None: ["", None],
},
)
AVAILABLE_GROUPINGS = attrs.field(init=False, default=["pred", "gt", None])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these to module level and remove attrs since these are just static constants

@@ -22,6 +22,16 @@ class AssociationMatrix:
ref_instances: list[Instance] = attrs.field()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't need = attrs.field() since auto_attribs=True is the default in the new attrs API (assuming all fields have type hinting).

image

https://www.attrs.org/en/stable/api.html#attrs.define

Comment on lines 102 to 104
If list, then must match # of rows/queries
If `"gt"` then label by gt track id.
If `"pred"` then label by pred track id.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just needs to be one indentation level (4 spaces) to the right from the starting column of row_labels.

Comment on lines 180 to 181
row_dims: A str indicating how to what dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_rows=n_traj)
col_dims: A str indicating how to dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_cols=n_traj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
row_dims: A str indicating how to what dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_rows=n_traj)
col_dims: A str indicating how to dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_cols=n_traj)
row_dims: A str indicating how to what dimensions to reduce rows to. Either "inst" (remains unchanged), or "traj" (n_rows=n_traj)
col_dims: A str indicating how to dimensions to reduce rows to. Either "inst" (remains unchanged), or "traj" (n_cols=n_traj)

Comment on lines 192 to 196
not in [key for keys in self.AVAILABLE_REDUCTIONS.values() for key in keys]
) or (
col_dims is not None
and col_dims.lower()
not in [key for keys in self.AVAILABLE_REDUCTIONS.values() for key in keys]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cache all of this. Either have pre-defined lists (run this double loop comprehension at the module level), or have a helper function with a @functools.cache decorator. The former is easier and does the job. I'd even recommend storing it in a tuple or set, depending on whether order matters, to make it faster to query.

Comment on lines 202 to 208
if (
row_grouping not in self.AVAILABLE_GROUPINGS
or col_grouping not in self.AVAILABLE_GROUPINGS
):
raise ValueError(
f"Can aggregate by [gt, pred, None] but {row_by} and {col_by} was requested!"
f"Can aggregate by [gt, pred, None] but {row_grouping} and {col_grouping} was requested!"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need all this error checking? Again, keep the complexity in mind.

A list membership check is O(n), which is small for a small list, but in our case we'll be running this hundreds of thousands to millions of times, so every one of these checks adds up quick.

Refs:

Comment on lines 263 to 269

try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()

except IndexError as e:
print(f"Query_insts: {type(query_inst)}")
print(f"Query_inds: {query_ind}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is debugging code and comes at the cost of added checks and try-except fencing. I propose just getting rid of it. It's easier to debug these types of things now that things are better refactored.

Base automatically changed from aadi/instance-model-input to main May 29, 2024 21:59
@aaprasad aaprasad force-pushed the aadi/refactor-data-structures branch from 74c4889 to 5f2fedc Compare May 29, 2024 22:38
@aaprasad aaprasad marked this pull request as ready for review May 31, 2024 02:37
@aaprasad aaprasad requested a review from talmo May 31, 2024 02:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

Outside diff range and nitpick comments (7)
biogtr/io/__init__.py (1)

3-6: Consider adding imported classes to __all__ to clarify their intended use.

__all__ = ['Frame', 'Instance', 'AssociationMatrix', 'Track']
biogtr/inference/track.py (1)

Line range hint 133-133: Remove unnecessary f-string.

- print(f"Pod Index: {index}")
+ print("Pod Index:", index)
biogtr/datasets/microscopy_dataset.py (1)

Line range hint 87-87: Replace lambda with a function definition for clarity and maintainability.

- parser = lambda x: data_utils.parse_synthetic(x, source=source)
+ def parser(x):
+     return data_utils.parse_synthetic(x, source=source)
biogtr/datasets/cell_tracking_dataset.py (1)

126-126: Method get_instances effectively processes dataset elements. Consider adding more detailed documentation to explain the processing steps.

biogtr/training/losses.py (1)

3-3: Method forward effectively calculates the association loss. Consider adding more detailed documentation to explain the processing steps.

biogtr/inference/metrics.py (1)

108-108: Function to_track_eval effectively reformats frames for tracking evaluation. Consider adding more detailed documentation to explain the processing steps.

biogtr/inference/tracker.py (1)

Line range hint 177-177: The variable curr_track is defined but never used, which could lead to confusion. Consider removing it if it's not needed.

- curr_track += 1
Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

Commits

Files that changed from the base of the PR and between 89007a9 and 48efb7d.

Files selected for processing (25)
  • biogtr/datasets/base_dataset.py (1 hunks)
  • biogtr/datasets/cell_tracking_dataset.py (2 hunks)
  • biogtr/datasets/eval_dataset.py (1 hunks)
  • biogtr/datasets/microscopy_dataset.py (2 hunks)
  • biogtr/datasets/sleap_dataset.py (1 hunks)
  • biogtr/inference/metrics.py (2 hunks)
  • biogtr/inference/track.py (1 hunks)
  • biogtr/inference/track_queue.py (1 hunks)
  • biogtr/inference/tracker.py (5 hunks)
  • biogtr/io/init.py (1 hunks)
  • biogtr/io/association_matrix.py (1 hunks)
  • biogtr/io/frame.py (1 hunks)
  • biogtr/io/instance.py (1 hunks)
  • biogtr/io/track.py (1 hunks)
  • biogtr/models/global_tracking_transformer.py (3 hunks)
  • biogtr/models/gtr_runner.py (3 hunks)
  • biogtr/models/model_utils.py (1 hunks)
  • biogtr/models/transformer.py (7 hunks)
  • biogtr/training/losses.py (1 hunks)
  • biogtr/training/train.py (1 hunks)
  • tests/test_config.py (1 hunks)
  • tests/test_data_model.py (4 hunks)
  • tests/test_inference.py (1 hunks)
  • tests/test_models.py (4 hunks)
  • tests/test_training.py (1 hunks)
Files skipped from review due to trivial changes (1)
  • biogtr/inference/track_queue.py
Additional context used
Learnings (3)
biogtr/io/__init__.py (1)
User: aaprasad
PR: talmolab/biogtr#47
File: biogtr/io/__init__.py:8-8
Timestamp: 2024-05-20T21:57:47.690Z
Learning: To resolve circular import issues in Python, consider using local imports within functions or methods to delay the import until it's actually needed.
biogtr/models/global_tracking_transformer.py (1)
User: aaprasad
PR: talmolab/biogtr#46
File: biogtr/models/global_tracking_transformer.py:5-6
Timestamp: 2024-05-29T17:50:29.550Z
Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
biogtr/io/frame.py (1)
User: aaprasad
PR: talmolab/biogtr#46
File: biogtr/models/global_tracking_transformer.py:5-6
Timestamp: 2024-05-29T17:50:29.550Z
Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
Ruff
biogtr/io/__init__.py

3-3: biogtr.io.frame.Frame imported but unused; consider removing, adding to __all__, or using a redundant alias


4-4: biogtr.io.instance.Instance imported but unused; consider removing, adding to __all__, or using a redundant alias


5-5: biogtr.io.association_matrix.AssociationMatrix imported but unused; consider removing, adding to __all__, or using a redundant alias


6-6: biogtr.io.track.Track imported but unused; consider removing, adding to __all__, or using a redundant alias

biogtr/io/track.py

17-17: Undefined name Instance


46-46: Undefined name Instances


64-64: Undefined name Frame


80-80: Undefined name Instance


80-80: Undefined name Instance

tests/test_training.py

10-10: biogtr.models.global_tracking_transformer.GlobalTrackingTransformer imported but unused


11-11: omegaconf.DictConfig imported but unused


47-47: Local variable feats is assigned to but never used

biogtr/inference/track.py

133-133: f-string without any placeholders

biogtr/models/global_tracking_transformer.py

84-84: Undefined name AssociationMatrix

biogtr/datasets/microscopy_dataset.py

87-87: Do not assign a lambda expression, use a def

biogtr/models/model_utils.py

3-3: typing.Tuple imported but unused

tests/test_inference.py

4-4: pytest imported but unused


168-168: Local variable N is assigned to but never used

biogtr/io/association_matrix.py

113-113: f-string without any placeholders


137-137: f-string without any placeholders

biogtr/datasets/sleap_dataset.py

9-9: warnings imported but unused

tests/test_models.py

429-429: Local variable img_shape is assigned to but never used

biogtr/inference/tracker.py

121-121: f-string without any placeholders


177-177: Undefined name curr_track


177-177: Local variable curr_track is assigned to but never used

biogtr/io/frame.py

59-59: Undefined name AssociationMatrix


328-328: Undefined name AssociationMatrix


347-347: Undefined name AssociationMatrix

biogtr/io/instance.py

93-93: Undefined name Frame


520-520: Undefined name Frame


529-529: Undefined name Frame

Additional comments not posted (56)
biogtr/datasets/eval_dataset.py (1)

4-4: LGTM! The imports and class definition align well with the PR objectives of enhancing data handling.

Also applies to: 5-5

tests/test_config.py (1)

4-4: LGTM! The tests are comprehensive and provide good coverage for the Config class.

biogtr/training/train.py (1)

6-6: LGTM! The training script is well-organized and effectively handles different training configurations.

tests/test_training.py (1)

6-6: LGTM! The tests are comprehensive and validate the training logic effectively.

Also applies to: 7-7

biogtr/inference/track.py (1)

3-5: Update imports to reflect new structure.

Ensure that the updated imports from biogtr.io.frame and biogtr.models.gtr_runner are correctly used in the file.

biogtr/models/global_tracking_transformer.py (2)

5-5: Ensure correct usage of the Instance class from biogtr.io.instance.

The import of Instance is used for type hinting in the forward method parameters, aligning with previous implementations.


84-84: Import the AssociationMatrix class to fix the undefined name.

+ from biogtr.io.association_matrix import AssociationMatrix

Likely invalid or redundant comment.

biogtr/datasets/base_dataset.py (1)

4-4: Update import to reflect new structure.

The import from biogtr.io.frame is correctly updated and used in the type hints and method implementations.

tests/test_data_model.py (4)

3-6: Ensure correct imports for testing.

Imports from biogtr.io.frame, biogtr.io.instance, biogtr.io.association_matrix, and biogtr.io.track are correctly used in the test functions.


108-116: Test the functionality of AssociationMatrix within Frame.

The test checks the integration of AssociationMatrix with Frame, ensuring that methods like has_matches, has_asso_output, and has_traj_score work as expected.


135-196: Test the functionality and error handling of AssociationMatrix.

The tests cover various aspects of AssociationMatrix, including error handling with incorrect tensor shapes and indexing, and reductions based on different criteria.


199-210: Test the functionality of the Track class.

The test checks the Track class for correct handling of instances, ensuring that properties like track_id and indexing work as expected.

biogtr/datasets/microscopy_dataset.py (1)

6-7: Update imports to reflect new structure.

The imports from biogtr.io.frame and biogtr.io.instance are correctly updated and used in the class methods.

biogtr/models/model_utils.py (6)

5-5: Ensure correct usage of the Instance class from biogtr.io.instance.

The import of Instance is correctly used in the function parameters for type hinting and ensuring correct data handling.


5-5: Review the softmax application to ensure it handles edge cases correctly.

The function correctly handles the softmax application, including the addition of a zero-padding column to maintain shape consistency.


5-5: Review the error handling in optimizer initialization.

The function includes comprehensive error handling to ensure that the optimizer is correctly instantiated based on the provided configuration.


5-5: Review the error handling in scheduler initialization.

The function includes comprehensive error handling to ensure that the scheduler is correctly instantiated based on the provided configuration.


5-5: Review the logger initialization to ensure it handles different configurations correctly.

The function correctly initializes different types of loggers based on the configuration, handling potential exceptions and providing informative error messages.


5-5: Ensure the device detection logic is robust and accurate.

The function correctly identifies the available computing device, checking for CUDA and MPS support and defaulting to CPU if necessary.

biogtr/datasets/cell_tracking_dataset.py (3)

6-7: Updated imports to reflect new module structure.


Line range hint 15-126: Constructor is well-structured and handles dataset initialization effectively.


126-126: Method get_indices correctly retrieves label and frame indices.

tests/test_inference.py (4)

6-7: Updated imports to reflect new module structure.


15-91: Comprehensive test for TrackQueue functionality.


92-92: Test for Tracker is well-structured and effectively checks the functionality.


92-92: Test for post-processing methods is detailed and covers multiple scenarios.

biogtr/models/gtr_runner.py (8)

11-12: Updated imports to reflect new module structure.


Line range hint 11-76: Constructor is well-structured and handles model initialization effectively.


76-76: Method forward correctly implements the model's forward logic.


76-76: Method training_step is well-structured and effectively handles the training logic.


76-76: Method validation_step is well-structured and effectively handles the validation logic.


76-76: Method test_step is well-structured and effectively handles the testing logic.


76-76: Method predict_step correctly implements the inference logic.


Line range hint 76-169: Method _shared_eval_step is well-structured and effectively handles the evaluation logic.

biogtr/training/losses.py (3)

3-3: Updated imports to reflect new module structure.


3-3: Constructor is well-structured and handles loss function initialization effectively.


3-3: Method _get_asso_gt correctly computes the association ground truth.

biogtr/inference/metrics.py (4)

6-6: Updated imports to reflect new module structure.


6-6: Function get_matches correctly extracts matches between predicted and ground truth labels.


6-6: Function get_switches correctly identifies misassigned predicted trajectory labels.


6-6: Function get_switch_count correctly calculates the number of mislabeled predicted trajectories.

biogtr/io/association_matrix.py (4)

11-19: Class definition and documentation are clear and concise.


21-23: Attributes are well-defined using the attrs library.


75-83: Method to convert matrix to numpy array is correctly implemented.


245-287: Method for retrieving indices is robust and handles various input types effectively.

biogtr/datasets/sleap_dataset.py (1)

10-11: Imports updated to reflect new module structure.

tests/test_models.py (1)

5-6: Imports updated to reflect new module structure.

biogtr/inference/tracker.py (2)

6-6: Imports updated to reflect new module structure.


131-131: The f-string in the docstring does not contain any placeholders, which might be a typo or an oversight.

- frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info).
+ frames: A list of Frames (See `biogtr.io.frame.Frame` for more info).

Likely invalid or redundant comment.

biogtr/io/frame.py (2)

328-328: Undefined name AssociationMatrix used in type hint.


347-347: Undefined name AssociationMatrix used in type hint.

biogtr/io/instance.py (5)

11-27: The _to_tensor function is well-implemented with appropriate type checks and conversions.


55-120: The Instance class is well-structured with clear attribute definitions and initialization. The use of attrs simplifies the code and ensures type safety and default handling.


95-120: The __attrs_post_init__ method effectively handles default values and type conversions. However, consider handling exceptions for operations that might fail, such as tensor operations.

#!/bin/bash
# Description: Verify tensor operations in __attrs_post_init__ do not raise exceptions with typical inputs.

# Test: Run typical inputs through the method. Expect: No exceptions.
python -c "from biogtr.io.instance import Instance; inst = Instance(); inst.__attrs_post_init__()"

154-209: The from_slp class method is a good example of a factory method, but it lacks error handling for potential issues during the conversion process, such as type mismatches or missing attributes in slp_instance.

#!/bin/bash
# Description: Verify that `from_slp` handles all potential input variations without error.

# Test: Provide various malformed inputs to the method. Expect: Graceful handling or specific error messages.
python -c "from biogtr.io.instance import Instance; Instance.from_slp(malformed_input)"

211-243: The to_slp method converts the instance to a sleap_io.PredictedInstance object. Ensure that all attributes used in the conversion are properly validated and handled to avoid runtime errors.

#!/bin/bash
# Description: Verify that all attributes used in `to_slp` are validated.

# Test: Check for attribute presence and correct types before conversion. Expect: No runtime errors.
python -c "from biogtr.io.instance import Instance; inst = Instance(); inst.to_slp()"

)

_instances: list["Instance"] = attrs.field(alias="instances", factory=list)
_asso_output: "AssociationMatrix" = attrs.field(alias="asso_output", default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undefined name AssociationMatrix used in type hint.

Please ensure that AssociationMatrix is properly imported or defined within this module or in an importable module.

@@ -141,7 +142,7 @@ def _reset_parameters(self):

def forward(
self, ref_instances: list[Instance], query_instances: list[Instance] = None
) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]:
) -> list[AssociationMatrix]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the concatenation of instance features.

Consider using a more efficient method for concatenating tensors, especially when dealing with large lists of instances. This could potentially improve performance.

- [instance.features for instance in ref_instances], dim=0
+ torch.stack([instance.features for instance in ref_instances], dim=0)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
) -> list[AssociationMatrix]:
) -> list[AssociationMatrix]:
torch.stack([instance.features for instance in ref_instances], dim=0)

Comment on lines +85 to +157
def to_dataframe(
self, row_labels: str = "gt", col_labels: str = "gt"
) -> pd.DataFrame:
"""Convert the association matrix to a pandas DataFrame.

Args:
row_labels: How to label the rows(queries).
If list, then must match # of rows/queries
If `"gt"` then label by gt track id.
If `"pred"` then label by pred track id.
Otherwise label by the query_instance indices
col_labels: How to label the columns(references).
If list, then must match # of columns/refs
If `"gt"` then label by gt track id.
If `"pred"` then label by pred track id.
Otherwise label by the ref_instance indices

Returns:
The association matrix as a pandas dataframe.
"""
matrix = self.numpy()

if not isinstance(row_labels, str):
if len(row_labels) == len(self.query_instances):
row_inds = row_labels
else:
raise ValueError(
(
f"Mismatched # of rows and labels!",
f"Found {len(row_labels)} with {len(self.query_instances)} rows",
)
)
else:
if row_labels == "gt":
row_inds = [
instance.gt_track_id.item() for instance in self.query_instances
]

elif row_labels == "pred":
row_inds = [
instance.pred_track_id.item() for instance in self.query_instances
]

else:
row_inds = np.arange(len(self.query_instances))

if not isinstance(col_labels, str):
if len(col_labels) == len(self.ref_instances):
col_inds = col_labels
else:
raise ValueError(
(
f"Mismatched # of columns and labels!",
f"Found {len(col_labels)} with {len(self.ref_instances)} columns",
)
)
else:
if col_labels == "gt":
col_inds = [
instance.gt_track_id.item() for instance in self.ref_instances
]

elif col_labels == "pred":
col_inds = [
instance.pred_track_id.item() for instance in self.ref_instances
]

else:
col_inds = np.arange(len(self.ref_instances))

asso_df = pd.DataFrame(matrix, index=row_inds, columns=col_inds)

return asso_df
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enhance flexibility in labeling rows and columns in DataFrame conversion.

Consider adding an optional parameter to allow custom labels for rows and columns, enhancing the flexibility of this method:

def to_dataframe(self, row_labels=None, col_labels=None):
    row_inds = row_labels if row_labels is not None else [instance.id for instance in self.query_instances]
    col_inds = col_labels if col_labels is not None else [instance.id for instance in self.ref_instances]
    ...

_skeleton: sio.Skeleton = attrs.field(alias="skeleton", default=None)
_pose: dict[str, ArrayLike] = attrs.field(alias="pose", factory=dict)
_device: str = attrs.field(alias="device", default=None)
_frame: "Frame" = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the missing import for Frame in the setter method.

+ from biogtr.io.frame import Frame

Also applies to: 520-520, 529-529


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
_frame: "Frame" = None
+ from biogtr.io.frame import Frame
_frame: "Frame" = None

Comment on lines +288 to +327
def get_tracks(
self, instances: list["Instance"], label: str = "pred"
) -> dict[int, list["Instance"]]:
"""Group instances by track.

Args:
instances: The list of instances to group
label: the track id type to group by. Either `pred` or `gt`.

Returns:
A dictionary of track_id:instances
"""
if label == "pred":
traj_ids = set([instance.pred_track_id.item() for instance in instances])
traj = {
track_id: [
instance
for instance in instances
if instance.pred_track_id.item() == track_id
]
for track_id in traj_ids
}

elif label == "gt":
traj_ids = set(
[instance.gt_track_id.item() for instance in self.ref_instances]
)
traj = {
track_id: [
instance
for instance in self.ref_instances
if instance.gt_track_id.item() == track_id
]
for track_id in traj_ids
}

else:
raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")

return traj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validate input parameters more robustly.

- raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")
+ raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")

Committable suggestion was skipped due low confidence.

Comment on lines +218 to +243
def __getitem__(self, inds) -> np.ndarray:
"""Get elements of the association matrix.

Args:
inds: A tuple of query indices and reference indices.
Indices can be either:
A single instance or integer.
A list of instances or integers.

Returns:
An np.ndarray containing the elements requested.
"""
query_inst, ref_inst = inds

query_ind = self.__getindices__(query_inst, self.query_instances)
ref_ind = self.__getindices__(ref_inst, self.ref_instances)

try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()

except IndexError as e:
print(f"Query_insts: {type(query_inst)}")
print(f"Query_inds: {query_ind}")
print(f"Ref_insts: {type(ref_inst)}")
print(f"Ref_ind: {ref_ind}")
raise (e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling by providing more context.

- print(f"Query_insts: {type(query_inst)}")
- print(f"Query_inds: {query_ind}")
- print(f"Ref_insts: {type(ref_inst)}")
- print(f"Ref_ind: {ref_ind}")
+ print(f"Error accessing matrix with query instances of type {type(query_inst)} and reference instances of type {type(ref_inst)}")
+ print(f"Query indices attempted: {query_ind}, Reference indices attempted: {ref_ind}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def __getitem__(self, inds) -> np.ndarray:
"""Get elements of the association matrix.
Args:
inds: A tuple of query indices and reference indices.
Indices can be either:
A single instance or integer.
A list of instances or integers.
Returns:
An np.ndarray containing the elements requested.
"""
query_inst, ref_inst = inds
query_ind = self.__getindices__(query_inst, self.query_instances)
ref_ind = self.__getindices__(ref_inst, self.ref_instances)
try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()
except IndexError as e:
print(f"Query_insts: {type(query_inst)}")
print(f"Query_inds: {query_ind}")
print(f"Ref_insts: {type(ref_inst)}")
print(f"Ref_ind: {ref_ind}")
raise (e)
def __getitem__(self, inds) -> np.ndarray:
"""Get elements of the association matrix.
Args:
inds: A tuple of query indices and reference indices.
Indices can be either:
A single instance or integer.
A list of instances or integers.
Returns:
An np.ndarray containing the elements requested.
"""
query_inst, ref_inst = inds
query_ind = self.__getindices__(query_inst, self.query_instances)
ref_ind = self.__getindices__(ref_inst, self.ref_instances)
try:
return self.numpy()[query_ind[:, None], ref_ind].squeeze()
except IndexError as e:
print(f"Error accessing matrix with query instances of type {type(query_inst)} and reference instances of type {type(ref_inst)}")
print(f"Query indices attempted: {query_ind}, Reference indices attempted: {ref_ind}")
raise (e)

Comment on lines +159 to +216
def reduce(
self,
row_dims: str = "instance",
col_dims: str = "track",
row_grouping: str = None,
col_grouping: str = "pred",
reduce_method: callable = np.sum,
) -> pd.DataFrame:
"""Aggregate the association matrix by specified dimensions and grouping.

Args:
row_dims: A str indicating how to what dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_rows=n_traj).
col_dims: A str indicating how to dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.

Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
"""
n_rows = len(self.query_instances)
n_cols = len(self.ref_instances)

col_tracks = {-1: self.ref_instances}
row_tracks = {-1: self.query_instances}

col_inds = [i for i in range(len(self.ref_instances))]
row_inds = [i for i in range(len(self.query_instances))]

if col_dims == "track":
col_tracks = self.get_tracks(self.ref_instances, col_grouping)
col_inds = list(col_tracks.keys())
n_cols = len(col_inds)

if row_dims == "track":
row_tracks = self.get_tracks(self.query_instances, row_grouping)
row_inds = list(row_tracks.keys())
n_rows = len(row_inds)

reduced_matrix = []
for row_track, row_instances in row_tracks.items():

for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]

if col_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=1)

if row_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=0)

reduced_matrix.append(asso_matrix)

reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T

return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor to use more descriptive parameter names in the reduce method.

- row_dims: str = "instance",
- col_dims: str = "track",
+ row_dimensions: str = "instance",
+ column_dimensions: str = "track",
...
- "Reduce association matrix rows/columns to inst/traj x traj/inst."
+ "Aggregate the association matrix by specified dimensions and grouping."

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def reduce(
self,
row_dims: str = "instance",
col_dims: str = "track",
row_grouping: str = None,
col_grouping: str = "pred",
reduce_method: callable = np.sum,
) -> pd.DataFrame:
"""Aggregate the association matrix by specified dimensions and grouping.
Args:
row_dims: A str indicating how to what dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_rows=n_traj).
col_dims: A str indicating how to dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
"""
n_rows = len(self.query_instances)
n_cols = len(self.ref_instances)
col_tracks = {-1: self.ref_instances}
row_tracks = {-1: self.query_instances}
col_inds = [i for i in range(len(self.ref_instances))]
row_inds = [i for i in range(len(self.query_instances))]
if col_dims == "track":
col_tracks = self.get_tracks(self.ref_instances, col_grouping)
col_inds = list(col_tracks.keys())
n_cols = len(col_inds)
if row_dims == "track":
row_tracks = self.get_tracks(self.query_instances, row_grouping)
row_inds = list(row_tracks.keys())
n_rows = len(row_inds)
reduced_matrix = []
for row_track, row_instances in row_tracks.items():
for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]
if col_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=1)
if row_dims == "track":
asso_matrix = reduce_method(asso_matrix, axis=0)
reduced_matrix.append(asso_matrix)
reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T
return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds)
def reduce(
self,
row_dimensions: str = "instance",
column_dimensions: str = "track",
row_grouping: str = None,
col_grouping: str = "pred",
reduce_method: callable = np.sum,
) -> pd.DataFrame:
"""Aggregate the association matrix by specified dimensions and grouping.
Args:
row_dimensions: A str indicating how to what dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_rows=n_traj).
column_dimensions: A str indicating how to dimensions to reduce rows to.
Either "instance" (remains unchanged), or "track" (n_cols=n_traj)
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt".
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt".
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing.
Returns:
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe.
"""
n_rows = len(self.query_instances)
n_cols = len(self.ref_instances)
col_tracks = {-1: self.ref_instances}
row_tracks = {-1: self.query_instances}
col_inds = [i for i in range(len(self.ref_instances))]
row_inds = [i for i in range(len(self.query_instances))]
if column_dimensions == "track":
col_tracks = self.get_tracks(self.ref_instances, col_grouping)
col_inds = list(col_tracks.keys())
n_cols = len(col_inds)
if row_dimensions == "track":
row_tracks = self.get_tracks(self.query_instances, row_grouping)
row_inds = list(row_tracks.keys())
n_rows = len(row_inds)
reduced_matrix = []
for row_track, row_instances in row_tracks.items():
for col_track, col_instances in col_tracks.items():
asso_matrix = self[row_instances, col_instances]
if column_dimensions == "track":
asso_matrix = reduce_method(asso_matrix, axis=1)
if row_dimensions == "track":
asso_matrix = reduce_method(asso_matrix, axis=0)
reduced_matrix.append(asso_matrix)
reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T
return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds)

from biogtr.io.association_matrix import AssociationMatrix
from biogtr.io.track import Track

# TODO: expose config without circular import error from biogtr.io.config import Config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure the dynamic import of Config is implemented to resolve the circular import issue as previously discussed.

Would you like to proceed with the implementation of the dynamic import solution for Config?

"""

_id: int = attrs.field(alias="id")
_instances: list["Instance"] = attrs.field(alias="instances", factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the undefined names and type hints.

- _instances: list["Instance"] = attrs.field(alias="instances", factory=list)
+ from biogtr.io.instance import Instance
+ _instances: list[Instance] = attrs.field(alias="instances", factory=list)

- def instances(self) -> list["Instances"]:
+ def instances(self) -> list[Instance]:

- def frames(self) -> set["Frame"]:
+ from biogtr.io.frame import Frame
+ def frames(self) -> set[Frame]:

- def __getitem__(self, ind) -> Union["Instance", list["Instance"]]:
+ def __getitem__(self, ind) -> Union[Instance, list[Instance]]:

Also applies to: 46-46, 64-64, 80-80


Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
_instances: list["Instance"] = attrs.field(alias="instances", factory=list)
from biogtr.io.instance import Instance
from biogtr.io.frame import Frame
_instances: list[Instance] = attrs.field(alias="instances", factory=list)
def instances(self) -> list[Instance]:
...
def frames(self) -> set[Frame]:
...
def __getitem__(self, ind) -> Union[Instance, list[Instance]]:
...

@@ -2,7 +2,7 @@

from typing import List, Tuple, Iterable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

- from typing import List, Tuple, Iterable
+ from typing import List, Iterable

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from typing import List, Tuple, Iterable
from typing import List, Iterable

@aaprasad aaprasad merged commit 98106b7 into main Jun 3, 2024
3 checks passed
@aaprasad aaprasad deleted the aadi/refactor-data-structures branch June 3, 2024 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants