We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi!Dear Developers! Here is my test code, please ask me if I wrote it correctly?
import torch import numpy as np from lib import networks from lib import models from lib.data.med_transforms import * from lib.utils import set_seed, dist_setup, get_conf from monai.losses import DiceCELoss, DiceLoss from collections import defaultdict, OrderedDict from monai.metrics import compute_meandice, compute_hausdorff_distance from functools import partial from lib.data.med_datasets import * from lib.utils import SmoothedValue, concat_all_gather, LayerDecayValueAssigner from monai.inferers import sliding_window_inference from monai.data import decollate_batch import nibabel as nib class Test(): def __init__(self, args): #super().__init__(args, test_path) self.args = args self.model_name = args.proj_name self.scaler = torch.cuda.amp.GradScaler() self.metric_funcs = OrderedDict([('Dice', compute_meandice), ('HD', partial(compute_hausdorff_distance, percentile=95))]) def build_model(self): print(f"=> creating model {self.model_name}") self.loss_fn = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr) self.post_pred, self.post_label = get_post_transforms(args) self.model = getattr(models, self.model_name)(encoder=getattr(networks, args.enc_arch), decoder=getattr(networks, args.dec_arch), args=args) print(f"=> loading checkpoint") checkpoint = torch.load(args.pretrain, map_location='cpu') state_dict = checkpoint['state_dict'] msg = self.model.load_state_dict(state_dict, strict=False) print(f"Loading messages: \n {msg}") print(f"=> Finish loading pretrained weights from {args.pretrain}") self.model.eval() self.model.cuda(args.gpu) def build_dataloader(self): print("=> creating test dataloader") args = self.args #test_transform = get_test_transforms(args) test_transform = get_testV2_transforms(args) self.val_dataloader = get_val_loader(args, args.batch_size, args.workers, test_transform) @torch.no_grad() def evaluate(self): args = self.args self.build_dataloader() self.build_model() model = self.model dice_list_case = [] print("=> Start Evaluating") val_loader = self.val_dataloader roi_size = (args.roi_x, args.roi_y, args.roi_z) if args.spatial_dim == 3 else None meters = defaultdict(SmoothedValue) ts_samples = int(len(val_loader)) val_samples = len(val_loader) - ts_samples ts_meters = defaultdict(SmoothedValue) for i, batch_data in enumerate(val_loader): image, target = batch_data['image'].to(args.gpu, non_blocking=True), batch_data['label'].to(args.gpu, non_blocking=True) original_affine = batch_data["label_meta_dict"]["affine"][0].numpy() _, _, h, w, d = target.shape target_shape = (h, w, d) img_name = batch_data["image_meta_dict"]["filename_or_obj"][0].split("/")[-1] with torch.cuda.amp.autocast(): val_output = sliding_window_inference(image, roi_size=roi_size, sw_batch_size=4, predictor=model, overlap=args.infer_overlap) val_output = torch.softmax(val_output, 1).cpu().numpy() val_output = np.argmax(val_output, axis=1).astype(np.uint8)[0] target = target.cpu().numpy()[0, 0, :, :, :] val_output = resample_3d(img=val_output, target_size=target_shape) print(f'val_output shape is {val_output.shape} | target shape is {target_shape}') mean_dice = dice(val_output == 1, target == 1) print(f"=>Evaluating on {img_name}, Mean Dice: {mean_dice}") dice_list_case.append(mean_dice) nib.save( nib.Nifti1Image(val_output.astype(np.uint8), original_affine), os.path.join('/home/lzb/wby/3D_Project/SelfMedMAEv2.0/Test_Output', img_name) ) print("Overall Mean Dice: {}".format(np.mean(dice_list_case))) def resample_3d(img, target_size): imx, imy, imz = img.shape tx, ty, tz = target_size zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz)) import scipy.ndimage as ndimage img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False) return img_resampled def dice(x, y): intersect = np.sum(np.sum(np.sum(x * y))) y_sum = np.sum(np.sum(np.sum(y))) if y_sum == 0: return 0.0 x_sum = np.sum(np.sum(np.sum(x))) return 2 * intersect / (x_sum + y_sum) def compute_avg_metric(metric, meters, metric_name, batch_size, args): assert len(metric.shape) == 2 if args.dataset == 'btcv': # cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0)) cls_avg_metric = np.mean(np.ma.masked_invalid(np.nanmean(metric, axis=0))) # cls8_avg_metric = np.nanmean(np.nanmean(metric[..., btcv_8cls_idx], axis=0)) #cls8_avg_metric = np.nanmean(np.ma.masked_invalid(np.nanmean(metric[..., btcv_8cls_idx], axis=0))) meters[metric_name].update(value=cls_avg_metric, n=batch_size) #meters[f'cls8_{metric_name}'].update(value=cls8_avg_metric, n=batch_size) else: cls_avg_metric = np.nanmean(np.nanmean(metric, axis=0)) meters[metric_name].update(value=cls_avg_metric, n=batch_size) if __name__ == '__main__': args = get_conf() args.test = True args.num_classes = 2 test_example = Test(args) test_example.evaluate()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi!Dear Developers!
Here is my test code, please ask me if I wrote it correctly?
The text was updated successfully, but these errors were encountered: