-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use instances as model input #46
Conversation
WalkthroughThis update introduces significant enhancements to the Changes
Sequence Diagram(s) (Beta)sequenceDiagram
participant User
participant Tracker
participant GlobalTrackingTransformer
participant GTRRunner
User->>Tracker: Call sliding_inference(device)
Tracker->>GlobalTrackingTransformer: Call forward(ref_instances, query_instances)
GlobalTrackingTransformer->>GTRRunner: Call extract_features(instances)
GTRRunner->>GlobalTrackingTransformer: Return extracted features
GlobalTrackingTransformer->>Tracker: Return tracking results
Tracker->>User: Return inference results
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
Out of diff range and nitpick comments (2)
biogtr/models/global_tracking_transformer.py (1)
Line range hint
82-122
: Update docstring to reflect new parameters.- Args: - frames: List of Frames from chunk containing crops of objects + gt label info - query_frame: Frame index used as query for self attention. Only used in sliding inference where query frame is the last frame in the window. + Args: + ref_instances: List of reference instances containing crops of objects + gt label info. + query_instances: List of query instances used for self attention. Only used in sliding inference where query instances are the last instances in the window.The docstring should accurately describe the new parameters
ref_instances
andquery_instances
to avoid confusion.biogtr/models/gtr_runner.py (1)
Line range hint
63-75
: Update docstring to reflect new parameters.- Args: - instances: a list of dicts where each dict is a frame with gt data + Args: + ref_instances: List of reference instances containing crops of objects + gt label info. + query_instances: List of query instances used for self attention. Only used in sliding inference where query instances are the last instances in the window.The docstring should accurately describe the new parameters
ref_instances
andquery_instances
to avoid confusion.
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (8)
- biogtr/data_structures.py (5 hunks)
- biogtr/inference/tracker.py (5 hunks)
- biogtr/models/global_tracking_transformer.py (3 hunks)
- biogtr/models/gtr_runner.py (7 hunks)
- biogtr/models/model_utils.py (1 hunks)
- biogtr/models/transformer.py (8 hunks)
- biogtr/training/losses.py (2 hunks)
- tests/test_models.py (5 hunks)
Additional Context Used
Ruff (5)
biogtr/inference/tracker.py (3)
121-121: f-string without any placeholders
177-177: Undefined name
curr_track
177-177: Local variable
curr_track
is assigned to but never usedbiogtr/models/model_utils.py (1)
3-3:
typing.Tuple
imported but unusedtests/test_models.py (1)
428-428: Local variable
img_shape
is assigned to but never used
Additional comments not posted (16)
tests/test_models.py (4)
385-388
: Ensure thatref_pos_emb
andquery_pos_emb
are correctly initialized and passed to thetransformer_decoder
function. This change appears to be consistent with the function's purpose of testing the transformer decoder layer logic.
417-418
: The list comprehension correctly flattens the list of instances from frames. Ensure that the transformer can handle the flattened list of instances as expected.
442-443
: The list comprehension correctly flattens the list of instances from frames. Ensure that the transformer can handle the flattened list of instances as expected.
512-513
: The list comprehension correctly flattens the list of instances from frames. Ensure that the tracking transformer can handle the flattened list of instances as expected.biogtr/inference/tracker.py (4)
144-146
: Ensure thatframe_to_track.frame_id.device
is correctly initialized and passed to thecollate_tracks
method. This change appears to be consistent with the method's purpose of performing sliding inference.
237-239
: The list comprehensions correctly extractquery_instances
andall_instances
from the frames. Ensure that the method can handle these lists of instances as expected.
262-262
: Ensure that the model can handleall_instances
andquery_instances
as inputs. This change appears to be consistent with the method's purpose of running the global tracker.
330-330
: The updated calculation correctly usesnonquery_inds
to extractasso_nonquery
. Ensure that the method can handle this updated calculation as expected.biogtr/models/transformer.py (4)
14-17
: Imports look good and necessary for the changes made in the file.
143-149
: The updatedforward
method signature and docstring correctly reflect the new parametersref_instances
andquery_instances
.
173-237
: The handling of embeddings and positional encodings appears correct and aligns with the new input structure.
Line range hint
240-267
: The decoder and attention head logic correctly integrates the new input structure and performs the necessary operations.biogtr/data_structures.py (4)
127-128
: Initialize_frame
attribute toNone
.The initialization of the
_frame
attribute toNone
in the constructor is correct and follows best practices.
609-610
: Ensureinstances
is a list ofInstance
objects.Ensure that the
instances
parameter is always a list ofInstance
objects. If not, consider adding a type check or validation to handle unexpected types.if not all(isinstance(instance, Instance) for instance in instances): raise ValueError("All elements in instances must be of type Instance")
824-826
: Ensureinstances
is a list ofInstance
objects.Ensure that the
instances
parameter is always a list ofInstance
objects. If not, consider adding a type check or validation to handle unexpected types.if not all(isinstance(instance, Instance) for instance in instances): raise ValueError("All elements in instances must be of type Instance")
640-640
: Update__repr__
method to display keys of_traj_score
.The update to the
__repr__
method to display the keys of_traj_score
instead of values is correct and improves the readability of the output.
ref_features = torch.cat( | ||
[instance.features for instance in ref_instances], dim=0 | ||
).unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a check to ensure that ref_instances
is not empty before concatenating features to avoid potential runtime errors.
+ if not ref_instances:
+ raise ValueError("ref_instances should not be empty")
+
ref_features = torch.cat(
[instance.features for instance in ref_instances], dim=0
).unsqueeze(0)
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
ref_features = torch.cat( | |
[instance.features for instance in ref_instances], dim=0 | |
).unsqueeze(0) | |
if not ref_instances: | |
raise ValueError("ref_instances should not be empty") | |
ref_features = torch.cat( | |
[instance.features for instance in ref_instances], dim=0 | |
).unsqueeze(0) |
from biogtr.data_structures import Instance | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
- from biogtr.data_structures import Instance
The Instance
class is not used in this file. Removing unused imports helps keep the code clean and maintainable.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
from biogtr.data_structures import Instance | |
import torch | |
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used for type hinting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aaprasad, thank you for pointing that out. You're absolutely right; the Instance
class is used for type hinting in the forward
method parameters. I appreciate your correction, and I'll make sure to consider this in future reviews. 🐇
New Learnings Added
User: aaprasad
PR: talmolab/biogtr#46
File: biogtr/models/global_tracking_transformer.py:5-6
Timestamp: 2024-05-29T17:50:29.550Z
Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
Note: Learnings are effective only in the context of specific code segments. To apply general coding guidelines, please configure review instructions.
if any( | ||
[ | ||
(not instance.has_features()) and instance.has_crop() | ||
for instance in ref_instances | ||
] | ||
): | ||
crops = frame.get_crops() | ||
z = self.visual_encoder(crops) | ||
ref_crops = torch.concat( | ||
[instance.crop for instance in ref_instances], axis=0 | ||
) | ||
ref_z = self.visual_encoder(ref_crops) | ||
for i, z_i in enumerate(ref_z): | ||
ref_instances[i].features = z_i | ||
|
||
for i, z_i in enumerate(z): | ||
frame.instances[i].features = z_i | ||
if query_instances: | ||
if any( | ||
[ | ||
(not instance.has_features()) and instance.has_crop() | ||
for instance in query_instances | ||
] | ||
): | ||
query_crops = torch.concat( | ||
[instance.crop for instance in query_instances], axis=0 | ||
) | ||
query_z = self.visual_encoder(query_crops) | ||
for i, z_i in enumerate(query_z): | ||
query_instances[i].features = z_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify feature extraction logic.
- if any(
- [
- (not instance.has_features()) and instance.has_crop()
- for instance in ref_instances
- ]
- ):
- ref_crops = torch.concat(
- [instance.crop for instance in ref_instances], axis=0
- )
- ref_z = self.visual_encoder(ref_crops)
- for i, z_i in enumerate(ref_z):
- ref_instances[i].features = z_i
-
- if query_instances:
- if any(
- [
- (not instance.has_features()) and instance.has_crop()
- for instance in query_instances
- ]
- ):
- query_crops = torch.concat(
- [instance.crop for instance in query_instances], axis=0
- )
- query_z = self.visual_encoder(query_crops)
- for i, z_i in enumerate(query_z):
- query_instances[i].features = z_i
+ def extract_features(instances):
+ if any((not instance.has_features()) and instance.has_crop() for instance in instances):
+ crops = torch.concat([instance.crop for instance in instances], axis=0)
+ features = self.visual_encoder(crops)
+ for i, feature in enumerate(features):
+ instances[i].features = feature
+
+ extract_features(ref_instances)
+ if query_instances:
+ extract_features(query_instances)
This refactor reduces code duplication and improves readability by extracting the feature extraction logic into a helper function.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
if any( | |
[ | |
(not instance.has_features()) and instance.has_crop() | |
for instance in ref_instances | |
] | |
): | |
crops = frame.get_crops() | |
z = self.visual_encoder(crops) | |
ref_crops = torch.concat( | |
[instance.crop for instance in ref_instances], axis=0 | |
) | |
ref_z = self.visual_encoder(ref_crops) | |
for i, z_i in enumerate(ref_z): | |
ref_instances[i].features = z_i | |
for i, z_i in enumerate(z): | |
frame.instances[i].features = z_i | |
if query_instances: | |
if any( | |
[ | |
(not instance.has_features()) and instance.has_crop() | |
for instance in query_instances | |
] | |
): | |
query_crops = torch.concat( | |
[instance.crop for instance in query_instances], axis=0 | |
) | |
query_z = self.visual_encoder(query_crops) | |
for i, z_i in enumerate(query_z): | |
query_instances[i].features = z_i | |
def extract_features(instances): | |
if any((not instance.has_features()) and instance.has_crop() for instance in instances): | |
crops = torch.concat([instance.crop for instance in instances], axis=0) | |
features = self.visual_encoder(crops) | |
for i, feature in enumerate(features): | |
instances[i].features = feature | |
extract_features(ref_instances) | |
if query_instances: | |
extract_features(query_instances) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do this! Could be useful to have it as an instance method so that you can call it by itself as well for different kinds of preprocessing pipelines.
It may also be helpful to have it not be all-or-nothing and instead just run on the ones without features (or have a flag in this function to recompute the features explicitly).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@talmo, great suggestions! Implementing the feature extraction as an instance method will indeed enhance modularity and allow for more flexible usage across different preprocessing pipelines. Here's how you could modify the GlobalTrackingTransformer
class to include this functionality:
class GlobalTrackingTransformer(torch.nn.Module):
...
def extract_features(self, instances, force_recompute=False):
"""
Extract features for instances that either don't have them or are forced to recompute.
Args:
instances (list[Instance]): The instances to process.
force_recompute (bool): If True, recompute features even if they already exist.
"""
for instance in instances:
if force_recompute or (not instance.has_features() and instance.has_crop()):
crop = instance.crop.unsqueeze(0) # Assuming crop is a tensor that needs batch dimension
feature = self.visual_encoder(crop)
instance.features = feature.squeeze(0) # Remove batch dimension after processing
def forward(self, ref_instances, query_instances=None):
"""Execute forward pass of GTR Model to get association matrix."""
self.extract_features(ref_instances)
if query_instances:
self.extract_features(query_instances)
...
This method extract_features
checks each instance to determine if it needs feature extraction based on the force_recompute
flag or the absence of existing features. It processes only the necessary instances, which can improve efficiency, especially when dealing with large datasets.
Feel free to adjust the tensor handling (unsqueeze
and squeeze
) based on your actual data structure. This should give you a good starting point to integrate the method into your preprocessing pipelines!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@talmo if I make it a class method it won't have access to self.visual_encoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
@@ -2,35 +2,69 @@ | |||
|
|||
from typing import List, Tuple, Iterable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
- from typing import Tuple
The Tuple
type is not used in this file. Removing unused imports helps keep the code clean and maintainable.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
from typing import List, Tuple, Iterable | |
from typing import List, Iterable |
@@ -8,6 +8,7 @@ | |||
from biogtr.training.losses import AssoLoss | |||
from biogtr.models.model_utils import init_optimizer, init_scheduler | |||
from pytorch_lightning import LightningModule | |||
from biogtr.data_structures import Frame, Instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
- from biogtr.data_structures import Frame, Instance
+ from biogtr.data_structures import Frame
The Instance
class is not used in this file. Removing unused imports helps keep the code clean and maintainable.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
from biogtr.data_structures import Frame, Instance | |
from biogtr.data_structures import Frame |
biogtr/models/gtr_runner.py
Outdated
frames = [frame for frame in frames if frame.has_instances()] | ||
if len(frames) == 0: | ||
return None | ||
|
||
instances = [instance for frame in frames for instance in frame.instances] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize instance extraction.
- frames = [frame for frame in frames if frame.has_instances()]
- if len(frames) == 0:
- return None
-
- instances = [instance for frame in frames for instance in frame.instances]
+ instances = [
+ instance
+ for frame in frames if frame.has_instances()
+ for instance in frame.instances
+ ]
+ if len(instances) == 0:
+ return None
This refactor combines the filtering and extraction of instances into a single list comprehension, making the code more concise and efficient.
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
frames = [frame for frame in frames if frame.has_instances()] | |
if len(frames) == 0: | |
return None | |
instances = [instance for frame in frames for instance in frame.instances] | |
instances = [ | |
instance | |
for frame in frames if frame.has_instances() | |
for instance in frame.instances | |
] | |
if len(instances) == 0: | |
return None |
if any( | ||
[ | ||
(not instance.has_features()) and instance.has_crop() | ||
for instance in ref_instances | ||
] | ||
): | ||
crops = frame.get_crops() | ||
z = self.visual_encoder(crops) | ||
ref_crops = torch.concat( | ||
[instance.crop for instance in ref_instances], axis=0 | ||
) | ||
ref_z = self.visual_encoder(ref_crops) | ||
for i, z_i in enumerate(ref_z): | ||
ref_instances[i].features = z_i | ||
|
||
for i, z_i in enumerate(z): | ||
frame.instances[i].features = z_i | ||
if query_instances: | ||
if any( | ||
[ | ||
(not instance.has_features()) and instance.has_crop() | ||
for instance in query_instances | ||
] | ||
): | ||
query_crops = torch.concat( | ||
[instance.crop for instance in query_instances], axis=0 | ||
) | ||
query_z = self.visual_encoder(query_crops) | ||
for i, z_i in enumerate(query_z): | ||
query_instances[i].features = z_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do this! Could be useful to have it as an instance method so that you can call it by itself as well for different kinds of preprocessing pipelines.
It may also be helpful to have it not be all-or-nothing and instead just run on the ones without features (or have a flag in this function to recompute the features explicitly).
@@ -59,7 +60,9 @@ def __init__( | |||
self.metrics = metrics | |||
self.persistent_tracking = persistent_tracking | |||
|
|||
def forward(self, instances) -> torch.Tensor: | |||
def forward( | |||
self, ref_instances: list[Instance], query_instances: list[Instance] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the docstring for the new signature
|
||
def training_step( | ||
self, train_batch: list[dict], batch_idx: int | ||
self, train_batch: list[list[Frame]], batch_idx: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring
@@ -92,7 +93,7 @@ def training_step( | |||
return result | |||
|
|||
def validation_step( | |||
self, val_batch: list[dict], batch_idx: int | |||
self, val_batch: list[list[Frame]], batch_idx: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring
@@ -109,7 +110,9 @@ def validation_step( | |||
|
|||
return result | |||
|
|||
def test_step(self, test_batch: list[dict], batch_idx: int) -> dict[str, float]: | |||
def test_step( | |||
self, test_batch: list[list[Frame]], batch_idx: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring
biogtr/models/model_utils.py
Outdated
import torch | ||
|
||
|
||
def get_boxes_times(frames: List[Frame]) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Extract the bounding boxes and frame indices from the input list of instances. | ||
def get_boxes(instances: List[Instance]) -> torch.tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_boxes(instances: List[Instance]) -> torch.tensor: | |
def get_boxes(instances: List[Instance]) -> torch.Tensor: |
biogtr/models/model_utils.py
Outdated
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors containing the | ||
bounding boxes normalized by the height and width of the image | ||
and corresponding frame indices, respectively. | ||
The bounding boxes normalized by the height and width of the image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape? dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review Details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (4)
- biogtr/models/global_tracking_transformer.py (2 hunks)
- biogtr/models/gtr_runner.py (5 hunks)
- biogtr/models/model_utils.py (1 hunks)
- biogtr/models/transformer.py (8 hunks)
Additional Context Used
Ruff (1)
biogtr/models/model_utils.py (1)
3-3:
typing.Tuple
imported but unused
Learnings (1)
biogtr/models/global_tracking_transformer.py (1)
User: aaprasad PR: talmolab/biogtr#46 File: biogtr/models/global_tracking_transformer.py:5-6 Timestamp: 2024-05-29T17:50:29.550Z Learning: The `Instance` class is used for type hinting in the `forward` method parameters of the `GlobalTrackingTransformer` class in `biogtr/models/global_tracking_transformer.py`.
Additional comments not posted (25)
biogtr/models/global_tracking_transformer.py (4)
5-6
: Imports are correctly used for type hinting in theforward
method parameters.
Line range hint
11-82
: The__init__
method is well-implemented with clear documentation and appropriate initialization of components.
82-100
: Theforward
method is correctly implemented, handling both reference and optional query instances effectively.
104-132
: Theextract_features
method is efficiently implemented, with checks for necessary feature computation and handling of batch normalization errors.biogtr/models/model_utils.py (5)
9-29
: Theget_boxes
function is correctly implemented, efficiently extracting and normalizing bounding boxes from instances.
32-68
: Theget_times
function is well-implemented, correctly handling both reference and optional query instances to extract time indices.
Line range hint
70-86
: Thesoftmax_asso
function is correctly implemented, applying softmax to association outputs efficiently.
Line range hint
88-118
: Theinit_optimizer
function is well-implemented, handling various configurations and correctly initializing the appropriate optimizer.
Line range hint
120-156
: Theinit_scheduler
function is correctly implemented, handling various configurations and correctly initializing the appropriate scheduler.biogtr/models/gtr_runner.py (8)
11-11
: The import ofFrame
andInstance
is correctly used for type hinting and data handling within the class.
63-76
: Theforward
method is correctly implemented, handling both reference and optional query instances effectively.
Line range hint
79-91
: Thetraining_step
method is well-implemented, correctly processing the training batch and logging metrics.
Line range hint
97-109
: Thevalidation_step
method is correctly implemented, efficiently processing the validation batch and logging metrics.
Line range hint
115-127
: Thetest_step
method is well-implemented, correctly processing the test batch and logging metrics.
132-147
: Thepredict_step
method is correctly implemented, effectively running inference on the batch and returning predicted frames.
149-182
: The_shared_eval_step
method is efficiently implemented, used by train, test, and validation steps to run evaluation and handle exceptions appropriately.
Line range hint
184-213
: Theconfigure_optimizers
method is well-implemented, handling optional configurations for optimizers and schedulers effectively.biogtr/models/transformer.py (8)
14-17
: Imports are correctly used for type hinting and functionality within the class.
Line range hint
25-59
: The__init__
method of theTransformer
class is well-implemented, with clear documentation and appropriate initialization of components.
Line range hint
143-250
: Theforward
method is correctly implemented, handling both reference and optional query instances effectively, and returning detailed outputs.
Line range hint
262-322
: TheTransformerEncoderLayer
class is well-implemented, with a clear and efficient implementation of the encoder layer.
Line range hint
324-410
: TheTransformerDecoderLayer
class is correctly implemented, handling optional decoder self-attention and providing a clear implementation of the decoder layer.
Line range hint
412-438
: TheTransformerEncoder
class is well-implemented, correctly handling multiple encoder layers and optional normalization.
Line range hint
440-472
: TheTransformerDecoder
class is correctly implemented, handling multiple decoder layers and optional normalization and intermediate layer returns efficiently.
Line range hint
474-502
: The utility functions_get_clones
and_get_activation_fn
are well-implemented, providing necessary support for the transformer model.
|
||
Args: | ||
instances: A list of instances to compute features for | ||
force_recompute: indicate whether to compute features for all instances regardless of if they have instances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
force_recompute: indicate whether to compute features for all instances regardless of if they have instances | |
force_recompute: indicate whether to compute features for all instances regardless of if they have features |
train_batch: A single batch from the dataset which is a list of dicts | ||
with length `clip_length` where each dict is a frame | ||
train_batch: A single batch from the dataset which is a list of `Frame` objects | ||
with length `clip_length` containing Instances and other metadata. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with length `clip_length` containing Instances and other metadata. | |
with length `clip_length` containing Instances and other metadata. |
Currently, at all levels we use
biogtr.Frame
as the base unit of data given to the models as an input. However, one of the core advantages of the Global Tracking Transformer is that it really associates detections. Thus, if we are able to just use detections as input into the model then at a higher level we can come up with more useful ways to group the instances when passing it into the model. Thus, here we refactor thebiogtr.models.GlobalTrackingTransformer
andbiogtr.models.Transformer
to take as input a set ofref_instances
and a set ofquery instances
. Thequery_instances
are used as an input into the decoder to associate against theref_instances
. One level up, inbiogtr.inference.Tracker
, we basically show an example of how we can group instances by frame against window.Summary by CodeRabbit
New Features
Improvements
Bug Fixes
Refactor
Tests