forked from SuperMedIntel/Medical-SAM2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbctv_test.py
30 lines (21 loc) · 921 Bytes
/
bctv_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import cfg
from func_3d.utils import get_network
from func_3d.function import validation_sam
from func_3d.dataset import BTCV
from func_3d.dataset import get_dataloader
from hydra import initialize, compose
import hydra
from hydra.core.global_hydra import GlobalHydra
def main(args):
device = torch.device(f'cuda:{args.gpu_device}' if torch.cuda.is_available() else 'cpu')
print("deviceeeee :" , device)
nice_train_loader, nice_test_loader = get_dataloader(args)
net = get_network(args, args.net, use_gpu=True, gpu_device=args.gpu_device)
validation_loss, validation_metrics = validation_sam(args, nice_test_loader, 0, net)
print(f"Validation Loss: {validation_loss}")
print(f"Validation Metrics (IoU and Dice): {validation_metrics}")
if __name__ == "__main__":
initialize(config_path=f"./sam2_train", version_base='1.2')
args = cfg.parse_args()
main(args)