Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vista3d transpose bug #8059

Merged
merged 18 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/vista3d/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def point_based_window_inferer(
point_labels=point_labels,
class_vector=class_vector,
prompt_class=prompt_class,
patch_coords=unravel_slice,
patch_coords=[unravel_slice],
prev_mask=prev_mask,
**kwargs,
)
Expand Down
16 changes: 10 additions & 6 deletions monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
def forward(
self,
input_images: torch.Tensor,
patch_coords: Sequence[slice] | None = None,
patch_coords: list[Sequence[slice]] | None = None,
point_coords: torch.Tensor | None = None,
point_labels: torch.Tensor | None = None,
class_vector: torch.Tensor | None = None,
Expand Down Expand Up @@ -364,8 +364,12 @@ def forward(
the points are for zero-shot or supported class. When class_vector and point_coords are both
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
will be considered novel class.
patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase.
patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
inference. This value is passed from sliding_window_inferer.
This is an indicator for training phase or validation phase.
Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
functions using patch_coords will by default use patch_coords[0].
labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
Expand Down Expand Up @@ -395,14 +399,14 @@ def forward(
if val_point_sampler is None:
# TODO: think about how to refactor this part.
val_point_sampler = self.sample_points_patch_val
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set)
point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
if prompt_class[0].item() == 0: # type: ignore
point_labels[0] = -1 # type: ignore
labels, prev_mask = None, None
elif point_coords is not None:
# If not performing patch-based point only validation, use user provided click points for inference.
# the point clicks is in original image space, convert it to current patch-coordinate space.
point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore
point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore

if point_coords is not None and point_labels is not None:
# remove points that used for padding purposes (point_label = -1)
Expand Down Expand Up @@ -455,7 +459,7 @@ def forward(
logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
if prev_mask is not None and patch_coords is not None:
logits = self.connected_components_combine(
prev_mask[patch_coords].transpose(1, 0).to(logits.device),
prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
logits[mapping_index],
point_coords, # type: ignore
point_labels, # type: ignore
Expand Down
Loading