diff --git a/README.md b/README.md index fd325f0..573c4dc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,31 @@ -# ssl_for_MFR -TODO + +``` +python single_train.py --arch=MsvNetLite --num_train=1 --pretrain=pretrain\PPLCNet_x1_0_ssld_pretrained.pth --data_aug + +python single_train.py --arch=MsvNetLite --num_train=1 --simsiam_pretrain=pretrain\MsvNetLite1.0_simsiam_pretrain_BS32_6000\checkpoint_0099.pth.tar --lr_sch=cos --data_aug --base_lr=30 + +python single_train.py --arch=MsvNetLite --num_train=1 --simsiam_pretrain=pretrain\MsvNetLite1.0_simsiam_pretrain_BS32_6000\checkpoint_0099.pth.tar --base_lr=30 --optim=sgdm --weight_decay=1e-6 --lr_sch=cos --freeze + +python single_train.py --arch=FeatureNetLite --num_train=1 --simsiam_pretrain=pretrain\FeatureNetLite1.0_simsiam_pretrain_BS32_6000_40e\checkpoint_0039.pth.tar --weight_decay=1e-4 --lr_sch=cos --data_aug --base_lr=20 + +#### + +python single_train.py --arch=MsvNetLite --simsiam_pretrain=pretrain\MsvNetLite1.0_simsiam_pretrain_BS32_6000_200e\checkpoint_0199.pth.tar --program_type=draw_TSNE + +python single_train.py --arch=MsvNetLite --program_type=draw_TSNE + +python single_train.py --arch=MsvNetLite --pretrain=pretrain\PPLCNet_x1_0_ssld_pretrained.pth --program_type=draw_TSNE + +python single_train.py --arch=MsvNetLite --model_path=output\2022_03_18_10_50_08\best_model.pth --program_type=draw_TSNE + +# supervised 2022_03_09_16_31_21 + + +#### +# MsvNetLite +python single_train.py --arch=MsvNetLite --model_path=output\2022_03_17_12_00_03\best_model.pth --program_type=draw_ROC_CM + +# FeatureNetLite + +``` \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..6ed6628 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,9 @@ +''' +Author: whj +Date: 2022-02-15 15:53:55 +LastEditors: whj +LastEditTime: 2022-02-17 11:52:47 +Description: file content +''' + +# coding: utf-8 \ No newline at end of file diff --git a/core/dataset/__init__.py b/core/dataset/__init__.py new file mode 100644 index 0000000..7796a26 --- /dev/null +++ b/core/dataset/__init__.py @@ -0,0 +1,12 @@ +''' +Author: whj +Date: 2022-02-15 15:53:55 +LastEditors: whj +LastEditTime: 2022-03-04 13:24:46 +Description: file content +''' + +# coding: utf-8 + +from .datasets import FeatureDataset, SimsiamDataset, createPartition + diff --git a/core/dataset/augmentations.py b/core/dataset/augmentations.py new file mode 100644 index 0000000..c58b181 --- /dev/null +++ b/core/dataset/augmentations.py @@ -0,0 +1,240 @@ + +# coding: utf-8 + +import os +from pathlib import Path +import random + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data.dataset import Dataset +import torch.backends.cudnn as cudnn + +import numpy as np +import cupy as cp +import cupyx.scipy +import cupyx.scipy.ndimage +import cupyx +from PIL import Image +from scipy import ndimage + + +def randomRotation(sample): + rotation = random.randint(0, 23) + if rotation == 1: + sample = cp.rot90(sample, 1, (1, 2)) + elif rotation == 2: + sample = cp.rot90(sample, 2, (1, 2)) + elif rotation == 3: + sample = cp.rot90(sample, 1, (2, 1)) + elif rotation == 4: + sample = cp.rot90(sample, 1, (0, 1)) + elif rotation == 5: + sample = cp.rot90(sample, 1, (0, 1)) + sample = cp.rot90(sample, 1, (1, 2)) + elif rotation == 6: + sample = cp.rot90(sample, 1, (0, 1)) + sample = cp.rot90(sample, 2, (1, 2)) + elif rotation == 7: + sample = cp.rot90(sample, 1, (0, 1)) + sample = cp.rot90(sample, 1, (2, 1)) + elif rotation == 8: + sample = cp.rot90(sample, 1, (1, 0)) + elif rotation == 9: + sample = cp.rot90(sample, 1, (1, 0)) + sample = cp.rot90(sample, 1, (1, 2)) + elif rotation == 10: + sample = cp.rot90(sample, 1, (1, 0)) + sample = cp.rot90(sample, 2, (1, 2)) + elif rotation == 11: + sample = cp.rot90(sample, 1, (1, 0)) + sample = cp.rot90(sample, 1, (2, 1)) + elif rotation == 12: + sample = cp.rot90(sample, 2, (1, 0)) + elif rotation == 13: + sample = cp.rot90(sample, 2, (1, 0)) + sample = cp.rot90(sample, 1, (1, 2)) + elif rotation == 14: + sample = cp.rot90(sample, 2, (1, 0)) + sample = cp.rot90(sample, 2, (1, 2)) + elif rotation == 15: + sample = cp.rot90(sample, 2, (1, 0)) + sample = cp.rot90(sample, 1, (2, 1)) + elif rotation == 16: + sample = cp.rot90(sample, 1, (0, 2)) + elif rotation == 17: + sample = cp.rot90(sample, 1, (0, 2)) + sample = cp.rot90(sample, 1, (1, 2)) + elif rotation == 18: + sample = cp.rot90(sample, 1, (0, 2)) + sample = cp.rot90(sample, 2, (1, 2)) + elif rotation == 19: + sample = cp.rot90(sample, 1, (0, 2)) + sample = cp.rot90(sample, 1, (2, 1)) + elif rotation == 20: + sample = cp.rot90(sample, 1, (2, 0)) + elif rotation == 21: + sample = cp.rot90(sample, 1, (2, 0)) + sample = cp.rot90(sample, 1, (1, 2)) + elif rotation == 22: + sample = cp.rot90(sample, 1, (2, 0)) + sample = cp.rot90(sample, 2, (1, 2)) + elif rotation == 23: + sample = cp.rot90(sample, 1, (2, 0)) + sample = cp.rot90(sample, 1, (2, 1)) + + return sample + + +def randomScaleCrop(sample): + resolution = int(sample.shape[0]) + strategy = random.randint(0, 9) + if strategy == 0: + factor = random.uniform(1.0625, 1.25) + sample = ndimage.zoom(sample, factor, order=0) + startx = random.randint(0, sample.shape[0] - resolution) + starty = random.randint(0, sample.shape[1] - resolution) + startz = random.randint(0, sample.shape[2] - resolution) + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + elif strategy == 1: + factor = random.uniform(0.9375, 0.75) + sample = ndimage.zoom(sample, factor, order=0) + padxwl = random.randint(0, resolution - sample.shape[0]) + padxwr = resolution - padxwl - sample.shape[0] + padywl = random.randint(0, resolution - sample.shape[1]) + padywr = resolution - padywl - sample.shape[1] + padzwl = random.randint(0, resolution - sample.shape[2]) + padzwr = resolution - padzwl - sample.shape[2] + sample = np.pad(sample, ((padxwl, padxwr), + (padywl, padywr), (padzwl, padzwr)), mode='edge') + elif strategy == 2: + padr = int(resolution/8) + loc = 2*padr + startx = random.randint(0, loc) + starty = padr + startz = padr + sample = np.pad(sample, ((padr, padr), (padr, padr), + (padr, padr)), mode='edge') + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + elif strategy == 3: + padr = int(resolution/8) + loc = 2*padr + startx = padr + starty = random.randint(0, loc) + startz = padr + sample = np.pad(sample, ((padr, padr), (padr, padr), + (padr, padr)), mode='edge') + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + elif strategy == 4: + padr = int(resolution/8) + loc = 2*padr + startx = padr + starty = padr + startz = random.randint(0, loc) + sample = np.pad(sample, ((padr, padr), (padr, padr), + (padr, padr)), mode='edge') + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + + return sample + + +def randomScale(sample): + resolution = int(sample.shape[0]) + strategy = random.randint(0, 2) + if strategy == 0: + factor = random.uniform(1.0625, 1.1) + sample = ndimage.zoom(sample, factor, order=0) + startx = random.randint(0, sample.shape[0] - resolution) + starty = random.randint(0, sample.shape[1] - resolution) + startz = random.randint(0, sample.shape[2] - resolution) + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + elif strategy == 1: + factor = random.uniform(0.9375, 0.75) + sample = ndimage.zoom(sample, factor, order=0) + padxwl = random.randint(0, resolution - sample.shape[0]) + padxwr = resolution - padxwl - sample.shape[0] + padywl = random.randint(0, resolution - sample.shape[1]) + padywr = resolution - padywl - sample.shape[1] + padzwl = random.randint(0, resolution - sample.shape[2]) + padzwr = resolution - padzwl - sample.shape[2] + sample = np.pad(sample, ((padxwl, padxwr), + (padywl, padywr), (padzwl, padzwr)), mode='edge') + + return sample + + +def randomPadCrop(sample): + resolution = int(sample.shape[0]) + strategy = random.randint(0, 3) + if strategy == 0: + padr = int(resolution/8) + loc = 2*padr + startx = random.randint(0, loc) + starty = padr + startz = padr + sample = np.pad(sample, ((padr, padr), (padr, padr), + (padr, padr)), mode='edge') + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + elif strategy == 1: + padr = int(resolution/8) + loc = 2*padr + startx = padr + starty = random.randint(0, loc) + startz = padr + sample = np.pad(sample, ((padr, padr), (padr, padr), + (padr, padr)), mode='edge') + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + elif strategy == 2: + padr = int(resolution/8) + loc = 2*padr + startx = padr + starty = padr + startz = random.randint(0, loc) + sample = np.pad(sample, ((padr, padr), (padr, padr), + (padr, padr)), mode='edge') + sample = sample[startx:startx+resolution, + starty:starty+resolution, startz:startz+resolution] + + return sample + + + +def cutout3D(sample): + # parameters + max_holes = 3 + max_cutout_size = 12 + # the random number of holes + holes = random.randint(0, max_holes) + if holes == 0: + return sample + # cutout + resolution = int(sample.shape[0]) + for n in range(max_holes): + y = np.random.randint(resolution) + x = np.random.randint(resolution) + z = np.random.randint(resolution) + + sizey = np.random.randint(4, max_cutout_size) + sizex = np.random.randint(4, max_cutout_size) + sizez = np.random.randint(4, max_cutout_size) + + y1 = np.clip(y - sizey // 2, 0, resolution) + y2 = np.clip(y + sizey // 2, 0, resolution) + x1 = np.clip(x - sizex // 2, 0, resolution) + x2 = np.clip(x + sizex // 2, 0, resolution) + z1 = np.clip(z - sizez // 2, 0, resolution) + z2 = np.clip(z + sizez // 2, 0, resolution) + + sample[y1: y2, x1: x2, z1: z2] = 0 + + return sample + + diff --git a/core/dataset/datasets.py b/core/dataset/datasets.py new file mode 100644 index 0000000..bca75ef --- /dev/null +++ b/core/dataset/datasets.py @@ -0,0 +1,235 @@ + +# coding: utf-8 + +import os +from pathlib import Path +import random + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data.dataset import Dataset +import torch.backends.cudnn as cudnn + +import numpy as np +import cupy as cp +import cupyx.scipy +import cupyx.scipy.ndimage +import cupyx +from PIL import Image +from scipy import ndimage + +from ..utils import read_as_3d_array +from .augmentations import * + + + +class Object3DTo2D(object): + def __init__(self, img_num, resolution): + self.img_num = img_num + self.tranform = transforms.Compose([transforms.ToPILImage(), + transforms.Resize( + (resolution, resolution), + interpolation=transforms.InterpolationMode.NEAREST), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __call__(self, obj3d): + imgs = [] + for _ in range(self.img_num): + minres = min(obj3d.shape[0], obj3d.shape[1], obj3d.shape[2]) + proj_dir = random.randint(0, 1) + sel_axis = random.randint(0, 2) + sel_idx = random.randint(1, minres - 2) + if sel_axis == 0: + if proj_dir == 0: + img = cp.mean(obj3d[sel_idx:, :, :], sel_axis) + else: + img = cp.mean(obj3d[:sel_idx, :, :], sel_axis) + elif sel_axis == 1: + if proj_dir == 0: + img = cp.mean(obj3d[:, sel_idx:, :], sel_axis) + else: + img = cp.mean(obj3d[:, :sel_idx, :], sel_axis) + elif sel_axis == 2: + if proj_dir == 0: + img = cp.mean(obj3d[:, :, sel_idx:], sel_axis) + else: + img = cp.mean(obj3d[:, :, :sel_idx], sel_axis) + + img = torch.from_numpy(img).float() + img = img.expand(3, img.shape[0], img.shape[1]) # convert to rgb chanels + img = self.tranform(img) + imgs.append(img) + + return torch.stack(imgs) + + +class dataAugmentation(object): + def __init__(self, ops=None): + self.ops = ops + + def __call__(self, sample): + if self.ops is not None: + for op in self.ops: + sample = eval(op)(sample) + return sample + + +class FeatureDataset(Dataset): + def __init__(self, list_IDs, resolution, output_type='3d', num_cuts=12, data_augmentation=None): + self.list_IDs = list_IDs + self.resolution = resolution + self.output_type = output_type + self.num_cuts = num_cuts + self.data_augmentation = dataAugmentation(data_augmentation) + self.createImgs = Object3DTo2D(self.num_cuts, self.resolution) + + def __len__(self): + return len(self.list_IDs) + + def __getitem__(self, index): + idx = index + ID = self.list_IDs[idx][0] + rotation = self.list_IDs[idx][1] + + filename = ID + '.binvox' + filepath = os.path.join('data', os.path.join(str(self.resolution), filename)) + with open(filepath, 'rb') as f: + sample = read_as_3d_array(f).data + + if rotation == 1: + sample = cp.rot90(sample, 2, (0, 1)) + elif rotation == 2: + sample = cp.rot90(sample, 1, (0, 1)) + elif rotation == 3: + sample = cp.rot90(sample, 1, (1, 0)) + elif rotation == 4: + sample = cp.rot90(sample, 1, (2, 0)) + elif rotation == 5: + sample = cp.rot90(sample, 1, (0, 2)) + + sample = self.data_augmentation(sample) + label = int(ID.split('_')[0]) + + if self.output_type == '3d': + sample = np.expand_dims(sample, axis=0) + sample = torch.from_numpy(sample.copy()).float() + sample = 2.0 * (sample - 0.5) + elif self.output_type == '2d_multiple': + sample = self.createImgs(sample) + elif self.output_type == '2d_single': + sample = self.createImgs(sample) + label = torch.zeros(self.num_cuts, dtype=torch.int64) + label + + return sample, label + + +def createPartition(data_path, num_classes = 24, resolution=16, num_train=30, num_val_test=30): + counter = np.zeros(num_classes, np.int64) + partition = {} + for i in range(num_classes): + partition['train', i] = [] + partition['val', i] = [] + partition['test', i] = [] + + with open(os.devnull, 'w') as devnull: + path = Path(os.path.join(data_path, str(resolution))) + for filename in sorted(path.glob('*.binvox')): + namelist = os.path.basename(filename).split('_') + label = int(namelist[0]) + counter[label] += 1 + + items = [] + for i in range(6): + items.append((os.path.basename(filename).split('.')[0], i)) + + if counter[label] % 10 < 8: + partition['train', label] += items + elif counter[label] % 10 == 8: + partition['val', label] += items + elif counter[label] % 10 == 9: + partition['test', label] += items + + ret = {} + ret['train'] = [] + ret['val'] = [] + ret['test'] = [] + + for i in range(num_classes): + random.shuffle(partition['train', i]) + random.shuffle(partition['val', i]) + random.shuffle(partition['test', i]) + + ret['train'] += partition['train', i][0:num_train] + ret['val'] += partition['val', i][0:num_val_test] + ret['test'] += partition['test', i][0:num_val_test] + + random.shuffle(ret['train']) + random.shuffle(ret['val']) + random.shuffle(ret['test']) + + return ret + + +class SimsiamDataset(Dataset): + def __init__(self, data_path, num_classes = 24, num_train=4800, resolution=64, output_type='3d', num_cuts=12, transform=None): + self.resolution = resolution + self.output_type = output_type + self.num_cuts = num_cuts + self.transform = dataAugmentation(transform) + self.createImgs = Object3DTo2D(self.num_cuts, self.resolution) + + self.list_IDs = [] + with open(os.devnull, 'w') as devnull: + path = Path(os.path.join(data_path, str(resolution))) + for filename in sorted(path.glob('*.binvox')): + items = [] + for i in range(6): + items.append((os.path.basename(filename).split('.')[0], i)) + self.list_IDs.extend(items) + + random.shuffle(self.list_IDs) + total_train = num_classes * num_train + assert total_train <= len(self.list_IDs) + self.list_IDs = self.list_IDs[0:total_train] + + def __len__(self): + return len(self.list_IDs) + + def __getitem__(self, index): + idx = index + ID = self.list_IDs[idx][0] + rotation = self.list_IDs[idx][1] + + filename = ID + '.binvox' + filepath = os.path.join('data', os.path.join(str(self.resolution), filename)) + with open(filepath, 'rb') as f: + sample = read_as_3d_array(f).data + + if rotation == 1: + sample = cp.rot90(sample, 2, (0, 1)) + elif rotation == 2: + sample = cp.rot90(sample, 1, (0, 1)) + elif rotation == 3: + sample = cp.rot90(sample, 1, (1, 0)) + elif rotation == 4: + sample = cp.rot90(sample, 1, (2, 0)) + elif rotation == 5: + sample = cp.rot90(sample, 1, (0, 2)) + + q = self.transform(sample) + k = self.transform(sample) + + if self.output_type == '3d': + q, k = np.expand_dims(q, axis=0), np.expand_dims(k, axis=0) + q, k = torch.from_numpy(q.copy()).float(), torch.from_numpy(k.copy()).float() + q, k = 2.0 * (q - 0.5), 2.0 * (k - 0.5) + elif self.output_type == '2d_multiple': + q, k = self.createImgs(q), self.createImgs(k) + elif self.output_type == '2d_single': + q, k = self.createImgs(q), self.createImgs(k) + + return [q, k] diff --git a/core/engine/__init__.py b/core/engine/__init__.py new file mode 100644 index 0000000..6ed6628 --- /dev/null +++ b/core/engine/__init__.py @@ -0,0 +1,9 @@ +''' +Author: whj +Date: 2022-02-15 15:53:55 +LastEditors: whj +LastEditTime: 2022-02-17 11:52:47 +Description: file content +''' + +# coding: utf-8 \ No newline at end of file diff --git a/core/engine/multi.py b/core/engine/multi.py new file mode 100644 index 0000000..22812ac --- /dev/null +++ b/core/engine/multi.py @@ -0,0 +1,462 @@ +import os +import random +from pathlib import Path + +from skimage import measure +from scipy import ndimage as ndi +from skimage.segmentation import watershed +import selectivesearch +import cupy as cp +import numpy as np +from scipy.special import softmax +from PIL import Image + +import torch +import torch.nn as nn +from torchvision import transforms, utils + +from ..utils import vis +from ..utils import binvox_rw + + +def load_3dmodel(fidx, sidx): + modelfilename = 'data/set' + str(sidx) + '/'+str(fidx) + '.binvox' + + with open(modelfilename, 'rb') as f: + sample = binvox_rw.read_as_3d_array(f).data + + return sample + + +def featurenet_segmentation(sample): + blobs = ~sample + final_labels = np.zeros(blobs.shape) + all_labels = measure.label(blobs) + + for i in range(1, np.max(all_labels)+1): + mk = (all_labels == i) + distance = ndi.distance_transform_edt(mk) + + labels = watershed(-distance) + + max_val = np.max(final_labels)+1 + idx = np.where(mk) + + final_labels[idx] += (labels[idx] + max_val) + + results = get_seg_samples(final_labels) + results = 2*(results - 0.5) + + return results + + +def get_seg_samples(labels): + samples = np.zeros((0, labels.shape[0], labels.shape[1], labels.shape[2])) + + for i in range(1, np.max(labels.astype(int))+1): + idx = np.where(labels == i) + + if len(idx[0]) == 0: + continue + + cursample = np.ones(labels.shape) + cursample[idx] = 0 + cursample = np.expand_dims(cursample, axis=0) + samples = np.append(samples, cursample, axis=0) + + return samples + + +def soft_nms_pytorch(samples, box_scores, sigma=0.028): + N = samples.shape[0] + dets = np.zeros((N, 6)) + + for i in range(N): + idx = np.where(samples[i, :, :, :] == 0) + # print(idx) + dets[i, 0] = idx[2].min() + dets[i, 1] = idx[1].min() + dets[i, 2] = idx[0].min() + dets[i, 3] = idx[2].max() + dets[i, 4] = idx[1].max() + dets[i, 5] = idx[0].max() + + indexes = torch.arange(0, N, dtype=torch.double).view(N, 1) + dets = torch.from_numpy(dets).double() + box_scores = torch.from_numpy(box_scores).double() + + dets = torch.cat((dets, indexes), dim=1) + + z1 = dets[:, 0] + y1 = dets[:, 1] + x1 = dets[:, 2] + z2 = dets[:, 3] + y2 = dets[:, 4] + x2 = dets[:, 5] + scores = box_scores + areas = (x2 - x1 + 1) * (y2 - y1 + 1) * (z2 - z1 + 1) + + for i in range(N): + tscore = scores[i].clone() + pos = i + 1 + + if i != N - 1: + maxscore, maxpos = torch.max(scores[pos:], dim=0) + if tscore < maxscore: + dets[i], dets[maxpos.item() + i + 1] = dets[maxpos.item() + + i + 1].clone(), dets[i].clone() + scores[i], scores[maxpos.item() + i + 1] = scores[maxpos.item() + + i + 1].clone(), scores[i].clone() + areas[i], areas[maxpos + i + 1] = areas[maxpos + + i + 1].clone(), areas[i].clone() + + # IoU calculate + zz1 = np.maximum(dets[i, 0].to("cpu").numpy(), + dets[pos:, 0].to("cpu").numpy()) + yy1 = np.maximum(dets[i, 1].to("cpu").numpy(), + dets[pos:, 1].to("cpu").numpy()) + xx1 = np.maximum(dets[i, 2].to("cpu").numpy(), + dets[pos:, 2].to("cpu").numpy()) + zz2 = np.minimum(dets[i, 3].to("cpu").numpy(), + dets[pos:, 3].to("cpu").numpy()) + yy2 = np.minimum(dets[i, 4].to("cpu").numpy(), + dets[pos:, 4].to("cpu").numpy()) + xx2 = np.minimum(dets[i, 5].to("cpu").numpy(), + dets[pos:, 5].to("cpu").numpy()) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + l = np.maximum(0.0, zz2 - zz1 + 1) + inter = torch.tensor(w * h * l) + ovr = torch.div(inter, (areas[i] + areas[pos:] - inter)) + + # Gaussian decay + weight = torch.exp(-(ovr * ovr) / sigma) + + scores[pos:] = weight * scores[pos:] + + max_margin = 0 + for i in range(scores.shape[0]-1): + if scores[i] - scores[i + 1] > max_margin: + max_margin = scores[i] - scores[i + 1] + thresh = (scores[i] + scores[i+1])/2 + + keep = dets[:, 6][scores > thresh].int() + + return keep.to("cpu").numpy() + + +def test_msvnet(sidx): + num_cuts = 12 + + random.seed(213) + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + + net = torch.load('models/msvnet.pt') + net.eval() + + predictions = np.zeros(24) + truelabels = np.zeros(24) + truepositives = np.zeros(24) + + with torch.no_grad(): + with open(os.devnull, 'w') as devnull: + for filename in Path('data/set' + str(sidx) + '/').glob('*.STL'): + + filename = os.path.basename(filename) + idx = int(os.path.splitext(filename)[0]) + + sample = load_3dmodel(idx, sidx) + segs = msvnet_segmentation(sample) + inputs = get_msv_samples(segs, num_cuts) + + inputs = inputs.to(device) + outputs = net(inputs) + m = nn.Softmax(dim=1) + outputs = m(outputs) + + val, predicted = outputs.max(1) + val = val.cpu().numpy() + predicted = predicted.cpu().numpy() + + keepidx = soft_nms_pytorch(segs, val) + predicted = predicted[keepidx] + + pred = get_lvec(predicted).astype(int) + trul = get_lvec(vis.get_label( + 'data/set'+str(sidx)+'/'+str(idx)+'.csv', 0)[:, 6]).astype(int) + + print(idx) + print('Predicted labels:\t', pred) + print('True labels:\t\t', trul) + + tp = np.minimum(pred, trul) + + predictions += pred + truelabels += trul + truepositives += tp + precision, recall = eval_metric(predictions, truelabels, truepositives) + + return precision.mean(), recall.mean() + + +def cal_segs(img_org, rotation): + img = 1-img_org + + img_lbl, regions = selectivesearch.selective_search( + img, scale=500, sigma=0.8, min_size=10) + + samples = np.zeros((0, 64, 64, 64)) + + candidates = set() + for r in regions: + x, y, w, h = r['rect'] + w += 1 + h += 1 + + labels = r['labels'] + idx = np.where(img_lbl[y:y+h, x:x+w, 3] == labels[0]) + + for i in range(1, len(labels)): + idx2 = np.where(img_lbl[y:y+h, x:x+w, 3] == labels[i]) + idx = (np.append(idx[0], idx2[0], axis=0), + np.append(idx[1], idx2[1], axis=0)) + + tmpimg = np.zeros(img_lbl[y:y+h, x:x+w, 0].shape) + tmpimg[idx] = 1 + idx = np.where(img_lbl[y:y+h, x:x+w, 0] == 1) + tmpimg[idx] = 0 + + selval = tmpimg.sum() + maskval = h*w-idx[0].shape[0] + + if maskval <= 0 or selval/maskval < 0.5: + continue + + idx = np.where(tmpimg == 1) + minx = idx[1].min() + maxx = idx[1].max() + miny = idx[0].min() + maxy = idx[0].max() + w = maxx - minx + 1 + h = maxy - miny + 1 + x += minx + y += miny + + r['rect'] = (x, y, w, h) + + if r['rect'] in candidates: + continue + + if w <= 0 or h <= 0 or w/h <= 0.1 or h/w <= 0.1 or w < 6 or h < 6: + continue + + tmpimg = tmpimg[miny:miny+h, minx:minx+w] + + all_labels = measure.label(tmpimg) + if all_labels.max() >= 2: + continue + + cursample = np.ones((64, 64, 64)) + + for i in range(y, y+h): + for j in range(x, x+w): + if tmpimg[i-y, j-x] == 1: + depth = int(64*img_org[i, j, 0]) + cursample[0:depth, i, j] = 0 + + cursample = rotate_sample(cursample, rotation, True) + + cursample = np.expand_dims(cursample, axis=0) + samples = np.append(samples, cursample, axis=0) + + candidates.add(r['rect']) + + return samples + + +def create_img(obj3d): + img = np.zeros((obj3d.shape[1], obj3d.shape[2])) + + for i in range(img.shape[0]): + for j in range(img.shape[1]): + for d in range(obj3d.shape[0]): + + if obj3d[d, i, j] == True: + img[i, j] = d/obj3d.shape[0] + break + + if d == obj3d.shape[0] - 1: + img[i, j] = 1 + break + + img = np.stack((img,)*3, axis=-1) + + return img + + +def rotate_sample(sample, rotation, reverse=False): + if reverse: + if rotation == 1: + sample = cp.rot90(sample, -2, (0, 1)).copy() + elif rotation == 2: + sample = cp.rot90(sample, -1, (0, 1)).copy() + elif rotation == 3: + sample = cp.rot90(sample, -1, (1, 0)).copy() + elif rotation == 4: + sample = cp.rot90(sample, -1, (2, 0)).copy() + elif rotation == 5: + sample = cp.rot90(sample, -1, (0, 2)).copy() + else: + if rotation == 1: + sample = cp.rot90(sample, 2, (0, 1)).copy() + elif rotation == 2: + sample = cp.rot90(sample, 1, (0, 1)).copy() + elif rotation == 3: + sample = cp.rot90(sample, 1, (1, 0)).copy() + elif rotation == 4: + sample = cp.rot90(sample, 1, (2, 0)).copy() + elif rotation == 5: + sample = cp.rot90(sample, 1, (0, 2)).copy() + + return sample + + +def msvnet_segmentation(sample): + allsamples = np.zeros( + (0, sample.shape[0], sample.shape[1], sample.shape[2])) + + for i in range(0, 6): + cursample = sample.copy() + cursample = rotate_sample(cursample, i) + + img = create_img(cursample) + cursample = cal_segs(img, i) + allsamples = np.append(allsamples, cursample, axis=0) + + return allsamples + + +# def test_featurenet(sidx): +# resolution = 64 + +# x=tf.placeholder(tf.float32,shape=[None,resolution,resolution,resolution,1]) +# output_layer = lfmd.inference2(x) + + +# saver = tf.train.Saver() + +# predictions = np.zeros(24) +# truelabels = np.zeros(24) +# truepositives = np.zeros(24) + + +# with tf.Session() as sess: +# saver.restore(sess, "models/featurenet.ckpt") +# print("Model restored.") + +# with open(os.devnull, 'w') as devnull: +# for filename in Path('data/set' + str(sidx)+ '/').glob('*.STL'): + +# filename = os.path.basename(filename) +# idx = int(os.path.splitext(filename)[0]) + +# sample = load_3dmodel(idx,sidx) +# segs = featurenet_segmentation(sample) + +# inputs = segs +# temp = output_layer.eval({x: inputs.reshape(-1, resolution,resolution,resolution,1)}) + +# outputs = softmax(temp,axis=1) + +# pred = get_lvec(outputs.argmax(1)).astype(int) +# trul = get_lvec(vis.get_label('data/set'+str(sidx)+'/'+str(idx)+'.csv',0)[:,6]).astype(int) + +# print(idx) +# print('Predicted labels:\t',pred) +# print('True labels:\t\t',trul) + +# tp = np.minimum(pred,trul) + +# predictions += pred +# truelabels += trul +# truepositives += tp + +# precision, recall = eval_metric(predictions,truelabels,truepositives) +# return precision.mean(), recall.mean() + + +def create_sectional_view(obj3d): + minres = min(obj3d.shape[0], obj3d.shape[1], obj3d.shape[2]) + proj_dir = random.randint(0, 1) + sel_axis = random.randint(0, 2) + sel_idx = random.randint(1, minres - 2) + + if sel_axis == 0: + if proj_dir == 0: + img = cp.mean(obj3d[sel_idx:, :, :], sel_axis) + else: + img = cp.mean(obj3d[:sel_idx, :, :], sel_axis) + elif sel_axis == 1: + if proj_dir == 0: + img = cp.mean(obj3d[:, sel_idx:, :], sel_axis) + else: + img = cp.mean(obj3d[:, :sel_idx, :], sel_axis) + elif sel_axis == 2: + if proj_dir == 0: + img = cp.mean(obj3d[:, :, sel_idx:], sel_axis) + else: + img = cp.mean(obj3d[:, :, :sel_idx], sel_axis) + + img = torch.from_numpy(img).float() + img = img.expand(3, img.shape[0], img.shape[1]) # convert to rgb chanels + + trans = transforms.Compose([transforms.ToPILImage(), + transforms.Resize( + (64, 64), interpolation=Image.NEAREST), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ + 0.229, 0.224, 0.225]) + ]) + img = trans(img) + return img + + +def create_sectional_views(obj3d, img_num): + imgs = [] + for i in range(img_num): + imgs.append(create_sectional_view(obj3d)) + return torch.stack(imgs) + + +#input: n*64*64*64 +#output: n*12*3*64*64 +def get_msv_samples(samples, num_cuts): + batch_size = samples.shape[0] + results = torch.zeros((batch_size, num_cuts, 3, 64, 64)) + + for i in range(batch_size): + results[i] = create_sectional_views(samples[i], num_cuts) + return results + + +def dense_to_one_hot(labels_dense, num_classes=24): + """Convert class labels from scalars to one-hot vectors""" + num_labels = labels_dense.shape[0] + index_offset = np.arange(num_labels) * num_classes + labels_one_hot = np.zeros((num_labels, num_classes)) + labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 + return labels_one_hot + + +def get_lvec(labels): + results = np.zeros(24) + for i in labels: + results[int(i)] += 1 + return results + + +def eval_metric(pre, trul, tp): + precision = tp/pre + recall = tp/trul + return precision, recall diff --git a/core/engine/trainer.py b/core/engine/trainer.py new file mode 100644 index 0000000..b794832 --- /dev/null +++ b/core/engine/trainer.py @@ -0,0 +1,659 @@ +from random import random +import time +import os +import warnings +from tqdm import tqdm +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +from torchsummary import summary +from torch.utils.tensorboard import SummaryWriter + +from thop import profile +from warmup_scheduler import GradualWarmupScheduler + +from ..dataset import FeatureDataset +from ..dataset import createPartition +from ..models import FeatureNet, FeatureNetLite, SCNN, MCNN, MsvNetLite, BaselineNet, BaselineNet2, VoxNet +from ..utils import setup_logger +from ..utils import LayerOutHook +from ..utils import plot_with_labels, plot3D_with_labels + + +class Trainer: + def __init__(self, cfg): + self.base_lr = cfg.base_lr + self.num_train = cfg.num_train + self.num_val_test = cfg.num_val_test + self.arch = cfg.arch + self.train_epochs = cfg.train_epochs + self.resolution = cfg.resolution + self.num_cuts = cfg.num_cuts # obsolete + if cfg.train_batchsize > cfg.num_train * cfg.num_of_class: + # batch szie should be smaller than the number of samples + cfg.train_batchsize = cfg.num_train * cfg.num_of_class + self.train_bs = cfg.train_batchsize + self.val_bs = cfg.val_batchsize + self.weight_decay = cfg.weight_decay + self.warmup_epochs = cfg.warmup_epochs + self.lr_sch = cfg.lr_sch + self.optim = cfg.optim + self.data_aug = cfg.data_aug + self.pretrained = cfg.pretrained + self.simsiam_pretrained = cfg.simsiam_pretrained + self.freeze = cfg.freeze + self.val_interval = cfg.val_interval + self.data_path = cfg.data_path + self.num_of_class = cfg.num_of_class + self.workers = cfg.workers + self.optimizer = None + self.scheduler = None + self.output_type = '3d' + self.model_path = cfg.model_path + self.shapetypes = ['O ring', 'Through hole', 'Blind hole', 'Triangular passage', 'Rectangular passage', + 'Circular through slot', 'Triangular through slot', 'Rectangular through slot', + 'Rectangular blind slot', 'Triangular pocket', 'Rectangular pocket', 'Circular end pocket', + 'Triangular blind step', 'Circular blind step', 'Rectangular blind step', 'Rectangular through step', + '2-sides through step', 'Slanted through step', 'Chamfer', 'Round', 'Vertical circular end blind slot', + 'Horizontal circular end blind slot', '6-sides passage', '6-sides pocket' + ] + + # training output directory + time_str = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + self.output_dir = os.path.join(cfg.output_dir, time_str) + self.summary_writer = SummaryWriter( + os.path.join(self.output_dir, 'SummaryWriter')) + self.logger = setup_logger(os.path.join( + self.output_dir, 'log.txt'), 'FeatureRecognition') + + # print and log config + self.logger.info('Training config:') + for s in str(cfg).split(','): + self.logger.info(s) + print(s) + + # set device + if cfg.device == 'gpu': + self.device = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu") + elif cfg.device == "cpu": + self.device = torch.device("cpu") + + # data augmentation ops + augmentations = None + if self.data_aug == True: + if self.simsiam_pretrained is not None: + augmentations = ['randomRotation'] + else: + augmentations = ['randomRotation', 'randomScaleCrop'] + s = 'Data augmentation is enabled: ' + for op in augmentations: + s += op + ' ' + self.logger.info(s) + print(s) + + # define the network + self.build_net() + + # create dataloaders + partition = createPartition(self.data_path, + self.num_of_class, + self.resolution, + self.num_train, + self.num_val_test) + + training_set = FeatureDataset(partition['train'], + resolution=self.resolution, + output_type=self.output_type, + num_cuts=self.num_cuts, + data_augmentation=augmentations) + self.trainloader = torch.utils.data.DataLoader(training_set, + batch_size=self.train_bs, + shuffle=True, + num_workers=self.workers, + pin_memory=True, + drop_last=False) + + val_set = FeatureDataset(partition['val'], + resolution=self.resolution, + output_type=self.output_type, + num_cuts=self.num_cuts, + data_augmentation=None) + self.validloader = torch.utils.data.DataLoader(val_set, + batch_size=self.val_bs, + shuffle=False, + num_workers=self.workers, + pin_memory=True, + drop_last=False) + + test_set = FeatureDataset(partition['test'], + resolution=self.resolution, + output_type=self.output_type, + num_cuts=self.num_cuts, + data_augmentation=None) + self.testloader = torch.utils.data.DataLoader(test_set, + batch_size=self.val_bs, + shuffle=False, + num_workers=self.workers, + pin_memory=True, + drop_last=False) + + # define the criterion and optimizer + self.criterion = nn.CrossEntropyLoss().to(self.device) + self.build_optim() + if self.lr_sch != 'constant': + self.build_lr_sch() + + def build_net(self): + if self.arch == 'FeatureNet': + self.net = FeatureNet(num_classes=self.num_of_class) + self.output_type = '3d' + elif self.arch == 'FeatureNetLite': + self.net = FeatureNetLite( + scale=1.0, num_classes=self.num_of_class, dropout_prob=0., class_expand=1280) + self.output_type = '3d' + elif self.arch == 'MsvNet': + scnn = SCNN(num_classes=self.num_of_class, pretraining=False) + self.net = MCNN( + model=scnn, num_classes=self.num_of_class, num_cuts=self.num_cuts) + self.output_type = '2d_multiple' + elif self.arch == 'MsvNetLite': + self.net = MsvNetLite(scale=1.0, num_classes=self.num_of_class, + dropout_prob=0., class_expand=1280, num_cuts=self.num_cuts) + self.output_type = '2d_multiple' + elif self.arch == 'BaselineNet': + ''' + paper: Identifying manufacturability and machining processes using deep 3D convolutional networks + ''' + self.net = BaselineNet(num_classes=self.num_of_class, input_shape=( + self.resolution, self.resolution, self.resolution)) + self.output_type = '3d' + elif self.arch == 'BaselineNet2': + ''' + paper: Part machining feature recognition based on a deep learning method + ''' + self.net = BaselineNet2(num_classes=self.num_of_class, input_shape=( + self.resolution, self.resolution, self.resolution)) + self.output_type = '3d' + elif self.arch == 'VoxNet': + self.net = VoxNet(num_classes=self.num_of_class, input_shape=( + self.resolution, self.resolution, self.resolution)) + self.output_type = '3d' + else: + raise ValueError('Invalid network type') + # if running on GPU and we want to use cuda move model there + self.net.to(self.device) + + def load_simsiam_pretrained_model(self): + # freeze all layers but the last classifier layer + if self.freeze: + for name, param in self.net.named_parameters(): + if name not in ['classifier.weight', 'classifier.bias']: + param.requires_grad = False + # init the classifier layer + self.net.classifier.weight.data.normal_(mean=0.0, std=0.01) + self.net.classifier.bias.data.zero_() + + if not self.freeze: + classifier_parameters, model_parameters = [], [] + for name, param in self.net.named_parameters(): + if name in {'classifier.weight', 'classifier.bias'}: + classifier_parameters.append(param) + else: + model_parameters.append(param) + # set different learning rate for the classifier + param_groups = [dict(params=classifier_parameters, lr=self.base_lr)] + param_groups.append(dict(params=model_parameters, lr=0.001)) + #self.optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=self.weight_decay) + self.optimizer = optim.Adam(param_groups, 0, weight_decay=self.weight_decay) + print(self.optimizer) + self.build_lr_sch() + + # load from pre-trained, before DistributedDataParallel constructor + if self.simsiam_pretrained: + if os.path.isfile(self.simsiam_pretrained): + print("=> loading checkpoint '{}'".format( + self.simsiam_pretrained)) + checkpoint = torch.load( + self.simsiam_pretrained, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith('encoder') and not k.startswith('encoder.classifier'): + # remove prefix + state_dict[k[len("encoder."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + # check whether pretrained parameters loading successfully + msg = self.net.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == { + "classifier.weight", "classifier.bias"} + + s = '=> pre-trained model loaded from {}'.format( + self.simsiam_pretrained) + print(s) + self.logger.info(s) + else: + s = "=> no checkpoint found at '{}'".format( + self.simsiam_pretrained) + print(s) + self.logger.info(s) + + def load_params(self, model_path): + if os.path.isfile(model_path): + s = "=> loading model params '{}'".format(model_path) + print(s) + self.logger.info(s) + checkpoint = torch.load(model_path) + try: + state_dict = checkpoint['state_dict'] + except: + state_dict = checkpoint + msg = self.net.load_state_dict(state_dict, strict=False) + print(msg) + + def build_lr_sch(self): + # TODO scheduler parameters + # multistepLR remember the real decay epoch = milestone + warmup epoch + if self.lr_sch == 'step': + self.scheduler = optim.lr_scheduler.StepLR( + self.optimizer, step_size=10, gamma=0.1) + elif self.lr_sch == 'multistep': + self.scheduler = optim.lr_scheduler.MultiStepLR( + self.optimizer, milestones=[60, 80], gamma=0.1) + elif self.lr_sch == 'exp': + self.scheduler = optim.lr_scheduler.ExponentialLR( + self.optimizer, gamma=0.99) + elif self.lr_sch == 'cos': + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, T_max=self.train_epochs, eta_min=0) + elif self.lr_sch == 'constant': + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, lambda epoch: self.base_lr) + else: + raise ValueError('Invalid lr_sch') + + if self.warmup_epochs > 0: + temp_scheduler = self.scheduler + self.scheduler = GradualWarmupScheduler(self.optimizer, + multiplier=1, + total_epoch=self.warmup_epochs, + after_scheduler=temp_scheduler) + + def build_optim(self): + # TODO optimizer parameters + if self.optim == 'adam': + self.optimizer = optim.Adam( + self.net.parameters(), lr=self.base_lr, weight_decay=self.weight_decay) + elif self.optim == 'sgdm': + self.optimizer = optim.SGD( + self.net.parameters(), lr=self.base_lr, momentum=0.9, weight_decay=self.weight_decay) + elif self.optim == 'rmsprop': + self.optimizer = optim.RMSprop( + self.net.parameters(), lr=self.base_lr, weight_decay=self.weight_decay) + elif self.optim == 'adamw': + self.optimizer = optim.AdamW( + self.net.parameters(), lr=self.base_lr, weight_decay=self.weight_decay) + else: + raise ValueError('Invalid optimizer') + + def train(self): + # load pretrained model + if self.simsiam_pretrained is not None and self.pretrained is None: + self.load_simsiam_pretrained_model() + + if self.pretrained is not None and self.simsiam_pretrained is None: + self.load_params(self.pretrained) + + # input data shape + if self.output_type == '3d': + input_shape = (1, self.resolution, + self.resolution, self.resolution) + input = torch.randn(1, 1, self.resolution, + self.resolution, self.resolution).to(self.device) + elif self.output_type == '2d_multiple': + input_shape = (self.num_cuts, 3, self.resolution, self.resolution) + input = torch.randn( + 1, self.num_cuts, 3, self.resolution, self.resolution).to(self.device) + + # netowrl structure description + summary(self.net, input_shape) + + # calc the network macs & params + macs, params = profile(self.net, inputs=(input, )) + macs_params_info = 'model size: MACS: {}G || Parameters: {}M'.format( + macs / 1.e9, params / 1.e6) + self.logger.info(macs_params_info) + print(macs_params_info) + + print('\n surpvised training the network with labeled data ...') + best_acc = 0. + pbar = tqdm(range(self.train_epochs)) + for epoch in pbar: + # train + start_time = time.time() + train_acc, train_loss = self.train_epoch() + train_epoch_time = time.time()-start_time + + # validate + if (epoch+1) % self.val_interval == 0: + start_time = time.time() + val_acc, val_loss = self.valtest_epoch(testval='val') + val_epoch_time = time.time()-start_time + + # save best model + if val_acc > best_acc: + best_acc = val_acc + torch.save({'state_dict': self.net.state_dict()}, + os.path.join(self.output_dir, 'best_model.pth')) + self.logger.info( + 'Best model saved. at epoch {} with val_acc {:.2f}%'.format(epoch+1, val_acc)) + + # update validate log + self.summary_writer.add_scalar( + 'train/val_acc', val_acc, epoch+1) + self.summary_writer.add_scalar( + 'train/val_loss', val_loss, epoch+1) + self.logger.info('Eval at epoch {} with val_acc {:.2f}%, val_loss {:.5f} took {:.2f}s'.format( + epoch+1, val_acc, val_loss, val_epoch_time)) + + cur_lr = self.optimizer.param_groups[0]['lr'] + self.summary_writer.add_scalar('train/lr', cur_lr, epoch+1) + + # update training log + self.summary_writer.add_scalar( + 'train/train_acc', train_acc, epoch+1) + self.summary_writer.add_scalar( + 'train/train_loss', train_loss, epoch+1) + log_str = "epoch: {}/{} | Train Time: {:.2f}s | LR: {:.8f} | train_acc: {:.2f}%, train_loss: {:.5f}".format( + epoch+1, self.train_epochs, train_epoch_time, cur_lr, train_acc, train_loss) + pbar.set_description(log_str) + self.logger.info(log_str) + + def train_epoch(self): + self.net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(self.trainloader): + inputs, targets = inputs.to(self.device), targets.to(self.device) + + outputs = self.net(inputs) + loss = self.criterion(outputs, targets) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + if self.scheduler is not None: + self.scheduler.step() + + return 100.*correct/total, train_loss + + def valtest_epoch(self, testval='val'): + if testval == 'val': + loader = self.validloader + elif testval == 'test': + loader = self.testloader + else: + raise ValueError('Invalid testval loader') + + self.net.eval() + val_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(loader): + inputs, targets = inputs.to( + self.device), targets.to(self.device) + + outputs = self.net(inputs) + loss = self.criterion(outputs, targets) + + val_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + return 100.*correct/total, val_loss + + def infer(self, inputs): + self.net.eval() + with torch.no_grad(): + inputs = inputs.to(self.device) + outputs = self.net(inputs).max(1)[1].cpu().numpy() + return outputs + + def estimate_infer_time(self): + if self.output_type == '3d': + dummy_input = torch.randn(1, 1, self.resolution, + self.resolution, self.resolution).to(self.device) + elif self.output_type == '2d_multiple': + dummy_input = torch.randn( + 1, self.num_cuts, 3, self.resolution, self.resolution).to(self.device) + + # warmup 20 times + for _ in range(20): + _ = self.infer(dummy_input) + + # estimate average inference time + start_time = time.perf_counter() + for _ in range(1000): + _ = self.infer(dummy_input) + end_time = time.perf_counter() - start_time + + # average = end_time / 1000 (s) + # ms = average * 1000 (ms) + return end_time + + def tsne_visualize(self): + from sklearn.manifold import TSNE + import matplotlib.pyplot as plt + + if self.model_path is not None: + self.load_params(self.model_path) + else: + # load pretrained model + if self.simsiam_pretrained is not None and self.pretrained is None: + self.load_simsiam_pretrained_model() + + if self.pretrained is not None and self.simsiam_pretrained is None: + self.load_params(self.pretrained) + + self.net.eval() + layer_name = 'flatten' + hook = LayerOutHook(self.net, layer_name) + + torch.backends.cudnn.deterministic = True + os.environ['PYTHONHASHSEED'] = str(42) + np.random.seed(42) + torch.manual_seed(42) + torch.cuda.manual_seed(42) + torch.cuda.manual_seed_all(42) + + # take n models per class for TSNE visualization + num_data_vis = 40 + dim = 2 + partition = createPartition(self.data_path, + self.num_of_class, + self.resolution, + self.num_train, + num_val_test=num_data_vis) + test_set = FeatureDataset(partition['test'], + resolution=self.resolution, + output_type=self.output_type, + num_cuts=self.num_cuts, + data_augmentation=None) + testloader = torch.utils.data.DataLoader(test_set, + batch_size=num_data_vis*self.num_of_class, + shuffle=False, + num_workers=self.workers, + pin_memory=True, + drop_last=False) + + with torch.no_grad(): + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to( + self.device), targets.to(self.device) + + _ = self.net(inputs) + last_layer = hook.output + + tsne = TSNE(perplexity=100, n_components=dim, init='pca', n_iter=5000) + low_dim_embs = tsne.fit_transform(last_layer.cpu().data.numpy()) + labels = targets.cpu().numpy() + if dim == 2: + plot_with_labels(low_dim_embs, labels) + elif dim == 3: + plot3D_with_labels(low_dim_embs, labels) + + plt.ioff() + + def draw_ROC_CM(self): + """ + Draw ROC curve + """ + from sklearn.metrics import roc_curve, auc, confusion_matrix + import matplotlib.pyplot as plt + + if self.model_path is not None: + self.load_params(self.model_path) + else: + # load pretrained model + if self.simsiam_pretrained is not None and self.pretrained is None: + self.load_simsiam_pretrained_model() + + if self.pretrained is not None and self.simsiam_pretrained is None: + self.load_params(self.pretrained) + + self.net.eval() + val_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): + scores_list = [] + labels_list = [] + preds_list = [] + for batch_idx, (inputs, targets) in tqdm(enumerate(self.testloader)): + inputs, targets = inputs.to(self.device), targets.to(self.device) + outputs = self.net(inputs) + scores = torch.softmax(outputs, dim=1) + preds = torch.argmax(outputs, dim=1) + + loss = self.criterion(outputs, targets) + val_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + scores_list.append(scores.cpu().numpy()) + labels_list.append(targets.cpu().numpy()) + preds_list.extend(preds.cpu().numpy().tolist()) + scores = np.concatenate(scores_list) + labels = np.concatenate(labels_list) + + print('ACC: ', 100.*correct/total) + + # Assuming you have true labels in labels and predicted scores in scores + n_classes = len(self.shapetypes) + fpr = dict() + tpr = dict() + roc_auc = dict() + for i in range(n_classes): + fpr[i], tpr[i], _ = roc_curve(labels == i, scores[:, i]) + roc_auc[i] = auc(fpr[i], tpr[i]) + + # Plot ROC curve for each class + plt.figure(figsize=(8, 6)) + for i in range(n_classes): + plt.plot(fpr[i], tpr[i], label=f'ROC curve (AUC = {roc_auc[i]:.2f}) for class {self.shapetypes[i]}') + plt.plot([0, 1], [0, 1], 'k--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('Multiclass ROC Curve') + plt.legend(loc="lower right") + plt.show() + + # Draw ROC curve + cm = confusion_matrix(labels, preds_list) + # Plot confusion matrix as an image + plt.imshow(cm) + plt.colorbar() + plt.xticks(np.arange(len(self.shapetypes)), self.shapetypes, rotation=90) + plt.yticks(np.arange(len(self.shapetypes)), self.shapetypes) + plt.xlabel('Predicted') + plt.ylabel('True') + plt.show() + + + +def train_eval_model(cfg): + warnings.filterwarnings('ignore', '.*output shape of zoom.*') + + trainer = Trainer(cfg) + # train the model + trainer.train() + # load best model and evalute the model + trainer.load_params(os.path.join(trainer.output_dir, 'best_model.pth')) + val_acc, _ = trainer.valtest_epoch(testval='val') + test_acc, _ = trainer.valtest_epoch(testval='test') + # log results + result_str = '\n\nVal Acc: {:.2f} | Test Acc: {:.2f}'.format( + val_acc, test_acc) + print(result_str) + trainer.logger.info(result_str) + + +def eval_model(cfg): + warnings.filterwarnings('ignore', '.*output shape of zoom.*') + + trainer = Trainer(cfg) + # load best model and evalute the model + if cfg.model_path is not None: + trainer.load_params(cfg.model_path) + else: + return None + val_acc, _ = trainer.valtest_epoch(testval='val') + test_acc, _ = trainer.valtest_epoch(testval='test') + # log results + result_str = '\n\nVal Acc: {:.2f} | Test Acc: {:.2f}'.format( + val_acc, test_acc) + print(result_str) + trainer.logger.info(result_str) + + +def infer_time_test(cfg): + warnings.filterwarnings('ignore', '.*output shape of zoom.*') + + trainer = Trainer(cfg) + avg_infer_time = trainer.estimate_infer_time() + + # log results + result_str = '\n\n average inference time(on {:s}): {:.2f}ms'.format(cfg.device, avg_infer_time) + print(result_str) + trainer.logger.info(result_str) + + +def draw_TSNE(cfg): + warnings.filterwarnings('ignore', '.*output shape of zoom.*') + + trainer = Trainer(cfg) + trainer.tsne_visualize() + +def draw_ROC_CM(cfg): + warnings.filterwarnings('ignore', '.*output shape of zoom.*') + + cfg.val_batchsize = 256 # may be fast + trainer = Trainer(cfg) + trainer.draw_ROC_CM() \ No newline at end of file diff --git a/core/models/__init__.py b/core/models/__init__.py new file mode 100644 index 0000000..06c8464 --- /dev/null +++ b/core/models/__init__.py @@ -0,0 +1,16 @@ +''' +Author: whj +Date: 2022-02-15 15:53:55 +LastEditors: whj +LastEditTime: 2022-02-28 13:50:07 +Description: file content +''' + +# coding: utf-8 + +from .featurenet import * +from .featurenetlite import * +from .msvnetlite import * +from .msvnet import * +from .others import * + diff --git a/core/models/act_helper.py b/core/models/act_helper.py new file mode 100644 index 0000000..a355679 --- /dev/null +++ b/core/models/act_helper.py @@ -0,0 +1,25 @@ +''' +Author: whj +Date: 2022-02-15 12:36:42 +LastEditors: whj +LastEditTime: 2022-02-16 11:38:53 +Description: file content +''' +import torch +import torch.nn as nn + + + +def act_helper(act_type): + if act_type == 'relu': + return nn.ReLU(inplace=True) + elif act_type == 'leakyrelu': + return nn.LeakyReLU(0.2, inplace=True) + elif act_type == 'silu': + return nn.SiLU(inplace=True) + elif act_type == 'hardswish': + return nn.Hardswish(inplace=True) + elif act_type == 'gelu': + return nn.GELU(inplace=True) + else: + raise ValueError('Unsupported activation type: ' + act_type) \ No newline at end of file diff --git a/core/models/featurenet.py b/core/models/featurenet.py new file mode 100644 index 0000000..91f9d7d --- /dev/null +++ b/core/models/featurenet.py @@ -0,0 +1,121 @@ +''' +Author: whj +Date: 2022-02-15 12:36:42 +LastEditors: whj +LastEditTime: 2022-02-28 13:25:01 +Description: file content +''' +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms, utils +import torchvision.models as models + +from .act_helper import act_helper + + + +class conv3dBlock(nn.Module): + def __init__(self, _in, _out, ksize=3, stride=1, padding='same', dila=1, groups=1, bias=True, bn=False, act='relu'): + super(conv3dBlock, self).__init__() + self.bn = bn + if padding == 'same': + padding = (ksize - 1) // 2 + self.conv = nn.Conv3d(in_channels=_in, + out_channels=_out, + kernel_size=ksize, + stride=stride, + padding=padding, + dilation=dila, + groups=groups, + bias=bias) + + if self.bn: + self.bn = nn.BatchNorm3d(_out) + + self.act = act_helper(act) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + return self.act(x) + + +class conv3d_4x4_Block(nn.Module): + def __init__(self, _in, _out, ksize=4, stride=1, padding='same', dila=1, groups=1, bias=True, bn=False, act='relu'): + super(conv3d_4x4_Block, self).__init__() + self.bn = bn + if padding == 'same': + padding = (ksize - 1) // 2 + # pytorch do not have same padding + # a even number convolution with same padding + # need to pad 1 more element + # at left, _, top, _, front, _ + self.pad = nn.ConstantPad3d((1, 0, 1, 0, 1, 0), 0) + self.conv = nn.Conv3d(in_channels=_in, + out_channels=_out, + kernel_size=ksize, + stride=stride, + padding=padding, + dilation=dila, + groups=groups, + bias=bias) + + if self.bn: + self.bn = nn.BatchNorm3d(_out) + + self.act = act_helper(act) + + def forward(self, x): + x = self.pad(x) + x = self.conv(x) + if self.bn: + x = self.bn(x) + return self.act(x) + + +class FeatureNetBackbone(nn.Module): + def __init__(self, in_channels=1): + super(FeatureNetBackbone, self).__init__() + self.block_1 = conv3dBlock(in_channels, 32, ksize=7, stride=2, dila=1, padding='same', groups=1, bias=True, bn=False, act='relu') + self.block_2 = conv3dBlock(32, 32, ksize=5, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + # special 4x4 convolution layer + self.block_3 = conv3d_4x4_Block(32, 64) + # pytorch do not have same padding + # not use a 4x4 convolution layer + # self.block_3 = conv3dBlock(32, 64, ksize=5, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.block_4 = conv3dBlock(64, 64, ksize=3, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) + + def forward(self, x): + x = self.block_1(x) + x = self.block_2(x) + x = self.block_3(x) + x = self.block_4(x) + x = self.pool(x) + + return x + + +class FeatureNet(nn.Module): + ''' + paper: FeatureNet: Machining feature recognition based on 3D Convolution Neural Network + ''' + def __init__(self, num_classes=24, input_shape=(64, 64, 64)): + super(FeatureNet, self).__init__() + self.backbone = FeatureNetBackbone(in_channels=1) + # downsample 1/4 & out channel 64 + inshape = 64 * (input_shape[0] // 4) ** 3 + self.fc1 = nn.Linear(inshape, 128, bias=True) + self.fc2 = nn.Linear(128, num_classes, bias=True) + + def forward(self, x): + x = self.backbone(x) + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x \ No newline at end of file diff --git a/core/models/featurenetlite.py b/core/models/featurenetlite.py new file mode 100644 index 0000000..ea3b666 --- /dev/null +++ b/core/models/featurenetlite.py @@ -0,0 +1,252 @@ +''' +Author: whj +Date: 2022-02-15 12:36:42 +LastEditors: whj +LastEditTime: 2022-03-02 14:25:14 +Description: file content +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = [ + "FeatureNetLite" +] + +NET_CONFIG = { + #k, in_c, out_c, s, use_se + "blocks2":[[3, 16, 32, 1, False]], + "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]], + "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]], + "blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False]], + "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]] +} + +# NET_CONFIG = { +# #k, in_c, out_c, s, use_se +# "blocks2":[[3, 16, 32, 1, False]], +# "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]], +# "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]], +# "blocks5": [[5, 128, 256, 2, False], [5, 256, 256, 1, False], +# [5, 256, 256, 1, False], [5, 256, 256, 1, False], +# [5, 256, 256, 1, False], [5, 256, 256, 1, False]], +# "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]] +# } + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + +class Hardswish(nn.Module): + def __init__(self, inplace=True): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3., inplace=self.inplace) / 6. + +class Hardsigmoid(nn.Module): + def __init__(self, inplace=True): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3., inplace=True) / 6. + +class ConvBNLayer(nn.Module): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + num_groups=1): + super().__init__() + + self.conv = nn.Conv3d( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=num_groups, + bias=False) + + self.bn = nn.BatchNorm3d( + num_filters, + ) + self.hardswish = Hardswish() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.hardswish(x) + return x + +class DepthwiseSeparable(nn.Module): + def __init__(self, + num_channels, + num_filters, + stride, + dw_size=3, + use_se=False): + super().__init__() + self.use_se = use_se + self.dw_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_channels, + filter_size=dw_size, + stride=stride, + num_groups=num_channels) + if use_se: + self.se = SEModule(num_channels) + self.pw_conv = ConvBNLayer( + num_channels=num_channels, + filter_size=1, + num_filters=num_filters, + stride=1) + + def forward(self, x): + x = self.dw_conv(x) + if self.use_se: + x = self.se(x) + x = self.pw_conv(x) + return x + + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv3d( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = nn.Conv3d( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = Hardsigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = torch.mul(identity, x) + return x + + +class FeatureNetLite(nn.Module): + def __init__(self, + scale=1.0, + num_classes=1000, + dropout_prob=0.0, + class_expand=1280): + super().__init__() + self.scale = scale + self.class_expand = class_expand + + self.conv1 = ConvBNLayer( + num_channels=1, + filter_size=3, + num_filters=make_divisible(16 * scale), + stride=2) + + self.blocks2 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"]) + ]) + + self.blocks3 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"]) + ]) + + self.blocks4 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"]) + ]) + + self.blocks5 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"]) + ]) + + self.blocks6 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"]) + ]) + + self.avg_pool = nn.AdaptiveAvgPool3d(1) + + self.last_conv = nn.Conv3d( + in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale), + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias=False) + + self.hardswish = Hardswish() + self.dropout = nn.Dropout(dropout_prob) + self.flatten = nn.Flatten(start_dim=1, end_dim=-1) + + self.classifier = nn.Linear(self.class_expand, num_classes) + + def forward(self, x): + x = self.conv1(x) + + x = self.blocks2(x) + x = self.blocks3(x) + x = self.blocks4(x) + x = self.blocks5(x) + x = self.blocks6(x) + + x = self.avg_pool(x) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.classifier(x) + return x \ No newline at end of file diff --git a/core/models/msvnet.py b/core/models/msvnet.py new file mode 100644 index 0000000..f81933e --- /dev/null +++ b/core/models/msvnet.py @@ -0,0 +1,66 @@ +''' +Author: whj +Date: 2022-02-28 13:27:06 +LastEditors: whj +LastEditTime: 2022-03-05 14:49:45 +Description: file content +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + + + +class SCNN(nn.Module): + def __init__(self, num_classes=24, pretraining=True): + super(SCNN, self).__init__() + self.num_classes = num_classes + self.pretraining = pretraining + + self.features = models.vgg11(pretrained=self.pretraining).features + self.adppool = nn.AdaptiveAvgPool2d((7,7)) + self.classifier = models.vgg11(pretrained=self.pretraining).classifier + self.classifier._modules['6'] = nn.Linear(4096, num_classes) + + if self.pretraining: + nn.init.normal_(self.classifier._modules['6'].weight, 0, 0.01) + nn.init.constant_(self.classifier._modules['6'].bias, 0) + else: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + y = self.features(x) + y = self.adppool(y) + + return self.classifier(y.view(y.shape[0],-1)) + + +class MCNN(nn.Module): + def __init__(self, model, num_classes=24, num_cuts=12): + super(MCNN, self).__init__() + self.num_classes = num_classes + self.num_cuts = num_cuts + + self.features = model.features + self.adppool = model.adppool + self.classifier = model.classifier + + def forward(self, x): + x = x.reshape(-1,3,x.shape[3],x.shape[4]) + y = self.features(x) + y = y.view((int(x.shape[0]/self.num_cuts),self.num_cuts,y.shape[-3],y.shape[-2],y.shape[-1]))#(8,12,512,7,7) + y = self.adppool(torch.max(y,1)[0]) + y = self.classifier(y.view(y.shape[0],-1)) + return y + diff --git a/core/models/msvnetlite.py b/core/models/msvnetlite.py new file mode 100644 index 0000000..95bff0f --- /dev/null +++ b/core/models/msvnetlite.py @@ -0,0 +1,249 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "MsvNetLite" +] + +NET_CONFIG = { + "blocks2": + #k, in_c, out_c, s, use_se + [[3, 16, 32, 1, False]], + "blocks3": [[3, 32, 64, 2, False], [3, 64, 64, 1, False]], + "blocks4": [[3, 64, 128, 2, False], [3, 128, 128, 1, False]], + "blocks5": [[3, 128, 256, 2, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False], + [5, 256, 256, 1, False], [5, 256, 256, 1, False]], + "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]] +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class Hardswish(nn.Module): + def __init__(self, inplace=True): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return x * F.relu6(x + 3., inplace=self.inplace) / 6. + + +class Hardsigmoid(nn.Module): + def __init__(self, inplace=True): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3., inplace=True) / 6. + + +class ConvBNLayer(nn.Module): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + num_groups=1): + super().__init__() + + self.conv = nn.Conv2d( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=num_groups, + bias=False) + + self.bn = nn.BatchNorm2d( + num_filters, + ) + self.hardswish = Hardswish() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.hardswish(x) + return x + + +class DepthwiseSeparable(nn.Module): + def __init__(self, + num_channels, + num_filters, + stride, + dw_size=3, + use_se=False): + super().__init__() + self.use_se = use_se + self.dw_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=num_channels, + filter_size=dw_size, + stride=stride, + num_groups=num_channels) + if use_se: + self.se = SEModule(num_channels) + self.pw_conv = ConvBNLayer( + num_channels=num_channels, + filter_size=1, + num_filters=num_filters, + stride=1) + + def forward(self, x): + x = self.dw_conv(x) + if self.use_se: + x = self.se(x) + x = self.pw_conv(x) + return x + + +class SEModule(nn.Module): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv1 = nn.Conv2d( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = Hardsigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = torch.mul(identity, x) + return x + + +class MsvNetLite(nn.Module): + def __init__(self, + scale=1.0, + num_classes=24, + dropout_prob=0.0, + class_expand=1280, + num_cuts=12): + super().__init__() + self.scale = scale + self.class_expand = class_expand + self.num_cuts = num_cuts + + self.conv1 = ConvBNLayer( + num_channels=3, + filter_size=3, + num_filters=make_divisible(16 * scale), + stride=2) + + self.blocks2 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"]) + ]) + + self.blocks3 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"]) + ]) + + self.blocks4 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"]) + ]) + + self.blocks5 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"]) + ]) + + self.blocks6 = nn.Sequential(*[ + DepthwiseSeparable( + num_channels=make_divisible(in_c * scale), + num_filters=make_divisible(out_c * scale), + dw_size=k, + stride=s, + use_se=se) + for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"]) + ]) + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + self.last_conv = nn.Conv2d( + in_channels=make_divisible(NET_CONFIG["blocks6"][-1][2] * scale), + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias=False) + self.hardswish = Hardswish() + self.dropout = nn.Dropout(dropout_prob) + self.flatten = nn.Flatten(start_dim=1, end_dim=-1) + + self.classifier = nn.Linear(self.class_expand, num_classes) + + def forward(self, x): + # input shape; (b, v, 3, h, w) + x = x.reshape(-1, 3, x.shape[3], x.shape[4]) # out: (b*v, 3, h, w) + + x = self.conv1(x) + x = self.blocks2(x) + x = self.blocks3(x) + x = self.blocks4(x) + x = self.blocks5(x) + x = self.blocks6(x) + + # multiview features fusion + x = x.view((int(x.shape[0]/self.num_cuts), self.num_cuts, + x.shape[-3], x.shape[-2], x.shape[-1])) # out: (b,v,c,h,w) + # x = torch.mean(x, 1) # out: (b,c,h,w) + # view pooling + x = torch.max(x, 1)[0] # out: (b,c,h,w) + + x = self.avg_pool(x) + + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + + x = self.flatten(x) + x = self.classifier(x) + return x \ No newline at end of file diff --git a/core/models/others.py b/core/models/others.py new file mode 100644 index 0000000..44d9bfd --- /dev/null +++ b/core/models/others.py @@ -0,0 +1,205 @@ +''' +Author: whj +Date: 2022-02-28 13:23:59 +LastEditors: whj +LastEditTime: 2022-02-28 13:25:11 +Description: file content +''' +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms, utils +import torchvision.models as models + +from .act_helper import act_helper + + + +class conv3dBlock(nn.Module): + def __init__(self, _in, _out, ksize=3, stride=1, padding='same', dila=1, groups=1, bias=True, bn=False, act='relu'): + super(conv3dBlock, self).__init__() + self.bn = bn + if padding == 'same': + padding = (ksize - 1) // 2 + self.conv = nn.Conv3d(in_channels=_in, + out_channels=_out, + kernel_size=ksize, + stride=stride, + padding=padding, + dilation=dila, + groups=groups, + bias=bias) + + if self.bn: + self.bn = nn.BatchNorm3d(_out) + + self.act = act_helper(act) + + def forward(self, x): + x = self.conv(x) + if self.bn: + x = self.bn(x) + return self.act(x) + + +class BaselineNetBackbone(nn.Module): + ''' + paper: Identifying manufacturability and machining processes using deep 3D convolutional networks + ''' + def __init__(self, in_channels=1): + super(BaselineNetBackbone, self).__init__() + self.block_1 = conv3dBlock(in_channels, 32, ksize=7, stride=2, dila=1, padding='same', groups=1, bias=True, bn=False, act='relu') + self.block_2 = conv3dBlock(32, 32, ksize=5, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.block_3 = conv3dBlock(32, 64, ksize=3, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0) + + def forward(self, x): + x = self.block_1(x) + x = self.block_2(x) + x = self.block_3(x) + x = self.pool(x) + + return x + + +class BaselineNet(nn.Module): + ''' + paper: Identifying manufacturability and machining processes using deep 3D convolutional networks + ''' + def __init__(self, num_classes=24, input_shape=(64, 64, 64)): + super(BaselineNet, self).__init__() + self.backbone = BaselineNetBackbone(in_channels=1) + self.drop = nn.Dropout(p=0.2) + # downsample 1/4 & out channel 64 + inshape = 64 * (input_shape[0] // 4) ** 3 + self.fc1 = nn.Linear(inshape, 128, bias=True) + self.fc2 = nn.Linear(128, num_classes, bias=True) + + def forward(self, x): + x = self.backbone(x) + x = x.view(x.size(0), -1) + x = self.drop(x) + + x = self.fc1(x) + x = F.relu(x) + x = self.drop(x) + + x = self.fc2(x) + return x + + +class BaselineNet2Backbone(nn.Module): + ''' + paper: Part machining feature recognition based on a deep learning method + ''' + def __init__(self, in_channels=1): + super(BaselineNet2Backbone, self).__init__() + self.block_1 = conv3dBlock(in_channels, 32, ksize=3, stride=1, dila=1, padding='same', groups=1, bias=True, bn=False, act='relu') + self.block_2 = conv3dBlock(32, 32, ksize=3, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.block_3 = conv3dBlock(32, 32, ksize=3, stride=2, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.block_4 = conv3dBlock(32, 64, ksize=3, stride=1, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + self.block_5 = conv3dBlock(64, 64, ksize=3, stride=2, dila=1, groups=1, padding='same', bias=True, bn=False, act='relu') + + def forward(self, x): + x = self.block_1(x) + x = self.block_2(x) + x = self.block_3(x) + x = self.block_4(x) + x = self.block_5(x) + + return x + + +class BaselineNet2(nn.Module): + ''' + paper: Part machining feature recognition based on a deep learning method + ''' + def __init__(self, num_classes=24, input_shape=(64, 64, 64)): + super(BaselineNet2, self).__init__() + self.backbone = BaselineNet2Backbone(in_channels=1) + # downsample 1/4 & out channel 64 + inshape = 64 * (input_shape[0] // 4) ** 3 + self.fc1 = nn.Linear(inshape, 124, bias=True) + self.fc2 = nn.Linear(124, 124, bias=True) + self.fc3 = nn.Linear(124, num_classes, bias=True) + + def forward(self, x): + x = self.backbone(x) + x = x.view(x.size(0), -1) + + x = self.fc1(x) + x = F.relu(x) + + x = self.fc2(x) + x = F.relu(x) + + x = self.fc3(x) + return x + + +class VoxNet(torch.nn.Module): + def __init__(self, num_classes, input_shape=(64, 64, 64)): + #weights_path=None, + #load_body_weights=True, + #load_head_weights=True): + """ + VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition. + Modified in order to accept different input shapes. + Parameters + ---------- + num_classes: int, optional + Default: 10 + input_shape: (x, y, z) tuple, optional + Default: (32, 32, 32) + weights_path: str or None, optional + Default: None + load_body_weights: bool, optional + Default: True + load_head_weights: bool, optional + Default: True + Notes + ----- + Weights available at: url to be added + If you want to finetune with custom classes, set load_head_weights to False. + Default head weights are pretrained with ModelNet10. + """ + super(VoxNet, self).__init__() + self.body = torch.nn.Sequential(OrderedDict([ + ('conv1', torch.nn.Conv3d(in_channels=1, + out_channels=32, kernel_size=5, stride=2)), + ('lkrelu1', torch.nn.LeakyReLU()), + ('drop1', torch.nn.Dropout(p=0.2)), + ('conv2', torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3)), + ('lkrelu2', torch.nn.LeakyReLU()), + ('pool2', torch.nn.MaxPool3d(2)), + ('drop2', torch.nn.Dropout(p=0.3)) + ])) + + # Trick to accept different input shapes + x = self.body(torch.autograd.Variable( + torch.rand((1, 1) + input_shape))) + first_fc_in_features = 1 + for n in x.size()[1:]: + first_fc_in_features *= n + + self.head = torch.nn.Sequential(OrderedDict([ + ('fc1', torch.nn.Linear(first_fc_in_features, 128)), + ('relu1', torch.nn.ReLU()), + ('drop3', torch.nn.Dropout(p=0.4)), + ('fc2', torch.nn.Linear(128, num_classes)) + ])) + + #if weights_path is not None: + # weights = torch.load(weights_path) + # if load_body_weights: + # self.body.load_state_dict(weights["body"]) + # elif load_head_weights: + # self.head.load_state_dict(weights["head"]) + + def forward(self, x): + x = self.body(x) + x = x.view(x.size(0), -1) + x = self.head(x) + return x \ No newline at end of file diff --git a/core/simsiam/__init__.py b/core/simsiam/__init__.py new file mode 100644 index 0000000..aaeb9b8 --- /dev/null +++ b/core/simsiam/__init__.py @@ -0,0 +1,7 @@ +''' +Author: whj +Date: 2022-02-27 17:03:34 +LastEditors: whj +LastEditTime: 2022-02-27 19:21:21 +Description: file content +''' diff --git a/core/simsiam/builder.py b/core/simsiam/builder.py new file mode 100644 index 0000000..f593957 --- /dev/null +++ b/core/simsiam/builder.py @@ -0,0 +1,73 @@ +''' +Author: whj +Date: 2022-02-27 17:03:34 +LastEditors: whj +LastEditTime: 2022-03-04 20:10:19 +Description: file content +''' + +import torch +import torch.nn as nn + + +class SimSiam(nn.Module): + """ + Build a SimSiam model. + """ + def __init__(self, base_encoder, dim=2048, pred_dim=512, encoder_pretrain=None): + """ + dim: feature dimension (default: 2048) + pred_dim: hidden dimension of the predictor (default: 512) + """ + super(SimSiam, self).__init__() + + # create the encoder + # num_classes is the output classifier dimension, zero-initialize last BNs + self.encoder = base_encoder(num_classes=dim) + + if encoder_pretrain: + self.load_encoder_pretrain(encoder_pretrain) + + # build a 3-layer projector + prev_dim = self.encoder.classifier.weight.shape[1] + self.encoder.classifier = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False), + nn.BatchNorm1d(prev_dim), + nn.ReLU(inplace=True), # first layer + nn.Linear(prev_dim, prev_dim, bias=False), + nn.BatchNorm1d(prev_dim), + nn.ReLU(inplace=True), # second layer + self.encoder.classifier, + nn.BatchNorm1d(dim, affine=False)) # output layer + self.encoder.classifier[6].bias.requires_grad = False # hack: not use bias as it is followed by BN + + # build a 2-layer predictor + self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False), + nn.BatchNorm1d(pred_dim), + nn.ReLU(inplace=True), # hidden layer + nn.Linear(pred_dim, dim)) # output layer + + def forward(self, x1, x2): + """ + Input: + x1: first views of images + x2: second views of images + Output: + p1, p2, z1, z2: predictors and targets of the network + See Sec. 3 of https://arxiv.org/abs/2011.10566 for detailed notations + """ + + # compute features for one view + z1 = self.encoder(x1) # NxC + z2 = self.encoder(x2) # NxC + + p1 = self.predictor(z1) # NxC + p2 = self.predictor(z2) # NxC + + return p1, p2, z1.detach(), z2.detach() + + def load_encoder_pretrain(self, pretrain_path): + checkpoint = torch.load(pretrain_path) + msg = self.encoder.load_state_dict(checkpoint, strict=False) + print(msg) + + diff --git a/core/utils/IntermediateLayerGetter.py b/core/utils/IntermediateLayerGetter.py new file mode 100644 index 0000000..9eca938 --- /dev/null +++ b/core/utils/IntermediateLayerGetter.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from collections import OrderedDict + +class IntermediateLayerGetter(nn.ModuleDict): + """ get the output of certain layers """ + def __init__(self, model, return_layers): + # 判断传入的return_layers是否存在于model中 + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + + orig_return_layers = return_layers + return_layers = {k: v for k, v in return_layers.items()} # 构造dict + layers = OrderedDict() + # 将要从model中获取信息的最后一层之前的模块全部复制下来 + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) # 将所需的网络层通过继承的方式保存下来 + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + # 将所需的值以k,v的形式保存到out中 + for name, module in self.named_children(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out diff --git a/core/utils/LayerOutHook.py b/core/utils/LayerOutHook.py new file mode 100644 index 0000000..c8051e7 --- /dev/null +++ b/core/utils/LayerOutHook.py @@ -0,0 +1,21 @@ +''' +Author: whj +Date: 2022-05-31 21:38:24 +LastEditors: whj +LastEditTime: 2022-05-31 21:48:02 +Description: file content +''' + +class LayerOutHook(): + """ + Hook for LayerOut. + """ + def __init__(self, net, layer_name): + self.output = None + + for (name, module) in net.named_modules(): + if name == layer_name: + module.register_forward_hook(hook=self.hook) + + def hook(self, module, input, output): + self.output = output diff --git a/core/utils/__init__.py b/core/utils/__init__.py new file mode 100644 index 0000000..87e9476 --- /dev/null +++ b/core/utils/__init__.py @@ -0,0 +1,14 @@ +''' +Author: whj +Date: 2022-02-15 15:53:55 +LastEditors: whj +LastEditTime: 2022-05-31 21:47:27 +Description: file content +''' + +# coding: utf-8 +from .binvox_rw import read_as_3d_array +from .vis import * +from .logger import setup_logger, get_logger +from .IntermediateLayerGetter import IntermediateLayerGetter +from .LayerOutHook import LayerOutHook diff --git a/core/utils/binvox_rw.py b/core/utils/binvox_rw.py new file mode 100644 index 0000000..704fc1f --- /dev/null +++ b/core/utils/binvox_rw.py @@ -0,0 +1,284 @@ +# Copyright (C) 2012 Daniel Maturana +# This file is part of binvox-rw-py. +# +# binvox-rw-py is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# binvox-rw-py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with binvox-rw-py. If not, see . +# + +""" +Binvox to Numpy and back. + + +>>> import numpy as np +>>> import binvox_rw +>>> with open('chair.binvox', 'rb') as f: +... m1 = binvox_rw.read_as_3d_array(f) +... +>>> m1.dims +[32, 32, 32] +>>> m1.scale +41.133000000000003 +>>> m1.translate +[0.0, 0.0, 0.0] +>>> with open('chair_out.binvox', 'wb') as f: +... m1.write(f) +... +>>> with open('chair_out.binvox', 'rb') as f: +... m2 = binvox_rw.read_as_3d_array(f) +... +>>> m1.dims==m2.dims +True +>>> m1.scale==m2.scale +True +>>> m1.translate==m2.translate +True +>>> np.all(m1.data==m2.data) +True + +>>> with open('chair.binvox', 'rb') as f: +... md = binvox_rw.read_as_3d_array(f) +... +>>> with open('chair.binvox', 'rb') as f: +... ms = binvox_rw.read_as_coord_array(f) +... +>>> data_ds = binvox_rw.dense_to_sparse(md.data) +>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) +>>> np.all(data_sd==md.data) +True +>>> # the ordering of elements returned by numpy.nonzero changes with axis +>>> # ordering, so to compare for equality we first lexically sort the voxels. +>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) +True +""" + +import numpy as np + +class Voxels(object): + """ Holds a binvox model. + data is either a three-dimensional numpy boolean array (dense representation) + or a two-dimensional numpy float array (coordinate representation). + + dims, translate and scale are the model metadata. + + dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. + + scale and translate relate the voxels to the original model coordinates. + + To translate voxel coordinates i, j, k to original coordinates x, y, z: + + x_n = (i+.5)/dims[0] + y_n = (j+.5)/dims[1] + z_n = (k+.5)/dims[2] + x = scale*x_n + translate[0] + y = scale*y_n + translate[1] + z = scale*z_n + translate[2] + + """ + + def __init__(self, data, dims, translate, scale, axis_order): + self.data = data + self.dims = dims + self.translate = translate + self.scale = scale + assert (axis_order in ('xzy', 'xyz')) + self.axis_order = axis_order + + def clone(self): + data = self.data.copy() + dims = self.dims[:] + translate = self.translate[:] + return Voxels(data, dims, translate, self.scale, self.axis_order) + + def write(self, fp): + write(self, fp) + +def read_header(fp): + """ Read binvox header. Mostly meant for internal use. + """ + line = fp.readline().strip() + if not line.startswith(b'#binvox'): + raise IOError('Not a binvox file') + dims = list(map(int, fp.readline().strip().split(b' ')[1:])) + translate = list(map(float, fp.readline().strip().split(b' ')[1:])) + scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] + line = fp.readline() + return dims, translate, scale + +def read_as_3d_array(fp, fix_coords=True): + """ Read binary binvox format as array. + + Returns the model with accompanying metadata. + + Voxels are stored in a three-dimensional numpy array, which is simple and + direct, but may use a lot of memory for large models. (Storage requirements + are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy + boolean arrays use a byte per element). + + Doesn't do any checks on input except for the '#binvox' line. + """ + dims, translate, scale = read_header(fp) + raw_data = np.frombuffer(fp.read(), dtype=np.uint8) + # if just using reshape() on the raw data: + # indexing the array as array[i,j,k], the indices map into the + # coords as: + # i -> x + # j -> z + # k -> y + # if fix_coords is true, then data is rearranged so that + # mapping is + # i -> x + # j -> y + # k -> z + values, counts = raw_data[::2], raw_data[1::2] + data = np.repeat(values, counts).astype(np.bool) + data = data.reshape(dims) + if fix_coords: + # xzy to xyz TODO the right thing + data = np.transpose(data, (0, 2, 1)) + axis_order = 'xyz' + else: + axis_order = 'xzy' + return Voxels(data, dims, translate, scale, axis_order) + +def read_as_coord_array(fp, fix_coords=True): + """ Read binary binvox format as coordinates. + + Returns binvox model with voxels in a "coordinate" representation, i.e. an + 3 x N array where N is the number of nonzero voxels. Each column + corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates + of the voxel. (The odd ordering is due to the way binvox format lays out + data). Note that coordinates refer to the binvox voxels, without any + scaling or translation. + + Use this to save memory if your model is very sparse (mostly empty). + + Doesn't do any checks on input except for the '#binvox' line. + """ + dims, translate, scale = read_header(fp) + raw_data = np.frombuffer(fp.read(), dtype=np.uint8) + + values, counts = raw_data[::2], raw_data[1::2] + + sz = np.prod(dims) + index, end_index = 0, 0 + end_indices = np.cumsum(counts) + indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) + + values = values.astype(np.bool) + indices = indices[values] + end_indices = end_indices[values] + + nz_voxels = [] + for index, end_index in zip(indices, end_indices): + nz_voxels.extend(range(index, end_index)) + nz_voxels = np.array(nz_voxels) + # TODO are these dims correct? + # according to docs, + # index = x * wxh + z * width + y; // wxh = width * height = d * d + + x = nz_voxels / (dims[0]*dims[1]) + zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y + z = zwpy / dims[0] + y = zwpy % dims[0] + if fix_coords: + data = np.vstack((x, y, z)) + axis_order = 'xyz' + else: + data = np.vstack((x, z, y)) + axis_order = 'xzy' + + #return Voxels(data, dims, translate, scale, axis_order) + return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) + +def dense_to_sparse(voxel_data, dtype=np.int): + """ From dense representation to sparse (coordinate) representation. + No coordinate reordering. + """ + if voxel_data.ndim!=3: + raise ValueError('voxel_data is wrong shape; should be 3D array.') + return np.asarray(np.nonzero(voxel_data), dtype) + +def sparse_to_dense(voxel_data, dims, dtype=np.bool): + if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: + raise ValueError('voxel_data is wrong shape; should be 3xN array.') + if np.isscalar(dims): + dims = [dims]*3 + dims = np.atleast_2d(dims).T + # truncate to integers + xyz = voxel_data.astype(np.int) + # discard voxels that fall outside dims + valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) + xyz = xyz[:,valid_ix] + out = np.zeros(dims.flatten(), dtype=dtype) + out[tuple(xyz)] = True + return out + +#def get_linear_index(x, y, z, dims): + #""" Assuming xzy order. (y increasing fastest. + #TODO ensure this is right when dims are not all same + #""" + #return x*(dims[1]*dims[2]) + z*dims[1] + y + +def write(voxel_model, fp): + """ Write binary binvox format. + + Note that when saving a model in sparse (coordinate) format, it is first + converted to dense format. + + Doesn't check if the model is 'sane'. + + """ + if voxel_model.data.ndim==2: + # TODO avoid conversion to dense + dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) + else: + dense_voxel_data = voxel_model.data + + fp.write('#binvox 1\n') + fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') + fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') + fp.write('scale '+str(voxel_model.scale)+'\n') + fp.write('data\n') + if not voxel_model.axis_order in ('xzy', 'xyz'): + raise ValueError('Unsupported voxel model axis order') + + if voxel_model.axis_order=='xzy': + voxels_flat = dense_voxel_data.flatten() + elif voxel_model.axis_order=='xyz': + voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() + + # keep a sort of state machine for writing run length encoding + state = voxels_flat[0] + ctr = 0 + for c in voxels_flat: + if c==state: + ctr += 1 + # if ctr hits max, dump + if ctr==255: + fp.write(chr(state)) + fp.write(chr(ctr)) + ctr = 0 + else: + # if switch state, dump + fp.write(chr(state)) + fp.write(chr(ctr)) + state = c + ctr = 1 + # flush out remainders + if ctr > 0: + fp.write(chr(state)) + fp.write(chr(ctr)) + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/core/utils/logger.py b/core/utils/logger.py new file mode 100644 index 0000000..29dcb49 --- /dev/null +++ b/core/utils/logger.py @@ -0,0 +1,77 @@ +''' +Author: whj +Date: 2022-02-15 17:51:42 +LastEditors: whj +LastEditTime: 2022-02-15 17:52:00 +Description: file content +''' +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + + + +logger_initialized = [] + + +def setup_logger(output=None, name="ppgan"): + """ + Initialize the ppgan logger and set its verbosity level to "INFO". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + logger.setLevel(logging.INFO) + logger.propagate = False + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", + datefmt="%m/%d %H:%M:%S") + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + + # make dir if path not exist + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.FileHandler(filename, mode='a') + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + logger_initialized.append(name) + return logger + + +def get_logger(name='ppgan'): + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + + return setup_logger(name=name) diff --git a/core/utils/vis.py b/core/utils/vis.py new file mode 100644 index 0000000..b86bf6b --- /dev/null +++ b/core/utils/vis.py @@ -0,0 +1,110 @@ +''' +Author: whj +Date: 2022-02-15 15:46:41 +LastEditors: whj +LastEditTime: 2022-06-01 13:11:44 +Description: file content +''' + +import random +import pyvista as pv +import numpy as np +import csv +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d + + +colors = ['#000080', + '#FF0000', + '#FF00FF', + '#00BFFF', + '#DC143C', + '#DAA520', + '#DDA0DD', + '#708090', + '#556B2F', + '#483D8B', + '#CD5C5C', + '#21618C', + '#1C2833', + '#4169E1', + '#1E90FF', + '#FFD700', + '#FF4500', + '#646464', + '#DC143C', + '#98FB98', + '#9370DB', + '#8B4513', + '#00FF00', + '#008080' + ] + + +def get_label(filename, factor): + retarr = np.zeros((0, 7)) + with open(filename, newline='') as csvfile: + spamreader = csv.reader(csvfile, delimiter=' ', quotechar='|') + for row in spamreader: + items = row[0].split(',') + retarr = np.insert(retarr, 0, np.asarray(items), 0) + + retarr[:, 0:6] = retarr[:, 0:6] * factor + + return retarr + + +def disp_model(filename): + pv.set_plot_theme("document") + mesh = pv.PolyData(filename+'.STL') + + plotter = pv.Plotter() + plotter.add_mesh(mesh, opacity=0.8, color='#FFFFFF') + + shapetypes = ['O ring', 'Through hole', 'Blind hole', 'Triangular passage', 'Rectangular passage', 'Circular through slot', 'Triangular through slot', 'Rectangular through slot', 'Rectangular blind slot', 'Triangular pocket', 'Rectangular pocket', 'Circular end pocket', + 'Triangular blind step', 'Circular blind step', 'Rectangular blind step', 'Rectangular through step', '2-sides through step', 'Slanted through step', 'Chamfer', 'Round', 'Vertical circular end blind slot', 'Horizontal circular end blind slot', '6-sides passage', '6-sides pocket'] + + items = get_label(filename+'.csv', 1000) + + flag = np.zeros(24) + + for i in range(items.shape[0]): + if flag[int(items[i, 6])] == 0: + plotter.add_mesh(pv.Cube((0, 0, 0), 0, 0, 0, (items[i, 0], items[i, 3], items[i, 1], items[i, 4], items[i, 2], items[i, 5])), opacity=1, color=colors[int( + items[i, 6])], style='wireframe', line_width=2, label=shapetypes[int(items[i, 6])]) + flag[int(items[i, 6])] = 1 + else: + plotter.add_mesh(pv.Cube((0, 0, 0), 0, 0, 0, (items[i, 0], items[i, 3], items[i, 1], items[i, 4], + items[i, 2], items[i, 5])), opacity=1, color=colors[int(items[i, 6])], style='wireframe', line_width=2) + + plotter.add_legend() + plotter.show() + + +def plot3D_with_labels(lowDWeights, labels): + fig = plt.figure() + ax = mpl_toolkits.mplot3d.Axes3D(fig) + X, Y, Z = lowDWeights[:, 0], lowDWeights[:, 1], lowDWeights[:, 2] + for x, y, z, s in zip(X, Y, Z, labels): + c = colors[int(s)] + ax.text3D(x, y, z, s, backgroundcolor=c, color='#FFFFFF', fontsize=9) + ax.set_xlim3d(X.min(), X.max()) + ax.set_ylim3d(Y.min(), Y.max()) + ax.set_zlim3d(Z.min(), Z.max()) + #plt.title('Visualize last layer') + plt.show() + plt.pause(0.05) + + +def plot_with_labels(lowDWeights, labels): + plt.cla() + X, Y = lowDWeights[:, 0], lowDWeights[:, 1] + for x, y, s in zip(X, Y, labels): + c = colors[int(s)] + plt.text(x, y, s, backgroundcolor=c, color='#FFFFFF', fontsize=12) + plt.xlim(X.min()-1, X.max()+1) + plt.ylim(Y.min()-1, Y.max()+1) + # plt.axis('off') + #plt.title('Visualize last layer') + plt.show() + plt.pause(0.05) diff --git a/main_simsiam.py b/main_simsiam.py new file mode 100644 index 0000000..b2da05f --- /dev/null +++ b/main_simsiam.py @@ -0,0 +1,393 @@ +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +from core.dataset import SimsiamDataset, createPartition +from core.simsiam import builder +from core.models import FeatureNet, FeatureNetLite, MsvNetLite + +# model_names = sorted(name for name in models.__dict__ +# if name.islower() and not name.startswith("__") +# and callable(models.__dict__[name])) + +parser = argparse.ArgumentParser(description='PyTorch Featurenet Training') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('--epochs', default=100, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=32, type=int, + metavar='N', + help='mini-batch size (default: 512), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.05, type=float, + metavar='LR', help='initial (base) learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum of SGD solver') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=100, type=int, + metavar='N', help='print frequency (default: 100)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=0, type=int, + help='GPU id to use.') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + +# simsiam specific configs: +parser.add_argument('--dim', default=2048, type=int, + help='feature dimension (default: 2048)') +parser.add_argument('--pred-dim', default=512, type=int, + help='hidden dimension of the predictor (default: 512)') +parser.add_argument('--fix-pred-lr', action='store_true', + help='Fix learning rate for the predictor') + +# featurenet dataset specific configs: +parser.add_argument('--resolution', dest='resolution', + default=64, type=int, help='model resolution: 16, 32, 64') +parser.add_argument('--data_path', dest='data_path', + default='data', type=str, help='path to the data') +parser.add_argument('--num_of_class', dest='num_of_class', + default=24, type=int, help='number of classes') +parser.add_argument('--num_train', dest='num_train', default=4800, + type=int, help='number of training examples per class') +parser.add_argument('--arch', dest='arch', default='FeatureNetLite', + type=str, help='network arch: FeatureNet, FeatureNetLite, MsvNetLite') +parser.add_argument('--output_dir', dest='output_dir', + default='output', type=str, help='directory to save output') +parser.add_argument('--pretrain', default='', type=str, metavar='PRETRAIN', + help='path to pretrain (default: none)') + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, + args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + args.gpu = gpu + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(*args): + pass + builtins.print = print_pass + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + # create model + print("=> creating model '{}'".format(args.arch)) + + if args.arch == 'FeatureNet': + model = FeatureNet + output_type = '3d' + elif args.arch == 'FeatureNetLite': + model = FeatureNetLite + output_type = '3d' + elif args.arch == 'MsvNetLite': + model = MsvNetLite + output_type = '2d_multiple' + else: + raise ValueError('Invalid network type') + + model = builder.SimSiam(model, args.dim, args.pred_dim, args.pretrain) + + # infer learning rate before changing batch size + init_lr = args.lr * args.batch_size / 256 + + if args.distributed: + # Apply SyncBN + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int( + (args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + # comment out the following line for debugging + # raise NotImplementedError("Only DistributedDataParallel is supported.") + else: + # AllGather implementation (batch shuffle, queue update, etc.) in + # this code only supports DistributedDataParallel. + raise NotImplementedError("Only DistributedDataParallel is supported.") + print(model) # print model after SyncBatchNorm + + # define loss function (criterion) and optimizer + criterion = nn.CosineSimilarity(dim=1).cuda(args.gpu) + + if args.fix_pred_lr: + optim_params = [{'params': model.encoder.parameters(), 'fix_lr': False}, + {'params': model.predictor.parameters(), 'fix_lr': True}] + else: + optim_params = model.parameters() + + optimizer = torch.optim.SGD(optim_params, init_lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + + # Data loading code + augmentations = ['randomRotation', 'randomPadCrop', 'randomScale'] + + train_dataset = SimsiamDataset(args.data_path, + args.num_of_class, + args.num_train, + resolution=args.resolution, + output_type=output_type, + num_cuts=12, + transform=augmentations) + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=( + train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + + + time_str = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + output_dir = os.path.join(args.output_dir, 'simsiam_ckpts_'+time_str) + makedirs(output_dir) + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, init_lr, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + if not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0): + file_path= os.path.join(output_dir, 'checkpoint_{:04d}.pth.tar'.format(epoch)) + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, is_best=False, filename=file_path) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, objs in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + objs[0] = objs[0].cuda(args.gpu, non_blocking=True) + objs[1] = objs[1].cuda(args.gpu, non_blocking=True) + + # compute output and loss + p1, p2, z1, z2 = model(x1=objs[0], x2=objs[1]) + loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5 + + losses.update(loss.item(), objs[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, init_lr, epoch, args): + """Decay the learning rate based on schedule""" + cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) + for param_group in optimizer.param_groups: + if 'fix_lr' in param_group and param_group['fix_lr']: + param_group['lr'] = init_lr + else: + param_group['lr'] = cur_lr + + +def makedirs(dir): + if not os.path.exists(dir): + # avoid error when train with multiple gpus + try: + os.makedirs(dir) + except: + pass + + +if __name__ == '__main__': + main() diff --git a/multi_test.py b/multi_test.py new file mode 100644 index 0000000..6b28143 --- /dev/null +++ b/multi_test.py @@ -0,0 +1,41 @@ +''' +Author: whj +Date: 2022-02-15 15:48:15 +LastEditors: whj +LastEditTime: 2022-02-17 00:31:34 +Description: file content +''' +import os +import sys + +import torch + +import numpy as np + +# append root path of the project to the sys path +cur_path = os.path.abspath(os.path.dirname(__file__)) +root_path = os.path.split(cur_path)[0] +sys.path.append(root_path) + +from engine import multi as mlt + +torch.backends.cudnn.benchmark = True +np.random.seed(1234) +torch.manual_seed(1234) +torch.cuda.manual_seed(1234) + + +group_idx = 3 + +p,r = mlt.test_msvnet(group_idx) + +print('Precision for the MsvNet on group ', group_idx, ': ',p) +print('Recall for the MsvNet on group ', group_idx, ': ',r) + +p,r = mlt.test_featurenet(group_idx) + + +print('Precision for the FeatureNet on group ', group_idx, ': ',p) +print('Recall for the FeatureNet on group ', group_idx, ': ',r) + +msvnet_ssl_num_1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c0c2b46 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +tqdm +warmup_scheduler +pyvista +thop +torchsummary \ No newline at end of file diff --git a/single_train.py b/single_train.py new file mode 100644 index 0000000..b5bb835 --- /dev/null +++ b/single_train.py @@ -0,0 +1,110 @@ +''' +Author: whj +Date: 2021-10-27 13:43:16 +LastEditors: whj +LastEditTime: 2022-05-31 21:19:32 +Description: file content +''' +from core.engine.trainer import ( + train_eval_model, eval_model, + infer_time_test, + draw_TSNE, draw_ROC_CM + ) +import argparse +import os +import sys +import random +import warnings + +import torch + +import numpy as np + +# close warnings +warnings.filterwarnings('ignore') + +# append root path of the project to the sys path +cur_path = os.path.abspath(os.path.dirname(__file__)) +root_path = os.path.split(cur_path)[0] +sys.path.append(root_path) + + +# set flags / seeds / env +torch.backends.cudnn.benchmark = True +# torch.backends.cudnn.deterministic = True +# random.seed(42) +# os.environ['PYTHONHASHSEED'] = str(42) +# np.random.seed(42) +# torch.manual_seed(42) +# torch.cuda.manual_seed(42) +# torch.cuda.manual_seed_all(42) + +parser = argparse.ArgumentParser(description='Feature Recognition Training') +parser.add_argument('--data_path', dest='data_path', + default='data', type=str, help='path to the data') +parser.add_argument('--resolution', dest='resolution', + default=64, type=int, help='model resolution: 16, 32, 64') +parser.add_argument('--num_of_class', dest='num_of_class', + default=24, type=int, help='number of classes') +parser.add_argument('--num_train', dest='num_train', default=2, + type=int, help='number of training examples per class') +parser.add_argument('--num_val_test', dest='num_val_test', default=600, + type=int, help='number of val/test examples per class') +parser.add_argument('--arch', dest='arch', default='FeatureNet', + type=str, help='network arch: FeatureNet, FeatureNetLite, MsvNet, MsvNetLite, BaselineNet, BaselineNet2, VoxNet') +parser.add_argument('--base_lr', dest='base_lr', default=0.001, + type=float, help='base learning rate') +parser.add_argument('--train_epochs', dest='train_epochs', default=100, + type=int, help='num of epochs at surpvised training') +parser.add_argument('--train_batchsize', dest='train_batchsize', + default=64, type=int, help='train batch size') +parser.add_argument('--val_batchsize', dest='val_batchsize', + default=64, type=int, help='valid batch size') +parser.add_argument('--weight_decay', dest='weight_decay', + default=0.0, type=float, help='weight decay') +parser.add_argument('--warmup_epochs', dest='warmup_epochs', + default=0, type=int, help='warmup epochs') +parser.add_argument('--lr_sch', dest='lr_sch', default='constant', + type=str, help='learning rate scheduler type: constant, exp, cos, multistep, step') +parser.add_argument('--optim', dest='optim', + default='adam', type=str, help='optimizer type: adam, sgdm, rmsprop, adamw') +parser.add_argument('--data_aug', dest='data_aug', action='store_true', + help='whether to use data augmentation') +parser.add_argument('--num_cuts', dest='num_cuts', default=12, + type=int, help='number of cuts') +parser.add_argument('--pretrain', dest='pretrained', default=None, + type=str, help='pretrain model directory') +parser.add_argument('--simsiam_pretrain', dest='simsiam_pretrained', default=None, + type=str, help='simsiam pretrain model directory') +parser.add_argument('--freeze', dest='freeze', action='store_true', + help='whether to freeze the encoder when loading SSL pretrained model') +parser.add_argument('--val_interval', dest='val_interval', + default=10, type=int, help='valid interval') +parser.add_argument('--output_dir', dest='output_dir', + default='output', type=str, help='directory to save output') +parser.add_argument('--model_path', dest='model_path', + default=None, type=str, help='path to the trained model') +parser.add_argument('--program_type', dest='program_type', + default='train_eval', type=str, + help='which program to run [train_eval, eval, infer_time_test, draw_TSNE, draw_ROC_CM]') +parser.add_argument('--device', dest='device', + default='gpu', type=str, help='which device to run') +parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', + help='number of data loading workers (default: 0)') + + +if __name__ == '__main__': + args = parser.parse_args() + + if args.program_type == 'train_eval': + train_eval_model(args) + elif args.program_type == 'eval': + eval_model(args) + elif args.program_type == 'infer_time_test': + infer_time_test(args) + elif args.program_type == 'draw_TSNE': + draw_TSNE(args) + elif args.program_type == 'draw_ROC_CM': + draw_ROC_CM(args) + else: + raise ValueError('program type not supported') diff --git a/visualize.py b/visualize.py new file mode 100644 index 0000000..64fe856 --- /dev/null +++ b/visualize.py @@ -0,0 +1,11 @@ +''' +Author: whj +Date: 2021-10-27 13:43:16 +LastEditors: whj +LastEditTime: 2022-02-15 16:24:21 +Description: file content +''' + +from utils import vis + +vis.disp_model('data/set10/908') \ No newline at end of file