From bc16b1f53d4b2a4066a14308a0ab13fe642c8ac7 Mon Sep 17 00:00:00 2001 From: Luca Lumetti Date: Mon, 14 Nov 2022 23:55:50 +0100 Subject: [PATCH] improve compute_and_normalize_coords() --- models/LatePosPadUNet3D.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/models/LatePosPadUNet3D.py b/models/LatePosPadUNet3D.py index 8afc3f6..b8bf62b 100755 --- a/models/LatePosPadUNet3D.py +++ b/models/LatePosPadUNet3D.py @@ -1,19 +1,25 @@ import torch import torch.nn as nn -# TODO: improve -def compute_and_normalize_coords(coords, MAX_X = 168, MAX_Y= 280, MAX_Z=360): - coord_list = [] - for coord in coords: - x = torch.arange(coord[0], coord[3]) - y = torch.arange(coord[1], coord[4]) - z = torch.arange(coord[2], coord[5]) - gx, gy, gz = torch.meshgrid(x,y,z, indexing="ij") - gx = gx/(MAX_X-1) - gy = gy/(MAX_Y-1) - gz = gz/(MAX_Z-1) - coord_list.append(torch.cat([gx.unsqueeze(0), gy.unsqueeze(0), gz.unsqueeze(0)]).unsqueeze(0)) - return torch.cat(coord_list).to(coords.device) +def compute_and_normalize_coords(coords, MAX_X=168, MAX_Y=280, MAX_Z=360): + max_dims = torch.tensor([MAX_X, MAX_Y, MAX_Z])-1 + patch_size = coords[:,3:] - coords[:,:3] + + # assert that all the elements are the same + assert torch.all(patch_size[0] == patch_size), f"patch_size not constant across batch! Got: {patch_size}" + + patch_size = patch_size[0] + dim_x, dim_y, dim_z = patch_size + x = torch.arange(dim_x) + y = torch.arange(dim_y) + z = torch.arange(dim_z) + + cprod = torch.cartesian_prod(x,y,z).T + cprod = cprod.reshape(3, dim_x, dim_y, dim_z) + cprod = cprod.unsqueeze(0) # dim: 1 x 3 x X x Y x Z + offset = coords[:, :3][:, :, None, None, None] # dim: B x 3 x 1 x 1 x 1 + max_dims = max_dims[None,:,None,None,None] # dim: 1 x 3 x 1 x 1 x 1 + return (cprod+offset)/max_dims def initialize_weights(*models): for model in models: