-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodel.py
24 lines (22 loc) · 1010 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# ----------------------------------------------#
# Pro : cbct
# File : dataset.py
# Date : 2023/2/22
# Author : Qing Wu
# Email : wuqing@shanghaitech.edu.cn
# ----------------------------------------------#
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attenuation_Smootion_Over_Energies_Loss(nn.Module):
def __init__(self, mask, lamb):
super(Attenuation_Smootion_Over_Energies_Loss, self).__init__()
self.mask = mask
self.lamb = lamb
def forward(self, ray, intensity):
batch_size, num_sample_ray, k, e_level = intensity.shape
mask = F.grid_sample(
self.mask, ray.unsqueeze(0).unsqueeze(0), mode='nearest', align_corners=False
)[0, 0, 0, :].view(batch_size, num_sample_ray, k) # (batch_size, num_sample_ray, 2*SOD)
diff = torch.sum(torch.abs(intensity[:, :, :, 1:] - intensity[:, :, :, :e_level-1]), dim=-1) * mask
return self.lamb * torch.sum(diff) / (batch_size * num_sample_ray * k)