here
x_t_d = torch.clip(x_t[..., :3], min=-1 * clip_bounds_m, max=clip_bounds_m)
x_t_m = torch.clip(x_t[..., 3:], min=-1 * clip_bounds_d, max=clip_bounds_d)
x_t = torch.cat((x_t_d, x_t_m), dim=-1)
Why clip on the last dimension, rather than the third dimension, as the x_t is B x N x C x H x W.