-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor data structures #47
Conversation
WalkthroughThe recent changes primarily focus on enhancing the Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 23
Outside diff range and nitpick comments (9)
biogtr/io/__init__.py (4)
3-3
: Consider addingFrame
to__all__
to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.
4-4
: Consider addingInstance
to__all__
to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.
5-5
: Consider addingAssociationMatrix
to__all__
to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.
6-6
: Consider addingTrack
to__all__
to clarify its intended use, especially if it's meant to be exposed as part of the module's public API.biogtr/datasets/microscopy_dataset.py (1)
Line range hint
87-87
: Replace the lambda expression with a function definition for better readability and maintainability.- parser = lambda x: data_utils.parse_synthetic(x, source=source) + def parser(x): + return data_utils.parse_synthetic(x, source=source)tests/test_inference.py (1)
Line range hint
168-168
: Local variableN
is assigned but never used.- N = N_t * T
tests/test_models.py (1)
Line range hint
429-429
: Remove the unused variableimg_shape
to clean up the code.- img_shape = (1, 100, 100)
biogtr/inference/tracker.py (2)
Line range hint
121-121
: F-string without any placeholders.This line contains an f-string but does not have any placeholders. If dynamic string formatting is not needed, consider using a regular string.
Line range hint
177-177
: Undefined name 'curr_track' and local variable 'curr_track' is assigned to but never used.It seems there is a typo or logical error here. The variable
curr_track
is incremented but never initialized. Ensure that it is correctly initialized and used, or remove it if unnecessary.
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (26)
- biogtr/datasets/base_dataset.py (1 hunks)
- biogtr/datasets/cell_tracking_dataset.py (2 hunks)
- biogtr/datasets/eval_dataset.py (1 hunks)
- biogtr/datasets/microscopy_dataset.py (2 hunks)
- biogtr/datasets/sleap_dataset.py (1 hunks)
- biogtr/inference/metrics.py (2 hunks)
- biogtr/inference/track.py (1 hunks)
- biogtr/inference/track_queue.py (1 hunks)
- biogtr/inference/tracker.py (9 hunks)
- biogtr/io/init.py (1 hunks)
- biogtr/io/association_matrix.py (1 hunks)
- biogtr/io/frame.py (1 hunks)
- biogtr/io/instance.py (1 hunks)
- biogtr/io/track.py (1 hunks)
- biogtr/io/visualize.py (1 hunks)
- biogtr/models/global_tracking_transformer.py (3 hunks)
- biogtr/models/gtr_runner.py (7 hunks)
- biogtr/models/model_utils.py (1 hunks)
- biogtr/models/transformer.py (9 hunks)
- biogtr/training/losses.py (2 hunks)
- biogtr/training/train.py (1 hunks)
- tests/test_config.py (1 hunks)
- tests/test_data_model.py (4 hunks)
- tests/test_inference.py (1 hunks)
- tests/test_models.py (6 hunks)
- tests/test_training.py (1 hunks)
Files skipped from review due to trivial changes (1)
- biogtr/training/train.py
Additional Context Used
Ruff (31)
biogtr/datasets/microscopy_dataset.py (1)
87-87: Do not assign a
lambda
expression, use adef
biogtr/datasets/sleap_dataset.py (1)
9-9:
warnings
imported but unusedbiogtr/inference/track.py (1)
133-133: f-string without any placeholders
biogtr/inference/tracker.py (3)
121-121: f-string without any placeholders
177-177: Undefined name
curr_track
177-177: Local variable
curr_track
is assigned to but never usedbiogtr/io/__init__.py (4)
3-3:
biogtr.io.frame.Frame
imported but unused; consider removing, adding to__all__
, or using a redundant alias
4-4:
biogtr.io.instance.Instance
imported but unused; consider removing, adding to__all__
, or using a redundant alias
5-5:
biogtr.io.association_matrix.AssociationMatrix
imported but unused; consider removing, adding to__all__
, or using a redundant alias
6-6:
biogtr.io.track.Track
imported but unused; consider removing, adding to__all__
, or using a redundant aliasbiogtr/io/frame.py (6)
47-47: Undefined name
Instance
48-48: Undefined name
AssociationMatrix
243-243: Undefined name
Instance
252-252: Undefined name
Instance
282-282: Undefined name
AssociationMatrix
301-301: Undefined name
AssociationMatrix
biogtr/io/instance.py (3)
90-90: Undefined name
Frame
415-415: Undefined name
Frame
424-424: Undefined name
Frame
biogtr/io/track.py (5)
17-17: Undefined name
Instance
48-48: Undefined name
Instances
66-66: Undefined name
Frame
82-82: Undefined name
Instance
82-82: Undefined name
Instance
biogtr/models/model_utils.py (1)
3-3:
typing.Tuple
imported but unusedtests/test_inference.py (2)
4-4:
pytest
imported but unused
168-168: Local variable
N
is assigned to but never usedtests/test_models.py (1)
429-429: Local variable
img_shape
is assigned to but never usedtests/test_training.py (3)
10-10:
biogtr.models.global_tracking_transformer.GlobalTrackingTransformer
imported but unused
11-11:
omegaconf.DictConfig
imported but unused
47-47: Local variable
feats
is assigned to but never used
Additional comments not posted (45)
biogtr/datasets/eval_dataset.py (2)
4-4
: The import ofFrame
is correctly used in the return type of__getitem__
. Good job ensuring type consistency.
5-5
: The import ofInstance
is correctly used for creatingInstance
objects within the__getitem__
method. Well done on maintaining clear type usage.tests/test_config.py (1)
4-4
: The import ofConfig
is correctly used throughout the test functions. Good job ensuring that the tests cover the updated import paths.tests/test_training.py (3)
6-6
: The import ofFrame
is correctly used in the creation ofFrame
objects within the test functions. Well done on maintaining clear type usage.
7-7
: The import ofInstance
is correctly used for creatingInstance
objects within the test functions. Good job ensuring type consistency.
12-12
: The import ofConfig
is correctly used in thetest_config_gtr_runner
function. Good job ensuring that the tests cover the updated import paths.biogtr/inference/track.py (2)
3-3
: The import ofConfig
is correctly used in themain
function to load configuration settings. Well done on maintaining clear usage.
5-5
: The import ofFrame
is correctly used in theexport_trajectories
andinference
functions to handle frame data. Good job ensuring type consistency.biogtr/models/global_tracking_transformer.py (1)
Line range hint
82-122
: The refactoring to handleInstance
objects instead ofFrame
objects is well-implemented. Consider adding more detailed comments explaining the processing ofInstance
objects, especially in the feature extraction logic.biogtr/datasets/base_dataset.py (1)
4-4
: The update to the import path forFrame
aligns with the new structure of the project.tests/test_data_model.py (3)
10-10
: The tests for theInstance
object are comprehensive and cover all necessary attributes and methods.
108-116
: The tests for theFrame
object, including its interaction with theAssociationMatrix
, are well-implemented and thorough.
135-192
: The tests for theAssociationMatrix
are comprehensive, effectively testing both normal operations and error conditions.biogtr/datasets/microscopy_dataset.py (1)
6-7
: The update to the import paths forFrame
andInstance
aligns with the new structure of the project.biogtr/models/model_utils.py (1)
9-28
: The functionget_boxes
correctly extracts bounding boxes from instances and normalizes them as expected.biogtr/datasets/cell_tracking_dataset.py (2)
6-7
: The update to the import paths forFrame
andInstance
aligns with the new structure of the project.
126-126
: TheCellTrackingDataset
class correctly handlesFrame
andInstance
objects, aligning with the refactoring goals of the PR.tests/test_inference.py (1)
6-7
: Updated imports to reflect new module structure.biogtr/models/gtr_runner.py (5)
11-12
: Updated imports to reflect new module structure.
Line range hint
64-76
: Refactoredforward
method to handleInstance
objects instead ofFrame
objects, aligning with the new data structure.
149-185
: Refactored_shared_eval_step
to handle new data structures and metrics. Ensure exception handling and logging are correctly implemented.Verification successful
The search results indicate that exception handling and logging are implemented in various parts of the codebase, including the
_shared_eval_step
method inbiogtr/models/gtr_runner.py
. The exception handling in_shared_eval_step
is consistent with other parts of the codebase, where exceptions are caught, logged, and re-raised.
- The
_shared_eval_step
method catches exceptions, logs the frame and video IDs, and re-raises the exception.- This pattern is consistent with other exception handling instances found in the codebase.
Therefore, the refactored
_shared_eval_step
method appears to handle exceptions and logging correctly.Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify exception handling and logging in `_shared_eval_step`. # Test: Search for exception handling patterns in the codebase. Expect: Consistent and correct implementation. rg --type python $'try|except' --context 5Length of output: 74
Script:
#!/bin/bash # Description: Verify exception handling and logging in `_shared_eval_step`. # Test: Search for exception handling patterns in the codebase. Expect: Consistent and correct implementation. rg 'try|except' --glob '*.py' --context 5Length of output: 26532
Line range hint
132-147
: Refactoredpredict_step
to use the updated tracking logic. Ensure that thetracker
is correctly initialized and used here.Verification successful
The
tracker
is correctly initialized in thegtr_runner.py
file:
biogtr/models/gtr_runner.py: self.tracker = Tracker(**tracker_cfg)
This confirms that the
tracker
is properly set up before being used in thepredict_step
method.Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the initialization and usage of `tracker` in `predict_step`. # Test: Search for the tracker initialization in the codebase. Expect: Correct initialization and usage. rg --type python $'Tracker('Length of output: 96
Script:
#!/bin/bash # Description: Verify the initialization and usage of `tracker` in `predict_step`. # Test: Search for the tracker initialization in the codebase. Expect: Correct initialization and usage. rg --type python 'Tracker\('Length of output: 61
Script:
#!/bin/bash # Description: Verify the initialization and usage of `tracker` in `predict_step`. # Test: Search for the tracker initialization in the codebase. Expect: Correct initialization and usage. rg --glob '*.py' 'Tracker\('Length of output: 166
Line range hint
79-97
: Updatedtraining_step
,validation_step
, andtest_step
methods to handle lists ofFrame
objects. Ensure that the data passed to these methods are correctly structured as lists of lists ofFrame
objects.Also applies to: 114-116
Verification successful
The methods
training_step
,validation_step
, andtest_step
are defined inbiogtr/models/gtr_runner.py
and are called intests/test_training.py
. The calls to these methods pass[batch]
as an argument, which suggests that the data structure is a list of lists ofFrame
objects.
tests/test_training.py: metrics = gtr_runner.training_step([batch], i)
tests/test_training.py: metrics = gtr_runner.validation_step([batch], j)
tests/test_training.py: metrics = gtr_runner.test_step([batch], k)
Scripts Executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that the data passed to training_step, validation_step, and test_step are lists of lists of Frame objects. # Test: Search for the method calls in the codebase. Expect: Correct data structure in method calls. rg --type python $'GTRRunner.training_step|GTRRunner.validation_step|GTRRunner.test_step'Length of output: 121
Script:
#!/bin/bash # Description: Verify that the data passed to training_step, validation_step, and test_step are lists of lists of Frame objects. # Test: Search for the method calls in the codebase. Expect: Correct data structure in method calls. rg --type py $'GTRRunner.training_step|GTRRunner.validation_step|GTRRunner.test_step'Length of output: 86
Script:
#!/bin/bash # Description: Verify the definitions and calls of training_step, validation_step, and test_step methods within the GTRRunner class. # Search for the definitions of the methods in the codebase. rg --type py 'def training_step|def validation_step|def test_step' # Search for the calls to these methods in the codebase. rg --type py 'training_step|validation_step|test_step'Length of output: 675
biogtr/training/losses.py (2)
3-4
: Updated imports to reflect new module structure and utility functions.
52-56
: Refactored to use the newget_boxes
andget_times
utility functions, aligning with the updated data structures.biogtr/inference/metrics.py (2)
6-6
: Updated import to reflect new module structure.
108-108
: Updated documentation to reflect new data structures.biogtr/inference/track_queue.py (1)
4-4
: Updated import to reflect new module structure.biogtr/io/visualize.py (2)
18-56
: The implementation offill_missing
function is robust and well-documented.
270-284
: Thesave_vid
function is correctly implemented and effectively saves video frames to a file.biogtr/datasets/sleap_dataset.py (1)
10-11
: Correctly updated imports to reflect the new structure of thebiogtr.io
package.tests/test_models.py (1)
5-6
: Updated imports to reflect the new structure of thebiogtr.io
package.biogtr/io/instance.py (4)
53-90
: Refactor theInstance
class to useattrs
for initialization.The use of
attrs
for defining theInstance
class simplifies the code and makes it more maintainable by automatically generating methods like__init__
,__repr__
, and others based on the defined fields. This is a good application of the DRY principle.
92-116
: Review the__attrs_post_init__
method for proper initialization logic.The method correctly handles the initialization of various attributes with appropriate default values and conditions. This method ensures that the instance is fully prepared for use right after its creation.
131-148
: Ensure proper device handling in theto
method.The
to
method is well-implemented, providing functionality to change the device or data type of the instance's tensor attributes. This is crucial for operations in different computing environments, especially in deep learning contexts.
150-183
: Review theto_slp
method for converting tosleap_io.PredictedInstance
.This method provides a bridge between the internal data representation and the SLEAP framework, which is essential for integrating with other components that expect SLEAP data types.
biogtr/models/transformer.py (9)
14-15
: Update imports to reflect new data structure locations.
144-145
: Refactor theforward
method to handle lists ofInstance
objects.
149-150
: Update documentation to reflect new parameter types and expected data structures.
173-176
: Ensure robust handling of potential NaN values in reference boxes.
179-181
: Ensure embeddings are correctly calculated and returned based on thereturn_embedding
flag.Also applies to: 183-185
212-238
: Handle optionalquery_instances
with appropriate checks and transformations.
257-262
: Correctly form and append newAssociationMatrix
objects.
Line range hint
388-409
: Ensure positional embeddings are added correctly in the decoder layer.
Line range hint
512-535
: Properly manage the flow of data through the decoder layers, respecting thereturn_intermediate
flag.
biogtr/models/transformer.py
Outdated
ref_features = torch.cat( | ||
[instance.features for instance in ref_instances], dim=0 | ||
).unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize the concatenation of instance features.
Consider using a more efficient method for concatenating tensors, especially when dealing with large lists of instances. This could potentially improve performance.
- [instance.features for instance in ref_instances], dim=0
+ torch.stack([instance.features for instance in ref_instances], dim=0)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
ref_features = torch.cat( | |
[instance.features for instance in ref_instances], dim=0 | |
).unsqueeze(0) | |
ref_features = torch.stack( | |
[instance.features for instance in ref_instances], dim=0 | |
).unsqueeze(0) |
biogtr/io/association_matrix.py
Outdated
raise ValueError( | ||
( | ||
"Query instances must equal number of rows in Association matrix" | ||
f"Found {len(value)} query instances but {self.matrix.shape[0]} columns." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typographical error in error message.
- f"Found {len(value)} query instances but {self.matrix.shape[0]} columns."
+ f"Found {len(value)} query instances but {self.matrix.shape[0]} rows."
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
f"Found {len(value)} query instances but {self.matrix.shape[0]} columns." | |
f"Found {len(value)} query instances but {self.matrix.shape[0]} rows." |
@@ -7,7 +7,8 @@ | |||
import sleap_io as sio | |||
import random | |||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import of warnings
.
- import warnings
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
import warnings |
@@ -3,13 +3,93 @@ | |||
import torch | |||
import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import pytest
.
- import pytest
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
import pytest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (3)
- biogtr/io/association_matrix.py (1 hunks)
- biogtr/io/instance.py (1 hunks)
- tests/test_data_model.py (4 hunks)
Additional Context Used
Ruff (3)
biogtr/io/instance.py (3)
91-91: Undefined name
Frame
461-461: Undefined name
Frame
470-470: Undefined name
Frame
Additional comments not posted (12)
tests/test_data_model.py (4)
10-10
: Comprehensive test coverage forInstance
class.
108-116
: Effective testing ofFrame
class with comprehensive checks on properties and associated objects.
135-196
: Thorough testing ofAssociationMatrix
indexing and error handling.
199-210
: Basic but correct testing ofTrack
class properties.biogtr/io/association_matrix.py (5)
11-23
: Proper initialization and validation inAssociationMatrix
constructor.
75-83
: Correct conversion of association matrix to numpy array.
85-131
: Flexible and well-documented method for converting association matrix to DataFrame.
133-196
: Complex but correctly implemented method for reducing association matrix with detailed documentation.
198-221
: Correct retrieval of elements from association matrix with comprehensive indexing support.biogtr/io/instance.py (3)
53-131
: Proper initialization and conversion inInstance
constructor with comprehensive attribute handling.
132-149
: Correct handling of device and data type changes inInstance
objects.
151-184
: Correct conversion ofInstance
tosleap_io.PredictedInstance
with appropriate error handling.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #47 +/- ##
==========================================
- Coverage 69.60% 68.48% -1.12%
==========================================
Files 28 34 +6
Lines 2385 2891 +506
==========================================
+ Hits 1660 1980 +320
- Misses 725 911 +186 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 12
Outside diff range and nitpick comments (3)
tests/test_models.py (1)
Line range hint
429-429
: Remove the unused local variable 'img_shape'.- img_shape = (1, 100, 100)
biogtr/inference/tracker.py (2)
Line range hint
121-121
: Remove unnecessary f-string as it contains no placeholders.- warnings.warn(f"Clearing Queue after tracking") + warnings.warn("Clearing Queue after tracking")
Line range hint
177-177
: Fix the undefined variablecurr_track
and ensure it is used correctly.- curr_track += 1 + curr_track_id += 1
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (9)
- biogtr/inference/tracker.py (9 hunks)
- biogtr/io/frame.py (1 hunks)
- biogtr/io/instance.py (1 hunks)
- biogtr/io/track.py (1 hunks)
- biogtr/models/global_tracking_transformer.py (3 hunks)
- biogtr/models/gtr_runner.py (7 hunks)
- biogtr/models/transformer.py (8 hunks)
- tests/test_data_model.py (4 hunks)
- tests/test_models.py (6 hunks)
Additional Context Used
Ruff (16)
biogtr/inference/tracker.py (3)
121-121: f-string without any placeholders
177-177: Undefined name
curr_track
177-177: Local variable
curr_track
is assigned to but never usedbiogtr/io/frame.py (3)
59-59: Undefined name
AssociationMatrix
328-328: Undefined name
AssociationMatrix
347-347: Undefined name
AssociationMatrix
biogtr/io/instance.py (3)
93-93: Undefined name
Frame
520-520: Undefined name
Frame
529-529: Undefined name
Frame
biogtr/io/track.py (5)
17-17: Undefined name
Instance
46-46: Undefined name
Instances
64-64: Undefined name
Frame
80-80: Undefined name
Instance
80-80: Undefined name
Instance
biogtr/models/global_tracking_transformer.py (1)
84-84: Undefined name
AssociationMatrix
tests/test_models.py (1)
429-429: Local variable
img_shape
is assigned to but never used
GitHub Check Runs (1)
codecov/patch failure (30)
biogtr/io/frame.py: [warning] 21-21: biogtr/io/frame.py#L21
Added line #L21 was not covered by tests
biogtr/io/frame.py: [warning] 80-80: biogtr/io/frame.py#L80
Added line #L80 was not covered by tests
biogtr/io/frame.py: [warning] 109-109: biogtr/io/frame.py#L109
Added line #L109 was not covered by tests
biogtr/io/frame.py: [warning] 112-112: biogtr/io/frame.py#L112
Added line #L112 was not covered by tests
biogtr/io/frame.py: [warning] 115-116: biogtr/io/frame.py#L115-L116
Added lines #L115 - L116 were not covered by tests
biogtr/io/frame.py: [warning] 121-121: biogtr/io/frame.py#L121
Added line #L121 was not covered by tests
biogtr/io/frame.py: [warning] 141-146: biogtr/io/frame.py#L141-L146
Added lines #L141 - L146 were not covered by tests
biogtr/io/frame.py: [warning] 170-174: biogtr/io/frame.py#L170-L174
Added lines #L170 - L174 were not covered by tests
biogtr/io/frame.py: [warning] 201-201: biogtr/io/frame.py#L201
Added line #L201 was not covered by tests
biogtr/io/frame.py: [warning] 221-221: biogtr/io/frame.py#L221
Added line #L221 was not covered by tests
biogtr/io/frame.py: [warning] 241-241: biogtr/io/frame.py#L241
Added line #L241 was not covered by tests
Additional comments not posted (10)
tests/test_data_model.py (1)
108-116
: Ensure the 'AssociationMatrix' is correctly initialized and used in the test.The test correctly initializes the 'AssociationMatrix' with appropriate instances and asserts its properties.
biogtr/models/gtr_runner.py (1)
Line range hint
64-76
: Review the forward method for correct handling of instances.The method correctly handles the instances and uses the model's forward method to compute predictions.
biogtr/inference/tracker.py (4)
6-6
: Update import paths to align with new namespace structure.
Line range hint
17-41
: Initialization parameters are well-documented and correctly set up.
Line range hint
43-57
: Method delegation is correctly implemented for simplifying tracker usage.
212-212
: Ensure the docstring accurately describes the method's functionality and parameters.biogtr/models/transformer.py (2)
144-145
: Ensure correct handling ofInstance
objects in the forward method.The method signature correctly expects lists of
Instance
objects for both reference and query instances, aligning with the changes in data handling.
158-160
: Optimize the concatenation of instance features.Consider using a more efficient method for concatenating tensors, especially when dealing with large lists of instances. This could potentially improve performance.
- [instance.features for instance in ref_instances], dim=0 + torch.stack([instance.features for instance in ref_instances], dim=0)Likely invalid or redundant comment.
biogtr/io/instance.py (2)
11-28
: Utility functions_to_tensor
and_expand_to_rank
are implemented correctly and efficiently.Also applies to: 30-52
55-650
: TheInstance
class is well-structured and utilizesattrs
effectively for attribute management. Ensure all dependencies, such as theFrame
class, are correctly imported to prevent runtime errors.
alias="img_shape", converter=_to_tensor, factory=list | ||
) | ||
|
||
_instances: list["Instance"] = attrs.field(alias="instances", factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined name 'AssociationMatrix' used in type hint.
Please ensure that AssociationMatrix
is properly imported or defined within this module or in an importable module.
self._embeddings[emb_type] = embedding | ||
|
||
@property | ||
def frame(self) -> "Frame": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure the Frame
class is properly imported to avoid runtime errors.
+ from biogtr.io.frame import Frame
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def frame(self) -> "Frame": | |
from biogtr.io.frame import Frame | |
def frame(self) -> "Frame": |
""" | ||
return len(self.instances) | ||
|
||
def __getitem__(self, ind) -> Union["Instance", list["Instance"]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure the type hint for 'Instance' is correctly defined.
- def __getitem__(self, ind) -> Union["Instance", list["Instance"]]:
+ def __getitem__(self, ind) -> Union[Instance, list[Instance]]:
Committable suggestion was skipped due low confidence.
biogtr/io/association_matrix.py
Outdated
row_to, col_to = to | ||
row_by, col_by = by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would strongly prefer to have these as the input arguments explicitly instead of tuples.
biogtr/io/association_matrix.py
Outdated
col_inds = [i for i in range(len(self.ref_instances))] | ||
row_inds = [i for i in range(len(self.query_instances))] | ||
|
||
if "tra" in col_to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too clever bud. What if I provide "instrance"
???
Be explicit and list out the conditions with equality checks. if col_to == "track" or col_to == "trajectory"
biogtr/io/association_matrix.py
Outdated
|
||
reduced_matrix = [] | ||
for row_track, row_instances in row_tracks.items(): | ||
# print(row_instances) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
biogtr/io/association_matrix.py
Outdated
col_tracks = self.get_tracks(self.ref_instances, col_by) | ||
col_inds = list(col_tracks.keys()) | ||
n_cols = len(col_inds) | ||
if "tra" in row_to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^
biogtr/io/association_matrix.py
Outdated
# print(col_instances) | ||
# print(asso_matrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
biogtr/io/association_matrix.py
Outdated
asso_matrix = self[row_instances, col_instances] | ||
# print(col_instances) | ||
# print(asso_matrix) | ||
if "tra" in col_to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^
biogtr/io/association_matrix.py
Outdated
# print(asso_matrix) | ||
if "tra" in col_to: | ||
asso_matrix = reduce_method(asso_matrix, axis=1) | ||
if "tra" in row_to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^
biogtr/io/association_matrix.py
Outdated
Returns: | ||
A dictionary of track_id:instances | ||
""" | ||
if "pred" in label.lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==
biogtr/io/association_matrix.py
Outdated
] | ||
for track_id in traj_ids | ||
} | ||
elif "gt" in label.lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
==
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep complexity in mind. Organization is great, but we don't need this many bells and whistles for input validation. 2 extra lines of docs is worth 10000 if-else checks.
biogtr/io/association_matrix.py
Outdated
AVAILABLE_REDUCTIONS = attrs.field( | ||
init=False, | ||
default={ | ||
"instance": ["inst", "instance"], | ||
"track": ["track", "traj", "trajectory"], | ||
None: ["", None], | ||
}, | ||
) | ||
AVAILABLE_GROUPINGS = attrs.field(init=False, default=["pred", "gt", None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move these to module level and remove attrs since these are just static constants
@@ -22,6 +22,16 @@ class AssociationMatrix: | |||
ref_instances: list[Instance] = attrs.field() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't need = attrs.field()
since auto_attribs=True
is the default in the new attrs API (assuming all fields have type hinting).
biogtr/io/association_matrix.py
Outdated
If list, then must match # of rows/queries | ||
If `"gt"` then label by gt track id. | ||
If `"pred"` then label by pred track id. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just needs to be one indentation level (4 spaces) to the right from the starting column of row_labels
.
biogtr/io/association_matrix.py
Outdated
row_dims: A str indicating how to what dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_rows=n_traj) | ||
col_dims: A str indicating how to dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_cols=n_traj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
row_dims: A str indicating how to what dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_rows=n_traj) | |
col_dims: A str indicating how to dimensions to reduce rows to. Either inst (remains unchanged), or traj (n_cols=n_traj) | |
row_dims: A str indicating how to what dimensions to reduce rows to. Either "inst" (remains unchanged), or "traj" (n_rows=n_traj) | |
col_dims: A str indicating how to dimensions to reduce rows to. Either "inst" (remains unchanged), or "traj" (n_cols=n_traj) |
biogtr/io/association_matrix.py
Outdated
not in [key for keys in self.AVAILABLE_REDUCTIONS.values() for key in keys] | ||
) or ( | ||
col_dims is not None | ||
and col_dims.lower() | ||
not in [key for keys in self.AVAILABLE_REDUCTIONS.values() for key in keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cache all of this. Either have pre-defined lists (run this double loop comprehension at the module level), or have a helper function with a @functools.cache
decorator. The former is easier and does the job. I'd even recommend storing it in a tuple or set, depending on whether order matters, to make it faster to query.
biogtr/io/association_matrix.py
Outdated
if ( | ||
row_grouping not in self.AVAILABLE_GROUPINGS | ||
or col_grouping not in self.AVAILABLE_GROUPINGS | ||
): | ||
raise ValueError( | ||
f"Can aggregate by [gt, pred, None] but {row_by} and {col_by} was requested!" | ||
f"Can aggregate by [gt, pred, None] but {row_grouping} and {col_grouping} was requested!" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need all this error checking? Again, keep the complexity in mind.
A list membership check is O(n), which is small for a small list, but in our case we'll be running this hundreds of thousands to millions of times, so every one of these checks adds up quick.
Refs:
biogtr/io/association_matrix.py
Outdated
|
||
try: | ||
return self.numpy()[query_ind[:, None], ref_ind].squeeze() | ||
|
||
except IndexError as e: | ||
print(f"Query_insts: {type(query_inst)}") | ||
print(f"Query_inds: {query_ind}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is debugging code and comes at the cost of added checks and try-except fencing. I propose just getting rid of it. It's easier to debug these types of things now that things are better refactored.
…`Instance` classes into separate modules within `io`
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
handle string parsing better
74c4889
to
5f2fedc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
Outside diff range and nitpick comments (7)
biogtr/io/__init__.py (1)
3-6
: Consider adding imported classes to__all__
to clarify their intended use.__all__ = ['Frame', 'Instance', 'AssociationMatrix', 'Track']biogtr/inference/track.py (1)
Line range hint
133-133
: Remove unnecessary f-string.- print(f"Pod Index: {index}") + print("Pod Index:", index)biogtr/datasets/microscopy_dataset.py (1)
Line range hint
87-87
: Replace lambda with a function definition for clarity and maintainability.- parser = lambda x: data_utils.parse_synthetic(x, source=source) + def parser(x): + return data_utils.parse_synthetic(x, source=source)biogtr/datasets/cell_tracking_dataset.py (1)
126-126
: Methodget_instances
effectively processes dataset elements. Consider adding more detailed documentation to explain the processing steps.biogtr/training/losses.py (1)
3-3
: Methodforward
effectively calculates the association loss. Consider adding more detailed documentation to explain the processing steps.biogtr/inference/metrics.py (1)
108-108
: Functionto_track_eval
effectively reformats frames for tracking evaluation. Consider adding more detailed documentation to explain the processing steps.biogtr/inference/tracker.py (1)
Line range hint
177-177
: The variablecurr_track
is defined but never used, which could lead to confusion. Consider removing it if it's not needed.- curr_track += 1
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (25)
- biogtr/datasets/base_dataset.py (1 hunks)
- biogtr/datasets/cell_tracking_dataset.py (2 hunks)
- biogtr/datasets/eval_dataset.py (1 hunks)
- biogtr/datasets/microscopy_dataset.py (2 hunks)
- biogtr/datasets/sleap_dataset.py (1 hunks)
- biogtr/inference/metrics.py (2 hunks)
- biogtr/inference/track.py (1 hunks)
- biogtr/inference/track_queue.py (1 hunks)
- biogtr/inference/tracker.py (5 hunks)
- biogtr/io/init.py (1 hunks)
- biogtr/io/association_matrix.py (1 hunks)
- biogtr/io/frame.py (1 hunks)
- biogtr/io/instance.py (1 hunks)
- biogtr/io/track.py (1 hunks)
- biogtr/models/global_tracking_transformer.py (3 hunks)
- biogtr/models/gtr_runner.py (3 hunks)
- biogtr/models/model_utils.py (1 hunks)
- biogtr/models/transformer.py (7 hunks)
- biogtr/training/losses.py (1 hunks)
- biogtr/training/train.py (1 hunks)
- tests/test_config.py (1 hunks)
- tests/test_data_model.py (4 hunks)
- tests/test_inference.py (1 hunks)
- tests/test_models.py (4 hunks)
- tests/test_training.py (1 hunks)
Files skipped from review due to trivial changes (1)
- biogtr/inference/track_queue.py
Additional context used
Learnings (3)
biogtr/io/__init__.py (1)
User: aaprasad PR: talmolab/biogtr#47 File: biogtr/io/__init__.py:8-8 Timestamp: 2024-05-20T21:57:47.690Z Learning: To resolve circular import issues in Python, consider using local imports within functions or methods to delay the import until it's actually needed.
biogtr/models/global_tracking_transformer.py (1)
User: aaprasad PR: talmolab/biogtr#46 File: biogtr/models/global_tracking_transformer.py:5-6 Timestamp: 2024-05-29T17:50:29.550Z Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
biogtr/io/frame.py (1)
User: aaprasad PR: talmolab/biogtr#46 File: biogtr/models/global_tracking_transformer.py:5-6 Timestamp: 2024-05-29T17:50:29.550Z Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
Ruff
biogtr/io/__init__.py
3-3:
biogtr.io.frame.Frame
imported but unused; consider removing, adding to__all__
, or using a redundant alias
4-4:
biogtr.io.instance.Instance
imported but unused; consider removing, adding to__all__
, or using a redundant alias
5-5:
biogtr.io.association_matrix.AssociationMatrix
imported but unused; consider removing, adding to__all__
, or using a redundant alias
6-6:
biogtr.io.track.Track
imported but unused; consider removing, adding to__all__
, or using a redundant aliasbiogtr/io/track.py
17-17: Undefined name
Instance
46-46: Undefined name
Instances
64-64: Undefined name
Frame
80-80: Undefined name
Instance
80-80: Undefined name
Instance
tests/test_training.py
10-10:
biogtr.models.global_tracking_transformer.GlobalTrackingTransformer
imported but unused
11-11:
omegaconf.DictConfig
imported but unused
47-47: Local variable
feats
is assigned to but never usedbiogtr/inference/track.py
133-133: f-string without any placeholders
biogtr/models/global_tracking_transformer.py
84-84: Undefined name
AssociationMatrix
biogtr/datasets/microscopy_dataset.py
87-87: Do not assign a
lambda
expression, use adef
biogtr/models/model_utils.py
3-3:
typing.Tuple
imported but unusedtests/test_inference.py
4-4:
pytest
imported but unused
168-168: Local variable
N
is assigned to but never usedbiogtr/io/association_matrix.py
113-113: f-string without any placeholders
137-137: f-string without any placeholders
biogtr/datasets/sleap_dataset.py
9-9:
warnings
imported but unusedtests/test_models.py
429-429: Local variable
img_shape
is assigned to but never usedbiogtr/inference/tracker.py
121-121: f-string without any placeholders
177-177: Undefined name
curr_track
177-177: Local variable
curr_track
is assigned to but never usedbiogtr/io/frame.py
59-59: Undefined name
AssociationMatrix
328-328: Undefined name
AssociationMatrix
347-347: Undefined name
AssociationMatrix
biogtr/io/instance.py
93-93: Undefined name
Frame
520-520: Undefined name
Frame
529-529: Undefined name
Frame
Additional comments not posted (56)
biogtr/datasets/eval_dataset.py (1)
4-4
: LGTM! The imports and class definition align well with the PR objectives of enhancing data handling.Also applies to: 5-5
tests/test_config.py (1)
4-4
: LGTM! The tests are comprehensive and provide good coverage for theConfig
class.biogtr/training/train.py (1)
6-6
: LGTM! The training script is well-organized and effectively handles different training configurations.tests/test_training.py (1)
6-6
: LGTM! The tests are comprehensive and validate the training logic effectively.Also applies to: 7-7
biogtr/inference/track.py (1)
3-5
: Update imports to reflect new structure.Ensure that the updated imports from
biogtr.io.frame
andbiogtr.models.gtr_runner
are correctly used in the file.biogtr/models/global_tracking_transformer.py (2)
5-5
: Ensure correct usage of theInstance
class frombiogtr.io.instance
.The import of
Instance
is used for type hinting in theforward
method parameters, aligning with previous implementations.
84-84
: Import theAssociationMatrix
class to fix the undefined name.+ from biogtr.io.association_matrix import AssociationMatrix
Likely invalid or redundant comment.
biogtr/datasets/base_dataset.py (1)
4-4
: Update import to reflect new structure.The import from
biogtr.io.frame
is correctly updated and used in the type hints and method implementations.tests/test_data_model.py (4)
3-6
: Ensure correct imports for testing.Imports from
biogtr.io.frame
,biogtr.io.instance
,biogtr.io.association_matrix
, andbiogtr.io.track
are correctly used in the test functions.
108-116
: Test the functionality ofAssociationMatrix
withinFrame
.The test checks the integration of
AssociationMatrix
withFrame
, ensuring that methods likehas_matches
,has_asso_output
, andhas_traj_score
work as expected.
135-196
: Test the functionality and error handling ofAssociationMatrix
.The tests cover various aspects of
AssociationMatrix
, including error handling with incorrect tensor shapes and indexing, and reductions based on different criteria.
199-210
: Test the functionality of theTrack
class.The test checks the
Track
class for correct handling of instances, ensuring that properties liketrack_id
and indexing work as expected.biogtr/datasets/microscopy_dataset.py (1)
6-7
: Update imports to reflect new structure.The imports from
biogtr.io.frame
andbiogtr.io.instance
are correctly updated and used in the class methods.biogtr/models/model_utils.py (6)
5-5
: Ensure correct usage of theInstance
class frombiogtr.io.instance
.The import of
Instance
is correctly used in the function parameters for type hinting and ensuring correct data handling.
5-5
: Review the softmax application to ensure it handles edge cases correctly.The function correctly handles the softmax application, including the addition of a zero-padding column to maintain shape consistency.
5-5
: Review the error handling in optimizer initialization.The function includes comprehensive error handling to ensure that the optimizer is correctly instantiated based on the provided configuration.
5-5
: Review the error handling in scheduler initialization.The function includes comprehensive error handling to ensure that the scheduler is correctly instantiated based on the provided configuration.
5-5
: Review the logger initialization to ensure it handles different configurations correctly.The function correctly initializes different types of loggers based on the configuration, handling potential exceptions and providing informative error messages.
5-5
: Ensure the device detection logic is robust and accurate.The function correctly identifies the available computing device, checking for CUDA and MPS support and defaulting to CPU if necessary.
biogtr/datasets/cell_tracking_dataset.py (3)
6-7
: Updated imports to reflect new module structure.
Line range hint
15-126
: Constructor is well-structured and handles dataset initialization effectively.
126-126
: Methodget_indices
correctly retrieves label and frame indices.tests/test_inference.py (4)
6-7
: Updated imports to reflect new module structure.
15-91
: Comprehensive test forTrackQueue
functionality.
92-92
: Test forTracker
is well-structured and effectively checks the functionality.
92-92
: Test for post-processing methods is detailed and covers multiple scenarios.biogtr/models/gtr_runner.py (8)
11-12
: Updated imports to reflect new module structure.
Line range hint
11-76
: Constructor is well-structured and handles model initialization effectively.
76-76
: Methodforward
correctly implements the model's forward logic.
76-76
: Methodtraining_step
is well-structured and effectively handles the training logic.
76-76
: Methodvalidation_step
is well-structured and effectively handles the validation logic.
76-76
: Methodtest_step
is well-structured and effectively handles the testing logic.
76-76
: Methodpredict_step
correctly implements the inference logic.
Line range hint
76-169
: Method_shared_eval_step
is well-structured and effectively handles the evaluation logic.biogtr/training/losses.py (3)
3-3
: Updated imports to reflect new module structure.
3-3
: Constructor is well-structured and handles loss function initialization effectively.
3-3
: Method_get_asso_gt
correctly computes the association ground truth.biogtr/inference/metrics.py (4)
6-6
: Updated imports to reflect new module structure.
6-6
: Functionget_matches
correctly extracts matches between predicted and ground truth labels.
6-6
: Functionget_switches
correctly identifies misassigned predicted trajectory labels.
6-6
: Functionget_switch_count
correctly calculates the number of mislabeled predicted trajectories.biogtr/io/association_matrix.py (4)
11-19
: Class definition and documentation are clear and concise.
21-23
: Attributes are well-defined using theattrs
library.
75-83
: Method to convert matrix to numpy array is correctly implemented.
245-287
: Method for retrieving indices is robust and handles various input types effectively.biogtr/datasets/sleap_dataset.py (1)
10-11
: Imports updated to reflect new module structure.tests/test_models.py (1)
5-6
: Imports updated to reflect new module structure.biogtr/inference/tracker.py (2)
6-6
: Imports updated to reflect new module structure.
131-131
: The f-string in the docstring does not contain any placeholders, which might be a typo or an oversight.- frames: A list of Frames (See `biogtr.io.data_structures.Frame` for more info). + frames: A list of Frames (See `biogtr.io.frame.Frame` for more info).Likely invalid or redundant comment.
biogtr/io/frame.py (2)
328-328
: Undefined nameAssociationMatrix
used in type hint.
347-347
: Undefined nameAssociationMatrix
used in type hint.biogtr/io/instance.py (5)
11-27
: The_to_tensor
function is well-implemented with appropriate type checks and conversions.
55-120
: TheInstance
class is well-structured with clear attribute definitions and initialization. The use ofattrs
simplifies the code and ensures type safety and default handling.
95-120
: The__attrs_post_init__
method effectively handles default values and type conversions. However, consider handling exceptions for operations that might fail, such as tensor operations.#!/bin/bash # Description: Verify tensor operations in __attrs_post_init__ do not raise exceptions with typical inputs. # Test: Run typical inputs through the method. Expect: No exceptions. python -c "from biogtr.io.instance import Instance; inst = Instance(); inst.__attrs_post_init__()"
154-209
: Thefrom_slp
class method is a good example of a factory method, but it lacks error handling for potential issues during the conversion process, such as type mismatches or missing attributes inslp_instance
.#!/bin/bash # Description: Verify that `from_slp` handles all potential input variations without error. # Test: Provide various malformed inputs to the method. Expect: Graceful handling or specific error messages. python -c "from biogtr.io.instance import Instance; Instance.from_slp(malformed_input)"
211-243
: Theto_slp
method converts the instance to asleap_io.PredictedInstance
object. Ensure that all attributes used in the conversion are properly validated and handled to avoid runtime errors.#!/bin/bash # Description: Verify that all attributes used in `to_slp` are validated. # Test: Check for attribute presence and correct types before conversion. Expect: No runtime errors. python -c "from biogtr.io.instance import Instance; inst = Instance(); inst.to_slp()"
) | ||
|
||
_instances: list["Instance"] = attrs.field(alias="instances", factory=list) | ||
_asso_output: "AssociationMatrix" = attrs.field(alias="asso_output", default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undefined name AssociationMatrix
used in type hint.
Please ensure that AssociationMatrix
is properly imported or defined within this module or in an importable module.
@@ -141,7 +142,7 @@ def _reset_parameters(self): | |||
|
|||
def forward( | |||
self, ref_instances: list[Instance], query_instances: list[Instance] = None | |||
) -> tuple[list[torch.Tensor], dict[str, torch.Tensor]]: | |||
) -> list[AssociationMatrix]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize the concatenation of instance features.
Consider using a more efficient method for concatenating tensors, especially when dealing with large lists of instances. This could potentially improve performance.
- [instance.features for instance in ref_instances], dim=0
+ torch.stack([instance.features for instance in ref_instances], dim=0)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
) -> list[AssociationMatrix]: | |
) -> list[AssociationMatrix]: | |
torch.stack([instance.features for instance in ref_instances], dim=0) |
def to_dataframe( | ||
self, row_labels: str = "gt", col_labels: str = "gt" | ||
) -> pd.DataFrame: | ||
"""Convert the association matrix to a pandas DataFrame. | ||
|
||
Args: | ||
row_labels: How to label the rows(queries). | ||
If list, then must match # of rows/queries | ||
If `"gt"` then label by gt track id. | ||
If `"pred"` then label by pred track id. | ||
Otherwise label by the query_instance indices | ||
col_labels: How to label the columns(references). | ||
If list, then must match # of columns/refs | ||
If `"gt"` then label by gt track id. | ||
If `"pred"` then label by pred track id. | ||
Otherwise label by the ref_instance indices | ||
|
||
Returns: | ||
The association matrix as a pandas dataframe. | ||
""" | ||
matrix = self.numpy() | ||
|
||
if not isinstance(row_labels, str): | ||
if len(row_labels) == len(self.query_instances): | ||
row_inds = row_labels | ||
else: | ||
raise ValueError( | ||
( | ||
f"Mismatched # of rows and labels!", | ||
f"Found {len(row_labels)} with {len(self.query_instances)} rows", | ||
) | ||
) | ||
else: | ||
if row_labels == "gt": | ||
row_inds = [ | ||
instance.gt_track_id.item() for instance in self.query_instances | ||
] | ||
|
||
elif row_labels == "pred": | ||
row_inds = [ | ||
instance.pred_track_id.item() for instance in self.query_instances | ||
] | ||
|
||
else: | ||
row_inds = np.arange(len(self.query_instances)) | ||
|
||
if not isinstance(col_labels, str): | ||
if len(col_labels) == len(self.ref_instances): | ||
col_inds = col_labels | ||
else: | ||
raise ValueError( | ||
( | ||
f"Mismatched # of columns and labels!", | ||
f"Found {len(col_labels)} with {len(self.ref_instances)} columns", | ||
) | ||
) | ||
else: | ||
if col_labels == "gt": | ||
col_inds = [ | ||
instance.gt_track_id.item() for instance in self.ref_instances | ||
] | ||
|
||
elif col_labels == "pred": | ||
col_inds = [ | ||
instance.pred_track_id.item() for instance in self.ref_instances | ||
] | ||
|
||
else: | ||
col_inds = np.arange(len(self.ref_instances)) | ||
|
||
asso_df = pd.DataFrame(matrix, index=row_inds, columns=col_inds) | ||
|
||
return asso_df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enhance flexibility in labeling rows and columns in DataFrame conversion.
Consider adding an optional parameter to allow custom labels for rows and columns, enhancing the flexibility of this method:
def to_dataframe(self, row_labels=None, col_labels=None):
row_inds = row_labels if row_labels is not None else [instance.id for instance in self.query_instances]
col_inds = col_labels if col_labels is not None else [instance.id for instance in self.ref_instances]
...
_skeleton: sio.Skeleton = attrs.field(alias="skeleton", default=None) | ||
_pose: dict[str, ArrayLike] = attrs.field(alias="pose", factory=dict) | ||
_device: str = attrs.field(alias="device", default=None) | ||
_frame: "Frame" = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the missing import for Frame
in the setter method.
+ from biogtr.io.frame import Frame
Also applies to: 520-520, 529-529
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
_frame: "Frame" = None | |
+ from biogtr.io.frame import Frame | |
_frame: "Frame" = None |
def get_tracks( | ||
self, instances: list["Instance"], label: str = "pred" | ||
) -> dict[int, list["Instance"]]: | ||
"""Group instances by track. | ||
|
||
Args: | ||
instances: The list of instances to group | ||
label: the track id type to group by. Either `pred` or `gt`. | ||
|
||
Returns: | ||
A dictionary of track_id:instances | ||
""" | ||
if label == "pred": | ||
traj_ids = set([instance.pred_track_id.item() for instance in instances]) | ||
traj = { | ||
track_id: [ | ||
instance | ||
for instance in instances | ||
if instance.pred_track_id.item() == track_id | ||
] | ||
for track_id in traj_ids | ||
} | ||
|
||
elif label == "gt": | ||
traj_ids = set( | ||
[instance.gt_track_id.item() for instance in self.ref_instances] | ||
) | ||
traj = { | ||
track_id: [ | ||
instance | ||
for instance in self.ref_instances | ||
if instance.gt_track_id.item() == track_id | ||
] | ||
for track_id in traj_ids | ||
} | ||
|
||
else: | ||
raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.") | ||
|
||
return traj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate input parameters more robustly.
- raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")
+ raise ValueError(f"Unsupported label '{label}'. Expected 'pred' or 'gt'.")
Committable suggestion was skipped due low confidence.
def __getitem__(self, inds) -> np.ndarray: | ||
"""Get elements of the association matrix. | ||
|
||
Args: | ||
inds: A tuple of query indices and reference indices. | ||
Indices can be either: | ||
A single instance or integer. | ||
A list of instances or integers. | ||
|
||
Returns: | ||
An np.ndarray containing the elements requested. | ||
""" | ||
query_inst, ref_inst = inds | ||
|
||
query_ind = self.__getindices__(query_inst, self.query_instances) | ||
ref_ind = self.__getindices__(ref_inst, self.ref_instances) | ||
|
||
try: | ||
return self.numpy()[query_ind[:, None], ref_ind].squeeze() | ||
|
||
except IndexError as e: | ||
print(f"Query_insts: {type(query_inst)}") | ||
print(f"Query_inds: {query_ind}") | ||
print(f"Ref_insts: {type(ref_inst)}") | ||
print(f"Ref_ind: {ref_ind}") | ||
raise (e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve error handling by providing more context.
- print(f"Query_insts: {type(query_inst)}")
- print(f"Query_inds: {query_ind}")
- print(f"Ref_insts: {type(ref_inst)}")
- print(f"Ref_ind: {ref_ind}")
+ print(f"Error accessing matrix with query instances of type {type(query_inst)} and reference instances of type {type(ref_inst)}")
+ print(f"Query indices attempted: {query_ind}, Reference indices attempted: {ref_ind}")
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def __getitem__(self, inds) -> np.ndarray: | |
"""Get elements of the association matrix. | |
Args: | |
inds: A tuple of query indices and reference indices. | |
Indices can be either: | |
A single instance or integer. | |
A list of instances or integers. | |
Returns: | |
An np.ndarray containing the elements requested. | |
""" | |
query_inst, ref_inst = inds | |
query_ind = self.__getindices__(query_inst, self.query_instances) | |
ref_ind = self.__getindices__(ref_inst, self.ref_instances) | |
try: | |
return self.numpy()[query_ind[:, None], ref_ind].squeeze() | |
except IndexError as e: | |
print(f"Query_insts: {type(query_inst)}") | |
print(f"Query_inds: {query_ind}") | |
print(f"Ref_insts: {type(ref_inst)}") | |
print(f"Ref_ind: {ref_ind}") | |
raise (e) | |
def __getitem__(self, inds) -> np.ndarray: | |
"""Get elements of the association matrix. | |
Args: | |
inds: A tuple of query indices and reference indices. | |
Indices can be either: | |
A single instance or integer. | |
A list of instances or integers. | |
Returns: | |
An np.ndarray containing the elements requested. | |
""" | |
query_inst, ref_inst = inds | |
query_ind = self.__getindices__(query_inst, self.query_instances) | |
ref_ind = self.__getindices__(ref_inst, self.ref_instances) | |
try: | |
return self.numpy()[query_ind[:, None], ref_ind].squeeze() | |
except IndexError as e: | |
print(f"Error accessing matrix with query instances of type {type(query_inst)} and reference instances of type {type(ref_inst)}") | |
print(f"Query indices attempted: {query_ind}, Reference indices attempted: {ref_ind}") | |
raise (e) |
def reduce( | ||
self, | ||
row_dims: str = "instance", | ||
col_dims: str = "track", | ||
row_grouping: str = None, | ||
col_grouping: str = "pred", | ||
reduce_method: callable = np.sum, | ||
) -> pd.DataFrame: | ||
"""Aggregate the association matrix by specified dimensions and grouping. | ||
|
||
Args: | ||
row_dims: A str indicating how to what dimensions to reduce rows to. | ||
Either "instance" (remains unchanged), or "track" (n_rows=n_traj). | ||
col_dims: A str indicating how to dimensions to reduce rows to. | ||
Either "instance" (remains unchanged), or "track" (n_cols=n_traj) | ||
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt". | ||
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt". | ||
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. | ||
|
||
Returns: | ||
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe. | ||
""" | ||
n_rows = len(self.query_instances) | ||
n_cols = len(self.ref_instances) | ||
|
||
col_tracks = {-1: self.ref_instances} | ||
row_tracks = {-1: self.query_instances} | ||
|
||
col_inds = [i for i in range(len(self.ref_instances))] | ||
row_inds = [i for i in range(len(self.query_instances))] | ||
|
||
if col_dims == "track": | ||
col_tracks = self.get_tracks(self.ref_instances, col_grouping) | ||
col_inds = list(col_tracks.keys()) | ||
n_cols = len(col_inds) | ||
|
||
if row_dims == "track": | ||
row_tracks = self.get_tracks(self.query_instances, row_grouping) | ||
row_inds = list(row_tracks.keys()) | ||
n_rows = len(row_inds) | ||
|
||
reduced_matrix = [] | ||
for row_track, row_instances in row_tracks.items(): | ||
|
||
for col_track, col_instances in col_tracks.items(): | ||
asso_matrix = self[row_instances, col_instances] | ||
|
||
if col_dims == "track": | ||
asso_matrix = reduce_method(asso_matrix, axis=1) | ||
|
||
if row_dims == "track": | ||
asso_matrix = reduce_method(asso_matrix, axis=0) | ||
|
||
reduced_matrix.append(asso_matrix) | ||
|
||
reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T | ||
|
||
return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor to use more descriptive parameter names in the reduce
method.
- row_dims: str = "instance",
- col_dims: str = "track",
+ row_dimensions: str = "instance",
+ column_dimensions: str = "track",
...
- "Reduce association matrix rows/columns to inst/traj x traj/inst."
+ "Aggregate the association matrix by specified dimensions and grouping."
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def reduce( | |
self, | |
row_dims: str = "instance", | |
col_dims: str = "track", | |
row_grouping: str = None, | |
col_grouping: str = "pred", | |
reduce_method: callable = np.sum, | |
) -> pd.DataFrame: | |
"""Aggregate the association matrix by specified dimensions and grouping. | |
Args: | |
row_dims: A str indicating how to what dimensions to reduce rows to. | |
Either "instance" (remains unchanged), or "track" (n_rows=n_traj). | |
col_dims: A str indicating how to dimensions to reduce rows to. | |
Either "instance" (remains unchanged), or "track" (n_cols=n_traj) | |
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt". | |
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt". | |
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. | |
Returns: | |
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe. | |
""" | |
n_rows = len(self.query_instances) | |
n_cols = len(self.ref_instances) | |
col_tracks = {-1: self.ref_instances} | |
row_tracks = {-1: self.query_instances} | |
col_inds = [i for i in range(len(self.ref_instances))] | |
row_inds = [i for i in range(len(self.query_instances))] | |
if col_dims == "track": | |
col_tracks = self.get_tracks(self.ref_instances, col_grouping) | |
col_inds = list(col_tracks.keys()) | |
n_cols = len(col_inds) | |
if row_dims == "track": | |
row_tracks = self.get_tracks(self.query_instances, row_grouping) | |
row_inds = list(row_tracks.keys()) | |
n_rows = len(row_inds) | |
reduced_matrix = [] | |
for row_track, row_instances in row_tracks.items(): | |
for col_track, col_instances in col_tracks.items(): | |
asso_matrix = self[row_instances, col_instances] | |
if col_dims == "track": | |
asso_matrix = reduce_method(asso_matrix, axis=1) | |
if row_dims == "track": | |
asso_matrix = reduce_method(asso_matrix, axis=0) | |
reduced_matrix.append(asso_matrix) | |
reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T | |
return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds) | |
def reduce( | |
self, | |
row_dimensions: str = "instance", | |
column_dimensions: str = "track", | |
row_grouping: str = None, | |
col_grouping: str = "pred", | |
reduce_method: callable = np.sum, | |
) -> pd.DataFrame: | |
"""Aggregate the association matrix by specified dimensions and grouping. | |
Args: | |
row_dimensions: A str indicating how to what dimensions to reduce rows to. | |
Either "instance" (remains unchanged), or "track" (n_rows=n_traj). | |
column_dimensions: A str indicating how to dimensions to reduce rows to. | |
Either "instance" (remains unchanged), or "track" (n_cols=n_traj) | |
row_grouping: A str indicating how to group rows when aggregating. Either "pred" or "gt". | |
col_grouping: A str indicating how to group columns when aggregating. Either "pred" or "gt". | |
method: A callable function that operates on numpy matrices and can take an `axis` arg for reducing. | |
Returns: | |
The association matrix reduced to an inst/traj x traj/inst association matrix as a dataframe. | |
""" | |
n_rows = len(self.query_instances) | |
n_cols = len(self.ref_instances) | |
col_tracks = {-1: self.ref_instances} | |
row_tracks = {-1: self.query_instances} | |
col_inds = [i for i in range(len(self.ref_instances))] | |
row_inds = [i for i in range(len(self.query_instances))] | |
if column_dimensions == "track": | |
col_tracks = self.get_tracks(self.ref_instances, col_grouping) | |
col_inds = list(col_tracks.keys()) | |
n_cols = len(col_inds) | |
if row_dimensions == "track": | |
row_tracks = self.get_tracks(self.query_instances, row_grouping) | |
row_inds = list(row_tracks.keys()) | |
n_rows = len(row_inds) | |
reduced_matrix = [] | |
for row_track, row_instances in row_tracks.items(): | |
for col_track, col_instances in col_tracks.items(): | |
asso_matrix = self[row_instances, col_instances] | |
if column_dimensions == "track": | |
asso_matrix = reduce_method(asso_matrix, axis=1) | |
if row_dimensions == "track": | |
asso_matrix = reduce_method(asso_matrix, axis=0) | |
reduced_matrix.append(asso_matrix) | |
reduced_matrix = np.array(reduced_matrix).reshape(n_cols, n_rows).T | |
return pd.DataFrame(reduced_matrix, index=row_inds, columns=col_inds) |
from biogtr.io.association_matrix import AssociationMatrix | ||
from biogtr.io.track import Track | ||
|
||
# TODO: expose config without circular import error from biogtr.io.config import Config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure the dynamic import of Config
is implemented to resolve the circular import issue as previously discussed.
Would you like to proceed with the implementation of the dynamic import solution for Config
?
""" | ||
|
||
_id: int = attrs.field(alias="id") | ||
_instances: list["Instance"] = attrs.field(alias="instances", factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the undefined names and type hints.
- _instances: list["Instance"] = attrs.field(alias="instances", factory=list)
+ from biogtr.io.instance import Instance
+ _instances: list[Instance] = attrs.field(alias="instances", factory=list)
- def instances(self) -> list["Instances"]:
+ def instances(self) -> list[Instance]:
- def frames(self) -> set["Frame"]:
+ from biogtr.io.frame import Frame
+ def frames(self) -> set[Frame]:
- def __getitem__(self, ind) -> Union["Instance", list["Instance"]]:
+ def __getitem__(self, ind) -> Union[Instance, list[Instance]]:
Also applies to: 46-46, 64-64, 80-80
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
_instances: list["Instance"] = attrs.field(alias="instances", factory=list) | |
from biogtr.io.instance import Instance | |
from biogtr.io.frame import Frame | |
_instances: list[Instance] = attrs.field(alias="instances", factory=list) | |
def instances(self) -> list[Instance]: | |
... | |
def frames(self) -> set[Frame]: | |
... | |
def __getitem__(self, ind) -> Union[Instance, list[Instance]]: | |
... |
@@ -2,7 +2,7 @@ | |||
|
|||
from typing import List, Tuple, Iterable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
- from typing import List, Tuple, Iterable
+ from typing import List, Iterable
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
from typing import List, Tuple, Iterable | |
from typing import List, Iterable |
As per #39, just as we store model inputs in class objects for easy manipulation, we want to do the same for model outputs.
Thus, in this PR we:
biogtr.io
and have separate modules forbiogtr.io.Frame
andbiogtr.io.Instance
Frame
andInstance
classes to useattrs
for initializationAssociationMatrix
class which- stores the association matrix
- enables easy lookup thru either
int
orInstance
indexing- reduces the association matrix to
(n_query/n_traj , n_ref/n_traj)
Track
object which stores instances of the same track idInstance
orFrame
Objects #35 by storing embeddings inInstance
object and have models just returnAssociationMatrix
'sSummary by CodeRabbit
New Features
AssociationMatrix
for managing and analyzing association scores.Instance
class for handling individual tracking instances.Frame
class for managing video frame data and related instances.Track
class for managing instances of the same track.Enhancements
GlobalTrackingTransformer
to improve instance handling and feature extraction.GTRRunner
to operate on lists ofFrame
objects for better data management and processing.Bug Fixes
Tests
AssociationMatrix
andTrack
to improve testing coverage and reliability.Instance
andFrame
for better accuracy and robustness.