-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
34 lines (30 loc) · 1.03 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import os
import time
import h5py
from glob import glob
import numpy as np
from torch.utils.data import Dataset, DataLoader
from scipy.io import loadmat
import torch.utils.data as Data
from PIL import Image
import torch.utils.data
import matplotlib.pyplot as plt
import torch.distributed as dist
import argparse
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.im_net = IM_net()
self.se_net = SE_net()
self.recon = RECON
def forward(self,sub,LI_proj,mp,pre_CT,metal_trace):
sinosub = self.se_net(sub.float(),mp.float()).cuda(args.local_rank)
sino = (sinosub+LI_proj).mul(metal_trace)+LI_proj.mul(1-metal_trace)
sino2im = self.recon(torch.transpose(sino,2,3)).cuda(args.local_rank)
inputs_im = torch.cat((sino2im,pre_CT),dim=1)
im = self.im_net(inputs_im.float(),pre_CT.float())
return sino,sino2im,im