diff --git a/dover/.ipynb_checkpoints/__init__-checkpoint.py b/dover/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index 4e943a6..0000000 --- a/dover/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,2 +0,0 @@ -from .datasets import * -from .models import * diff --git a/dover/.ipynb_checkpoints/version-checkpoint.py b/dover/.ipynb_checkpoints/version-checkpoint.py deleted file mode 100644 index ce20dbd..0000000 --- a/dover/.ipynb_checkpoints/version-checkpoint.py +++ /dev/null @@ -1,16 +0,0 @@ -__version__ = "0.5.0" - - -def parse_version_info(version_str): - version_info = [] - for x in version_str.split("."): - if x.isdigit(): - version_info.append(int(x)) - elif x.find("rc") != -1: - patch_version = x.split("rc") - version_info.append(int(patch_version[0])) - version_info.append(f"rc{patch_version[1]}") - return tuple(version_info) - - -version_info = parse_version_info(__version__) diff --git a/dover/datasets/.ipynb_checkpoints/__init__-checkpoint.py b/dover/datasets/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index ef3cf11..0000000 --- a/dover/datasets/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,3 +0,0 @@ -## API for DOVER and its variants -from .basic_datasets import * -from .dover_datasets import * diff --git a/dover/datasets/.ipynb_checkpoints/basic_datasets-checkpoint.py b/dover/datasets/.ipynb_checkpoints/basic_datasets-checkpoint.py deleted file mode 100644 index e6fa3f7..0000000 --- a/dover/datasets/.ipynb_checkpoints/basic_datasets-checkpoint.py +++ /dev/null @@ -1,812 +0,0 @@ -import os.path as osp -import random - -import cv2 -import decord -import numpy as np -import skvideo.io -import torch -import torchvision -from decord import VideoReader, cpu, gpu -from tqdm import tqdm - -random.seed(42) - -decord.bridge.set_bridge("torch") - - -def get_spatial_fragments( - video, - fragments_h=7, - fragments_w=7, - fsize_h=32, - fsize_w=32, - aligned=32, - nfrags=1, - random=False, - fallback_type="upsample", -): - size_h = fragments_h * fsize_h - size_w = fragments_w * fsize_w - - ## situation for images - if video.shape[1] == 1: - aligned = 1 - - dur_t, res_h, res_w = video.shape[-3:] - ratio = min(res_h / size_h, res_w / size_w) - if fallback_type == "upsample" and ratio < 1: - - ovideo = video - video = torch.nn.functional.interpolate( - video / 255.0, scale_factor=1 / ratio, mode="bilinear" - ) - video = (video * 255.0).type_as(ovideo) - - assert dur_t % aligned == 0, "Please provide match vclip and align index" - size = size_h, size_w - - ## make sure that sampling will not run out of the picture - hgrids = torch.LongTensor( - [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] - ) - wgrids = torch.LongTensor( - [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] - ) - hlength, wlength = res_h // fragments_h, res_w // fragments_w - - if random: - print("This part is deprecated. Please remind that.") - if res_h > fsize_h: - rnd_h = torch.randint( - res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - if res_w > fsize_w: - rnd_w = torch.randint( - res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - else: - if hlength > fsize_h: - rnd_h = torch.randint( - hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - if wlength > fsize_w: - rnd_w = torch.randint( - wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - - target_video = torch.zeros(video.shape[:-2] + size).to(video.device) - # target_videos = [] - - for i, hs in enumerate(hgrids): - for j, ws in enumerate(wgrids): - for t in range(dur_t // aligned): - t_s, t_e = t * aligned, (t + 1) * aligned - h_s, h_e = i * fsize_h, (i + 1) * fsize_h - w_s, w_e = j * fsize_w, (j + 1) * fsize_w - if random: - h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h - w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w - else: - h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h - w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w - target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ - :, t_s:t_e, h_so:h_eo, w_so:w_eo - ] - # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) - # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) - # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments - return target_video - - -class FragmentSampleFrames: - def __init__(self, fsize_t, fragments_t, frame_interval=1, num_clips=1): - - self.fragments_t = fragments_t - self.fsize_t = fsize_t - self.size_t = fragments_t * fsize_t - self.frame_interval = frame_interval - self.num_clips = num_clips - - def get_frame_indices(self, num_frames): - - tgrids = np.array( - [num_frames // self.fragments_t * i for i in range(self.fragments_t)], - dtype=np.int32, - ) - tlength = num_frames // self.fragments_t - - if tlength > self.fsize_t * self.frame_interval: - rnd_t = np.random.randint( - 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) - ) - else: - rnd_t = np.zeros(len(tgrids), dtype=np.int32) - - ranges_t = ( - np.arange(self.fsize_t)[None, :] * self.frame_interval - + rnd_t[:, None] - + tgrids[:, None] - ) - return np.concatenate(ranges_t) - - def __call__(self, total_frames, train=False, start_index=0): - frame_inds = [] - for i in range(self.num_clips): - frame_inds += [self.get_frame_indices(total_frames)] - frame_inds = np.concatenate(frame_inds) - frame_inds = np.mod(frame_inds + start_index, total_frames) - return frame_inds - - -class SampleFrames: - def __init__(self, clip_len, frame_interval=1, num_clips=1): - - self.clip_len = clip_len - self.frame_interval = frame_interval - self.num_clips = num_clips - - def _get_train_clips(self, num_frames): - """Get clip offsets in train mode. - - It will calculate the average interval for selected frames, - and randomly shift them within offsets between [0, avg_interval]. - If the total number of frames is smaller than clips num or origin - frames length, it will return all zero indices. - - Args: - num_frames (int): Total number of frame in the video. - - Returns: - np.ndarray: Sampled frame indices in train mode. - """ - ori_clip_len = self.clip_len * self.frame_interval - avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips - - if avg_interval > 0: - base_offsets = np.arange(self.num_clips) * avg_interval - clip_offsets = base_offsets + np.random.randint( - avg_interval, size=self.num_clips - ) - elif num_frames > max(self.num_clips, ori_clip_len): - clip_offsets = np.sort( - np.random.randint(num_frames - ori_clip_len + 1, size=self.num_clips) - ) - elif avg_interval == 0: - ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips - clip_offsets = np.around(np.arange(self.num_clips) * ratio) - else: - clip_offsets = np.zeros((self.num_clips,), dtype=np.int) - return clip_offsets - - def _get_test_clips(self, num_frames, start_index=0): - """Get clip offsets in test mode. - - Calculate the average interval for selected frames, and shift them - fixedly by avg_interval/2. - - Args: - num_frames (int): Total number of frame in the video. - - Returns: - np.ndarray: Sampled frame indices in test mode. - """ - ori_clip_len = self.clip_len * self.frame_interval - avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips) - if num_frames > ori_clip_len - 1: - base_offsets = np.arange(self.num_clips) * avg_interval - clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32) - else: - clip_offsets = np.zeros((self.num_clips,), dtype=np.int32) - return clip_offsets - - def __call__(self, total_frames, train=False, start_index=0): - """Perform the SampleFrames loading. - - Args: - results (dict): The resulting dict to be modified and passed - to the next transform in pipeline. - """ - if train: - clip_offsets = self._get_train_clips(total_frames) - else: - clip_offsets = self._get_test_clips(total_frames) - frame_inds = ( - clip_offsets[:, None] - + np.arange(self.clip_len)[None, :] * self.frame_interval - ) - frame_inds = np.concatenate(frame_inds) - - frame_inds = frame_inds.reshape((-1, self.clip_len)) - frame_inds = np.mod(frame_inds, total_frames) - frame_inds = np.concatenate(frame_inds) + start_index - return frame_inds.astype(np.int32) - - -class FastVQAPlusPlusDataset(torch.utils.data.Dataset): - def __init__( - self, - ann_file, - data_prefix, - frame_interval=2, - aligned=32, - fragments=(8, 8, 8), - fsize=(4, 32, 32), - num_clips=1, - nfrags=1, - cache_in_memory=False, - phase="test", - fallback_type="oversample", - ): - """ - Fragments. - args: - fragments: G_f as in the paper. - fsize: S_f as in the paper. - nfrags: number of samples (spatially) as in the paper. - num_clips: number of samples (temporally) as in the paper. - """ - self.ann_file = ann_file - self.data_prefix = data_prefix - self.frame_interval = frame_interval - self.num_clips = num_clips - self.fragments = fragments - self.fsize = fsize - self.nfrags = nfrags - self.clip_len = fragments[0] * fsize[0] - self.aligned = aligned - self.fallback_type = fallback_type - self.sampler = FragmentSampleFrames( - fsize[0], fragments[0], frame_interval, num_clips - ) - self.video_infos = [] - self.phase = phase - self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) - self.std = torch.FloatTensor([58.395, 57.12, 57.375]) - if isinstance(self.ann_file, list): - self.video_infos = self.ann_file - else: - with open(self.ann_file, "r") as fin: - for line in fin: - line_split = line.strip().split(",") - filename, _, _, label = line_split - label = float(label) - filename = osp.join(self.data_prefix, filename) - self.video_infos.append(dict(filename=filename, label=label)) - if cache_in_memory: - self.cache = {} - for i in tqdm(range(len(self)), desc="Caching fragments"): - self.cache[i] = self.__getitem__(i, tocache=True) - else: - self.cache = None - - def __getitem__( - self, index, tocache=False, need_original_frames=False, - ): - if tocache or self.cache is None: - fx, fy = self.fragments[1:] - fsx, fsy = self.fsize[1:] - video_info = self.video_infos[index] - filename = video_info["filename"] - label = video_info["label"] - if filename.endswith(".yuv"): - video = skvideo.io.vread( - filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} - ) - frame_inds = self.sampler(video.shape[0], self.phase == "train") - imgs = [torch.from_numpy(video[idx]) for idx in frame_inds] - else: - vreader = VideoReader(filename) - frame_inds = self.sampler(len(vreader), self.phase == "train") - frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} - imgs = [frame_dict[idx] for idx in frame_inds] - img_shape = imgs[0].shape - video = torch.stack(imgs, 0) - video = video.permute(3, 0, 1, 2) - if self.nfrags == 1: - vfrag = get_spatial_fragments( - video, - fx, - fy, - fsx, - fsy, - aligned=self.aligned, - fallback_type=self.fallback_type, - ) - else: - vfrag = get_spatial_fragments( - video, - fx, - fy, - fsx, - fsy, - aligned=self.aligned, - fallback_type=self.fallback_type, - ) - for i in range(1, self.nfrags): - vfrag = torch.cat( - ( - vfrag, - get_spatial_fragments( - video, - fragments, - fx, - fy, - fsx, - fsy, - aligned=self.aligned, - fallback_type=self.fallback_type, - ), - ), - 1, - ) - if tocache: - return (vfrag, frame_inds, label, img_shape) - else: - vfrag, frame_inds, label, img_shape = self.cache[index] - vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) - data = { - "video": vfrag.reshape( - (-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:] - ).transpose( - 0, 1 - ), # B, V, T, C, H, W - "frame_inds": frame_inds, - "gt_label": label, - "original_shape": img_shape, - } - if need_original_frames: - data["original_video"] = video.reshape( - (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] - ).transpose(0, 1) - return data - - def __len__(self): - return len(self.video_infos) - - -class FragmentVideoDataset(torch.utils.data.Dataset): - def __init__( - self, - ann_file, - data_prefix, - clip_len=32, - frame_interval=2, - num_clips=4, - aligned=32, - fragments=7, - fsize=32, - nfrags=1, - cache_in_memory=False, - phase="test", - ): - """ - Fragments. - args: - fragments: G_f as in the paper. - fsize: S_f as in the paper. - nfrags: number of samples as in the paper. - """ - self.ann_file = ann_file - self.data_prefix = data_prefix - self.clip_len = clip_len - self.frame_interval = frame_interval - self.num_clips = num_clips - self.fragments = fragments - self.fsize = fsize - self.nfrags = nfrags - self.aligned = aligned - self.sampler = SampleFrames(clip_len, frame_interval, num_clips) - self.video_infos = [] - self.phase = phase - self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) - self.std = torch.FloatTensor([58.395, 57.12, 57.375]) - if isinstance(self.ann_file, list): - self.video_infos = self.ann_file - else: - with open(self.ann_file, "r") as fin: - for line in fin: - line_split = line.strip().split(",") - filename, _, _, label = line_split - label = float(label) - filename = osp.join(self.data_prefix, filename) - self.video_infos.append(dict(filename=filename, label=label)) - if cache_in_memory: - self.cache = {} - for i in tqdm(range(len(self)), desc="Caching fragments"): - self.cache[i] = self.__getitem__(i, tocache=True) - else: - self.cache = None - - def __getitem__( - self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False, - ): - if tocache or self.cache is None: - if fragments == -1: - fragments = self.fragments - if fsize == -1: - fsize = self.fsize - video_info = self.video_infos[index] - filename = video_info["filename"] - label = video_info["label"] - if filename.endswith(".yuv"): - video = skvideo.io.vread( - filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} - ) - frame_inds = self.sampler(video.shape[0], self.phase == "train") - imgs = [torch.from_numpy(video[idx]) for idx in frame_inds] - else: - vreader = VideoReader(filename) - frame_inds = self.sampler(len(vreader), self.phase == "train") - frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} - imgs = [frame_dict[idx] for idx in frame_inds] - img_shape = imgs[0].shape - video = torch.stack(imgs, 0) - video = video.permute(3, 0, 1, 2) - if self.nfrags == 1: - vfrag = get_spatial_fragments( - video, fragments, fragments, fsize, fsize, aligned=self.aligned - ) - else: - vfrag = get_spatial_fragments( - video, fragments, fragments, fsize, fsize, aligned=self.aligned - ) - for i in range(1, self.nfrags): - vfrag = torch.cat( - ( - vfrag, - get_spatial_fragments( - video, - fragments, - fragments, - fsize, - fsize, - aligned=self.aligned, - ), - ), - 1, - ) - if tocache: - return (vfrag, frame_inds, label, img_shape) - else: - vfrag, frame_inds, label, img_shape = self.cache[index] - vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) - data = { - "video": vfrag.reshape( - (-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:] - ).transpose( - 0, 1 - ), # B, V, T, C, H, W - "frame_inds": frame_inds, - "gt_label": label, - "original_shape": img_shape, - } - if need_original_frames: - data["original_video"] = video.reshape( - (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] - ).transpose(0, 1) - return data - - def __len__(self): - return len(self.video_infos) - - -class ResizedVideoDataset(torch.utils.data.Dataset): - def __init__( - self, - ann_file, - data_prefix, - clip_len=32, - frame_interval=2, - num_clips=4, - aligned=32, - size=224, - cache_in_memory=False, - phase="test", - ): - """ - Using resizing. - """ - self.ann_file = ann_file - self.data_prefix = data_prefix - self.clip_len = clip_len - self.frame_interval = frame_interval - self.num_clips = num_clips - self.size = size - self.aligned = aligned - self.sampler = SampleFrames(clip_len, frame_interval, num_clips) - self.video_infos = [] - self.phase = phase - self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) - self.std = torch.FloatTensor([58.395, 57.12, 57.375]) - if isinstance(self.ann_file, list): - self.video_infos = self.ann_file - else: - with open(self.ann_file, "r") as fin: - for line in fin: - line_split = line.strip().split(",") - filename, _, _, label = line_split - label = float(label) - filename = osp.join(self.data_prefix, filename) - self.video_infos.append(dict(filename=filename, label=label)) - if cache_in_memory: - self.cache = {} - for i in tqdm(range(len(self)), desc="Caching resized videos"): - self.cache[i] = self.__getitem__(i, tocache=True) - else: - self.cache = None - - def __getitem__(self, index, tocache=False, need_original_frames=False): - if tocache or self.cache is None: - video_info = self.video_infos[index] - filename = video_info["filename"] - label = video_info["label"] - vreader = VideoReader(filename) - frame_inds = self.sampler(len(vreader), self.phase == "train") - frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} - imgs = [frame_dict[idx] for idx in frame_inds] - img_shape = imgs[0].shape - video = torch.stack(imgs, 0) - video = video.permute(3, 0, 1, 2) - video = torch.nn.functional.interpolate(video, size=(self.size, self.size)) - if tocache: - return (vfrag, frame_inds, label, img_shape) - else: - vfrag, frame_inds, label, img_shape = self.cache[index] - vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) - data = { - "video": vfrag.reshape( - (-1, self.num_clips, self.clip_len) + vfrag.shape[2:] - ).transpose( - 0, 1 - ), # B, V, T, C, H, W - "frame_inds": frame_inds, - "gt_label": label, - "original_shape": img_shape, - } - if need_original_frames: - data["original_video"] = video.reshape( - (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] - ).transpose(0, 1) - return data - - def __len__(self): - return len(self.video_infos) - - -class CroppedVideoDataset(FragmentVideoDataset): - def __init__( - self, - ann_file, - data_prefix, - clip_len=32, - frame_interval=2, - num_clips=4, - aligned=32, - size=224, - ncrops=1, - cache_in_memory=False, - phase="test", - ): - - """ - Regard Cropping as a special case for Fragments in Grid 1*1. - """ - super().__init__( - ann_file, - data_prefix, - clip_len=clip_len, - frame_interval=frame_interval, - num_clips=num_clips, - aligned=aligned, - fragments=1, - fsize=224, - nfrags=ncrops, - cache_in_memory=cache_in_memory, - phase=phase, - ) - - -class FragmentImageDataset(torch.utils.data.Dataset): - def __init__( - self, - ann_file, - data_prefix, - fragments=7, - fsize=32, - nfrags=1, - cache_in_memory=False, - phase="test", - ): - self.ann_file = ann_file - self.data_prefix = data_prefix - self.fragments = fragments - self.fsize = fsize - self.nfrags = nfrags - self.image_infos = [] - self.phase = phase - self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) - self.std = torch.FloatTensor([58.395, 57.12, 57.375]) - if isinstance(self.ann_file, list): - self.image_infos = self.ann_file - else: - with open(self.ann_file, "r") as fin: - for line in fin: - line_split = line.strip().split(",") - filename, _, _, label = line_split - label = float(label) - filename = osp.join(self.data_prefix, filename) - self.image_infos.append(dict(filename=filename, label=label)) - if cache_in_memory: - self.cache = {} - for i in tqdm(range(len(self)), desc="Caching fragments"): - self.cache[i] = self.__getitem__(i, tocache=True) - else: - self.cache = None - - def __getitem__( - self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False - ): - if tocache or self.cache is None: - if fragments == -1: - fragments = self.fragments - if fsize == -1: - fsize = self.fsize - image_info = self.image_infos[index] - filename = image_info["filename"] - label = image_info["label"] - try: - img = torchvision.io.read_image(filename) - except: - img = cv2.imread(filename) - img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1) - img_shape = img.shape[1:] - image = img.unsqueeze(1) - if self.nfrags == 1: - ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize) - else: - ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize) - for i in range(1, self.nfrags): - ifrag = torch.cat( - ( - ifrag, - get_spatial_fragments( - image, fragments, fragments, fsize, fsize - ), - ), - 1, - ) - if tocache: - return (ifrag, label, img_shape) - else: - ifrag, label, img_shape = self.cache[index] - if self.nfrags == 1: - ifrag = ( - ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) - .squeeze(0) - .permute(2, 0, 1) - ) - else: - ### During testing, one image as a batch - ifrag = ( - ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) - .squeeze(0) - .permute(0, 3, 1, 2) - ) - data = { - "image": ifrag, - "gt_label": label, - "original_shape": img_shape, - "name": filename, - } - if need_original_frames: - data["original_image"] = image.squeeze(1) - return data - - def __len__(self): - return len(self.image_infos) - - -class ResizedImageDataset(torch.utils.data.Dataset): - def __init__( - self, ann_file, data_prefix, size=224, cache_in_memory=False, phase="test", - ): - self.ann_file = ann_file - self.data_prefix = data_prefix - self.size = size - self.image_infos = [] - self.phase = phase - self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) - self.std = torch.FloatTensor([58.395, 57.12, 57.375]) - if isinstance(self.ann_file, list): - self.image_infos = self.ann_file - else: - with open(self.ann_file, "r") as fin: - for line in fin: - line_split = line.strip().split(",") - filename, _, _, label = line_split - label = float(label) - filename = osp.join(self.data_prefix, filename) - self.image_infos.append(dict(filename=filename, label=label)) - if cache_in_memory: - self.cache = {} - for i in tqdm(range(len(self)), desc="Caching fragments"): - self.cache[i] = self.__getitem__(i, tocache=True) - else: - self.cache = None - - def __getitem__( - self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False - ): - if tocache or self.cache is None: - if fragments == -1: - fragments = self.fragments - if fsize == -1: - fsize = self.fsize - image_info = self.image_infos[index] - filename = image_info["filename"] - label = image_info["label"] - img = torchvision.io.read_image(filename) - img_shape = img.shape[1:] - image = img.unsqueeze(1) - if self.nfrags == 1: - ifrag = get_spatial_fragments(image, fragments, fsize) - else: - ifrag = get_spatial_fragments(image, fragments, fsize) - for i in range(1, self.nfrags): - ifrag = torch.cat( - (ifrag, get_spatial_fragments(image, fragments, fsize)), 1 - ) - if tocache: - return (ifrag, label, img_shape) - else: - ifrag, label, img_shape = self.cache[index] - ifrag = ( - ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) - .squeeze(0) - .permute(2, 0, 1) - ) - data = { - "image": ifrag, - "gt_label": label, - "original_shape": img_shape, - } - if need_original_frames: - data["original_image"] = image.squeeze(1) - return data - - def __len__(self): - return len(self.image_infos) - - -class CroppedImageDataset(FragmentImageDataset): - def __init__( - self, - ann_file, - data_prefix, - size=224, - ncrops=1, - cache_in_memory=False, - phase="test", - ): - - """ - Regard Cropping as a special case for Fragments in Grid 1*1. - """ - super().__init__( - ann_file, - data_prefix, - fragments=1, - fsize=224, - nfrags=ncrops, - cache_in_memory=cache_in_memory, - phase=phase, - ) diff --git a/dover/datasets/.ipynb_checkpoints/dover_datasets-checkpoint.py b/dover/datasets/.ipynb_checkpoints/dover_datasets-checkpoint.py deleted file mode 100644 index 28e72ec..0000000 --- a/dover/datasets/.ipynb_checkpoints/dover_datasets-checkpoint.py +++ /dev/null @@ -1,433 +0,0 @@ -import copy -import glob -import os -import os.path as osp -import random -from functools import lru_cache - -import cv2 -import decord -import numpy as np -import skvideo.io -import torch -import torchvision -from decord import VideoReader, cpu, gpu -from tqdm import tqdm - -random.seed(42) - -decord.bridge.set_bridge("torch") - - -def get_spatial_fragments( - video, - fragments_h=7, - fragments_w=7, - fsize_h=32, - fsize_w=32, - aligned=32, - nfrags=1, - random=False, - random_upsample=False, - fallback_type="upsample", - upsample=-1, - **kwargs, -): - if upsample > 0: - old_h, old_w = video.shape[-2], video.shape[-1] - if old_h >= old_w: - w = upsample - h = int(upsample * old_h / old_w) - else: - h = upsample - w = int(upsample * old_w / old_h) - - video = get_resized_video(video, h, w) - size_h = fragments_h * fsize_h - size_w = fragments_w * fsize_w - ## video: [C,T,H,W] - ## situation for images - if video.shape[1] == 1: - aligned = 1 - - dur_t, res_h, res_w = video.shape[-3:] - ratio = min(res_h / size_h, res_w / size_w) - if fallback_type == "upsample" and ratio < 1: - - ovideo = video - video = torch.nn.functional.interpolate( - video / 255.0, scale_factor=1 / ratio, mode="bilinear" - ) - video = (video * 255.0).type_as(ovideo) - - if random_upsample: - - randratio = random.random() * 0.5 + 1 - video = torch.nn.functional.interpolate( - video / 255.0, scale_factor=randratio, mode="bilinear" - ) - video = (video * 255.0).type_as(ovideo) - - assert dur_t % aligned == 0, "Please provide match vclip and align index" - size = size_h, size_w - - ## make sure that sampling will not run out of the picture - hgrids = torch.LongTensor( - [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] - ) - wgrids = torch.LongTensor( - [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] - ) - hlength, wlength = res_h // fragments_h, res_w // fragments_w - - if random: - print("This part is deprecated. Please remind that.") - if res_h > fsize_h: - rnd_h = torch.randint( - res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - if res_w > fsize_w: - rnd_w = torch.randint( - res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - else: - if hlength > fsize_h: - rnd_h = torch.randint( - hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - if wlength > fsize_w: - rnd_w = torch.randint( - wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) - ) - else: - rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() - - target_video = torch.zeros(video.shape[:-2] + size).to(video.device) - # target_videos = [] - - for i, hs in enumerate(hgrids): - for j, ws in enumerate(wgrids): - for t in range(dur_t // aligned): - t_s, t_e = t * aligned, (t + 1) * aligned - h_s, h_e = i * fsize_h, (i + 1) * fsize_h - w_s, w_e = j * fsize_w, (j + 1) * fsize_w - if random: - h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h - w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w - else: - h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h - w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w - target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ - :, t_s:t_e, h_so:h_eo, w_so:w_eo - ] - # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) - # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) - # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments - return target_video - - -@lru_cache -def get_resize_function(size_h, size_w, target_ratio=1, random_crop=False): - if random_crop: - return torchvision.transforms.RandomResizedCrop( - (size_h, size_w), scale=(0.40, 1.0) - ) - if target_ratio > 1: - size_h = int(target_ratio * size_w) - assert size_h > size_w - elif target_ratio < 1: - size_w = int(size_h / target_ratio) - assert size_w > size_h - return torchvision.transforms.Resize((size_h, size_w)) - - -def get_resized_video( - video, size_h=224, size_w=224, random_crop=False, arp=False, **kwargs, -): - video = video.permute(1, 0, 2, 3) - resize_opt = get_resize_function( - size_h, size_w, video.shape[-2] / video.shape[-1] if arp else 1, random_crop - ) - video = resize_opt(video).permute(1, 0, 2, 3) - return video - - -def get_arp_resized_video( - video, short_edge=224, train=False, **kwargs, -): - if train: ## if during training, will random crop into square and then resize - res_h, res_w = video.shape[-2:] - ori_short_edge = min(video.shape[-2:]) - if res_h > ori_short_edge: - rnd_h = random.randrange(res_h - ori_short_edge) - video = video[..., rnd_h : rnd_h + ori_short_edge, :] - elif res_w > ori_short_edge: - rnd_w = random.randrange(res_w - ori_short_edge) - video = video[..., :, rnd_h : rnd_h + ori_short_edge] - ori_short_edge = min(video.shape[-2:]) - scale_factor = short_edge / ori_short_edge - ovideo = video - video = torch.nn.functional.interpolate( - video / 255.0, scale_factors=scale_factor, mode="bilinear" - ) - video = (video * 255.0).type_as(ovideo) - return video - - -def get_arp_fragment_video( - video, short_fragments=7, fsize=32, train=False, **kwargs, -): - if ( - train - ): ## if during training, will random crop into square and then get fragments - res_h, res_w = video.shape[-2:] - ori_short_edge = min(video.shape[-2:]) - if res_h > ori_short_edge: - rnd_h = random.randrange(res_h - ori_short_edge) - video = video[..., rnd_h : rnd_h + ori_short_edge, :] - elif res_w > ori_short_edge: - rnd_w = random.randrange(res_w - ori_short_edge) - video = video[..., :, rnd_h : rnd_h + ori_short_edge] - kwargs["fsize_h"], kwargs["fsize_w"] = fsize, fsize - res_h, res_w = video.shape[-2:] - if res_h > res_w: - kwargs["fragments_w"] = short_fragments - kwargs["fragments_h"] = int(short_fragments * res_h / res_w) - else: - kwargs["fragments_h"] = short_fragments - kwargs["fragments_w"] = int(short_fragments * res_w / res_h) - return get_spatial_fragments(video, **kwargs) - - -def get_cropped_video( - video, size_h=224, size_w=224, **kwargs, -): - kwargs["fragments_h"], kwargs["fragments_w"] = 1, 1 - kwargs["fsize_h"], kwargs["fsize_w"] = size_h, size_w - return get_spatial_fragments(video, **kwargs) - - -def get_single_view( - video, sample_type="aesthetic", **kwargs, -): - if sample_type.startswith("aesthetic"): - video = get_resized_video(video, **kwargs) - elif sample_type.startswith("technical"): - video = get_spatial_fragments(video, **kwargs) - elif sample_type == "original": - return video - - return video - - -def spatial_temporal_view_decomposition( - video_path, sample_types, samplers, is_train=False, augment=False, -): - video = {} - if video_path.endswith(".yuv"): - print("This part will be deprecated due to large memory cost.") - ## This is only an adaptation to LIVE-Qualcomm - ovideo = skvideo.io.vread( - video_path, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} - ) - for stype in samplers: - frame_inds = samplers[stype](ovideo.shape[0], is_train) - imgs = [torch.from_numpy(ovideo[idx]) for idx in frame_inds] - video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) - del ovideo - else: - decord.bridge.set_bridge("torch") - vreader = VideoReader(video_path) - ### Avoid duplicated video decoding!!! Important!!!! - all_frame_inds = [] - frame_inds = {} - for stype in samplers: - frame_inds[stype] = samplers[stype](len(vreader), is_train) - all_frame_inds.append(frame_inds[stype]) - - ### Each frame is only decoded one time!!! - all_frame_inds = np.concatenate(all_frame_inds, 0) - frame_dict = {idx: vreader[idx] for idx in np.unique(all_frame_inds)} - - for stype in samplers: - imgs = [frame_dict[idx] for idx in frame_inds[stype]] - video[stype] = torch.stack(imgs, 0).permute(3, 0, 1, 2) - - sampled_video = {} - for stype, sopt in sample_types.items(): - sampled_video[stype] = get_single_view(video[stype], stype, **sopt) - return sampled_video, frame_inds - - -import random - -import numpy as np - - -class UnifiedFrameSampler: - def __init__( - self, fsize_t, fragments_t, frame_interval=1, num_clips=1, drop_rate=0.0, - ): - - self.fragments_t = fragments_t - self.fsize_t = fsize_t - self.size_t = fragments_t * fsize_t - self.frame_interval = frame_interval - self.num_clips = num_clips - self.drop_rate = drop_rate - - def get_frame_indices(self, num_frames, train=False): - - tgrids = np.array( - [num_frames // self.fragments_t * i for i in range(self.fragments_t)], - dtype=np.int32, - ) - tlength = num_frames // self.fragments_t - - if tlength > self.fsize_t * self.frame_interval: - rnd_t = np.random.randint( - 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) - ) - else: - rnd_t = np.zeros(len(tgrids), dtype=np.int32) - - ranges_t = ( - np.arange(self.fsize_t)[None, :] * self.frame_interval - + rnd_t[:, None] - + tgrids[:, None] - ) - - drop = random.sample( - list(range(self.fragments_t)), int(self.fragments_t * self.drop_rate) - ) - dropped_ranges_t = [] - for i, rt in enumerate(ranges_t): - if i not in drop: - dropped_ranges_t.append(rt) - return np.concatenate(dropped_ranges_t) - - def __call__(self, total_frames, train=False, start_index=0): - frame_inds = [] - - for i in range(self.num_clips): - frame_inds += [self.get_frame_indices(total_frames)] - - frame_inds = np.concatenate(frame_inds) - frame_inds = np.mod(frame_inds + start_index, total_frames) - return frame_inds.astype(np.int32) - - -class ViewDecompositionDataset(torch.utils.data.Dataset): - def __init__(self, opt): - ## opt is a dictionary that includes options for video sampling - - super().__init__() - - self.weight = opt.get("weight", 0.5) - - self.fully_supervised = opt.get("fully_supervised", False) - print("Fully supervised:", self.fully_supervised) - - self.video_infos = [] - self.ann_file = opt["anno_file"] - self.data_prefix = opt["data_prefix"] - self.opt = opt - self.sample_types = opt["sample_types"] - self.data_backend = opt.get("data_backend", "disk") - self.augment = opt.get("augment", False) - if self.data_backend == "petrel": - from petrel_client import client - - self.client = client.Client(enable_mc=True) - - self.phase = opt["phase"] - self.crop = opt.get("random_crop", False) - self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) - self.std = torch.FloatTensor([58.395, 57.12, 57.375]) - self.samplers = {} - for stype, sopt in opt["sample_types"].items(): - if "t_frag" not in sopt: - # resized temporal sampling for TQE in DOVER - self.samplers[stype] = UnifiedFrameSampler( - sopt["clip_len"], sopt["num_clips"], sopt["frame_interval"] - ) - else: - # temporal sampling for AQE in DOVER - self.samplers[stype] = UnifiedFrameSampler( - sopt["clip_len"] // sopt["t_frag"], - sopt["t_frag"], - sopt["frame_interval"], - sopt["num_clips"], - ) - print( - stype + " branch sampled frames:", - self.samplers[stype](240, self.phase == "train"), - ) - - if isinstance(self.ann_file, list): - self.video_infos = self.ann_file - else: - try: - with open(self.ann_file, "r") as fin: - for line in fin: - line_split = line.strip().split(",") - filename, a, t, label = line_split - if self.fully_supervised: - label = float(a), float(t), float(label) - else: - label = float(label) - filename = osp.join(self.data_prefix, filename) - self.video_infos.append(dict(filename=filename, label=label)) - except: - #### No Label Testing - video_filenames = [] - for (root, dirs, files) in os.walk(self.data_prefix, topdown=True): - for file in files: - if file.endswith(".mp4"): - video_filenames += [os.path.join(root, file)] - print(len(video_filenames)) - video_filenames = sorted(video_filenames) - for filename in video_filenames: - self.video_infos.append(dict(filename=filename, label=-1)) - - def __getitem__(self, index): - video_info = self.video_infos[index] - filename = video_info["filename"] - label = video_info["label"] - - try: - ## Read Original Frames - ## Process Frames - data, frame_inds = spatial_temporal_view_decomposition( - filename, - self.sample_types, - self.samplers, - self.phase == "train", - self.augment and (self.phase == "train"), - ) - - for k, v in data.items(): - data[k] = ((v.permute(1, 2, 3, 0) - self.mean) / self.std).permute( - 3, 0, 1, 2 - ) - - data["num_clips"] = {} - for stype, sopt in self.sample_types.items(): - data["num_clips"][stype] = sopt["num_clips"] - data["frame_inds"] = frame_inds - data["gt_label"] = label - data["name"] = filename # osp.basename(video_info["filename"]) - except: - # exception flow - return {"name": filename} - - return data - - def __len__(self): - return len(self.video_infos) diff --git a/dover/models/.ipynb_checkpoints/__init__-checkpoint.py b/dover/models/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index 47a4840..0000000 --- a/dover/models/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,17 +0,0 @@ -from .conv_backbone import convnext_3d_small, convnext_3d_tiny -from .evaluator import DOVER, BaseEvaluator, BaseImageEvaluator -from .head import IQAHead, VARHead, VQAHead, MaxVQAHead -from .swin_backbone import SwinTransformer2D as IQABackbone -from .swin_backbone import SwinTransformer3D as VQABackbone -from .swin_backbone import swin_3d_small, swin_3d_tiny - -__all__ = [ - "VQABackbone", - "IQABackbone", - "VQAHead", - "IQAHead", - "VARHead", - "BaseEvaluator", - "BaseImageEvaluator", - "DOVER", -] diff --git a/dover/models/.ipynb_checkpoints/backbone_v0_1-checkpoint.py b/dover/models/.ipynb_checkpoints/backbone_v0_1-checkpoint.py deleted file mode 100644 index dee5b4a..0000000 --- a/dover/models/.ipynb_checkpoints/backbone_v0_1-checkpoint.py +++ /dev/null @@ -1,862 +0,0 @@ -from functools import lru_cache, reduce -from operator import mul - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from timm.models.layers import DropPath, trunc_normal_ - - -class Mlp(nn.Module): - """Multilayer perceptron.""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, D, H, W, C) - window_size (tuple[int]): window size - - Returns: - windows: (B*num_windows, window_size*window_size, C) - """ - B, D, H, W, C = x.shape - x = x.view( - B, - D // window_size[0], - window_size[0], - H // window_size[1], - window_size[1], - W // window_size[2], - window_size[2], - C, - ) - windows = ( - x.permute(0, 1, 3, 5, 2, 4, 6, 7) - .contiguous() - .view(-1, reduce(mul, window_size), C) - ) - return windows - - -def window_reverse(windows, window_size, B, D, H, W): - """ - Args: - windows: (B*num_windows, window_size, window_size, C) - window_size (tuple[int]): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, D, H, W, C) - """ - x = windows.view( - B, - D // window_size[0], - H // window_size[1], - W // window_size[2], - window_size[0], - window_size[1], - window_size[2], - -1, - ) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) - return x - - -def get_window_size(x_size, window_size, shift_size=None): - use_window_size = list(window_size) - if shift_size is not None: - use_shift_size = list(shift_size) - for i in range(len(x_size)): - if x_size[i] <= window_size[i]: - use_window_size[i] = x_size[i] - if shift_size is not None: - use_shift_size[i] = 0 - - if shift_size is None: - return tuple(use_window_size) - else: - return tuple(use_window_size), tuple(use_shift_size) - - -class WindowAttention3D(nn.Module): - """Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The temporal length, height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__( - self, - dim, - window_size, - num_heads, - qkv_bias=False, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wd, Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros( - (2 * window_size[0] - 1) - * (2 * window_size[1] - 1) - * (2 * window_size[2] - 1), - num_heads, - ) - ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_d = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack( - torch.meshgrid(coords_d, coords_h, coords_w) - ) # 3, Wd, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :] - ) # 3, Wd*Wh*Ww, Wd*Wh*Ww - relative_coords = relative_coords.permute( - 1, 2, 0 - ).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * ( - 2 * self.window_size[2] - 1 - ) - relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 - relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=0.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """Forward function. - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, N, N) or None - """ - B_, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B_, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C - - q = q * self.scale - attn = q @ k.transpose(-2, -1) - - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index[:N, :N].reshape(-1) - ].reshape( - N, N, -1 - ) # Wd*Wh*Ww,Wd*Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( - 1 - ).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SwinTransformerBlock3D(nn.Module): - """Swin Transformer Block. - - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - window_size (tuple[int]): Window size. - shift_size (tuple[int]): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__( - self, - dim, - num_heads, - window_size=(2, 7, 7), - shift_size=(0, 0, 0), - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - use_checkpoint=False, - jump_attention=False, - ): - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - self.use_checkpoint = use_checkpoint - self.jump_attention = jump_attention - - assert ( - 0 <= self.shift_size[0] < self.window_size[0] - ), "shift_size must in 0-window_size" - assert ( - 0 <= self.shift_size[1] < self.window_size[1] - ), "shift_size must in 0-window_size" - assert ( - 0 <= self.shift_size[2] < self.window_size[2] - ), "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention3D( - dim, - window_size=self.window_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - def forward_part1(self, x, mask_matrix): - B, D, H, W, C = x.shape - window_size, shift_size = get_window_size( - (D, H, W), self.window_size, self.shift_size - ) - - x = self.norm1(x) - # pad feature maps to multiples of window size - pad_l = pad_t = pad_d0 = 0 - pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] - pad_b = (window_size[1] - H % window_size[1]) % window_size[1] - pad_r = (window_size[2] - W % window_size[2]) % window_size[2] - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) - _, Dp, Hp, Wp, _ = x.shape - # cyclic shift - if any(i > 0 for i in shift_size): - shifted_x = torch.roll( - x, - shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), - dims=(1, 2, 3), - ) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - # partition windows - x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask) # B*nW, Wd*Wh*Ww, C - # merge windows - attn_windows = attn_windows.view(-1, *(window_size + (C,))) - shifted_x = window_reverse( - attn_windows, window_size, B, Dp, Hp, Wp - ) # B D' H' W' C - # reverse cyclic shift - if any(i > 0 for i in shift_size): - x = torch.roll( - shifted_x, - shifts=(shift_size[0], shift_size[1], shift_size[2]), - dims=(1, 2, 3), - ) - else: - x = shifted_x - - if pad_d1 > 0 or pad_r > 0 or pad_b > 0: - x = x[:, :D, :H, :W, :].contiguous() - return x - - def forward_part2(self, x): - return self.drop_path(self.mlp(self.norm2(x))) - - def forward(self, x, mask_matrix): - """Forward function. - - Args: - x: Input feature, tensor size (B, D, H, W, C). - mask_matrix: Attention mask for cyclic shift. - """ - - shortcut = x - if not self.jump_attention: - if self.use_checkpoint: - x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) - else: - x = self.forward_part1(x, mask_matrix) - x = shortcut + self.drop_path(x) - - if self.use_checkpoint: - x = x + checkpoint.checkpoint(self.forward_part2, x) - else: - x = x + self.forward_part2(x) - - return x - - -class PatchMerging(nn.Module): - """Patch Merging Layer - - Args: - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """Forward function. - - Args: - x: Input feature, tensor size (B, D, H, W, C). - """ - B, D, H, W, C = x.shape - - # padding - pad_input = (H % 2 == 1) or (W % 2 == 1) - if pad_input: - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - - x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C - x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C - x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C - x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -# cache each stage results -@lru_cache() -def compute_mask(D, H, W, window_size, shift_size, device): - img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 - cnt = 0 - for d in ( - slice(-window_size[0]), - slice(-window_size[0], -shift_size[0]), - slice(-shift_size[0], None), - ): - for h in ( - slice(-window_size[1]), - slice(-window_size[1], -shift_size[1]), - slice(-shift_size[1], None), - ): - for w in ( - slice(-window_size[2]), - slice(-window_size[2], -shift_size[2]), - slice(-shift_size[2], None), - ): - img_mask[:, d, h, w, :] = cnt - cnt += 1 - mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 - mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( - attn_mask == 0, float(0.0) - ) - return attn_mask - - -class BasicLayer(nn.Module): - """A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of feature channels - depth (int): Depths of this stage. - num_heads (int): Number of attention head. - window_size (tuple[int]): Local window size. Default: (1,7,7). - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__( - self, - dim, - depth, - num_heads, - window_size=(1, 7, 7), - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False, - jump_attention=False, - ): - super().__init__() - self.window_size = window_size - self.shift_size = tuple(i // 2 for i in window_size) - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList( - [ - SwinTransformerBlock3D( - dim=dim, - num_heads=num_heads, - window_size=window_size, - shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, - norm_layer=norm_layer, - use_checkpoint=use_checkpoint, - jump_attention=jump_attention, - ) - for i in range(depth) - ] - ) - - self.downsample = downsample - if self.downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - - def forward(self, x): - """Forward function. - - Args: - x: Input feature, tensor size (B, C, D, H, W). - """ - # calculate attention mask for SW-MSA - B, C, D, H, W = x.shape - window_size, shift_size = get_window_size( - (D, H, W), self.window_size, self.shift_size - ) - x = rearrange(x, "b c d h w -> b d h w c") - Dp = int(np.ceil(D / window_size[0])) * window_size[0] - Hp = int(np.ceil(H / window_size[1])) * window_size[1] - Wp = int(np.ceil(W / window_size[2])) * window_size[2] - attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) - for blk in self.blocks: - x = blk(x, attn_mask) - x = x.view(B, D, H, W, -1) - - if self.downsample is not None: - x = self.downsample(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x - - -class PatchEmbed3D(nn.Module): - """Video to Patch Embedding. - - Args: - patch_size (int): Patch token size. Default: (2,4,4). - in_chans (int): Number of input video channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, patch_size=(2, 4, 4), in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - self.patch_size = patch_size - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size - ) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, D, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if D % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) - - x = self.proj(x) # B C D Wh Ww - if self.norm is not None: - D, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) - - return x - - -class SwinTransformer3D(nn.Module): - """Swin Transformer backbone. - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - - Args: - patch_size (int | tuple(int)): Patch size. Default: (4,4,4). - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - depths (tuple[int]): Depths of each Swin Transformer stage. - num_heads (tuple[int]): Number of attention head of each stage. - window_size (int): Window size. Default: 7. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. - drop_rate (float): Dropout rate. - attn_drop_rate (float): Attention dropout rate. Default: 0. - drop_path_rate (float): Stochastic depth rate. Default: 0.2. - norm_layer: Normalization layer. Default: nn.LayerNorm. - patch_norm (bool): If True, add normalization after patch embedding. Default: False. - frozen_stages (int): Stages to be frozen (stop grad and set eval mode). - -1 means not freezing any parameters. - """ - - def __init__( - self, - pretrained=None, - pretrained2d=False, - patch_size=(2, 4, 4), - in_chans=3, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=(8, 7, 7), - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - patch_norm=True, - frozen_stages=-1, - use_checkpoint=True, - jump_attention=[False, False, False, False], - ): - super().__init__() - - self.pretrained = pretrained - self.pretrained2d = pretrained2d - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.patch_norm = patch_norm - self.frozen_stages = frozen_stages - self.window_size = window_size - self.patch_size = patch_size - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed3D( - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None, - ) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) - ] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if i_layer < self.num_layers - 1 else None, - use_checkpoint=use_checkpoint, - jump_attention=jump_attention[i_layer], - ) - self.layers.append(layer) - - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - - # add a norm layer for each output - self.norm = norm_layer(self.num_features) - - self._freeze_stages() - - def _freeze_stages(self): - if self.frozen_stages >= 0: - self.patch_embed.eval() - for param in self.patch_embed.parameters(): - param.requires_grad = False - - if self.frozen_stages >= 1: - self.pos_drop.eval() - for i in range(0, self.frozen_stages): - m = self.layers[i] - m.eval() - for param in m.parameters(): - param.requires_grad = False - - def inflate_weights(self, logger): - """Inflate the swin2d parameters to swin3d. - - The differences between swin3d and swin2d mainly lie in an extra - axis. To utilize the pretrained parameters in 2d model, - the weight of swin2d models should be inflated to fit in the shapes of - the 3d counterpart. - - Args: - logger (logging.Logger): The logger used to print - debugging infomation. - """ - checkpoint = torch.load(self.pretrained, map_location="cpu") - state_dict = checkpoint["model"] - - # delete relative_position_index since we always re-init it - relative_position_index_keys = [ - k for k in state_dict.keys() if "relative_position_index" in k - ] - for k in relative_position_index_keys: - del state_dict[k] - - # delete attn_mask since we always re-init it - attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] - for k in attn_mask_keys: - del state_dict[k] - - state_dict["patch_embed.proj.weight"] = ( - state_dict["patch_embed.proj.weight"] - .unsqueeze(2) - .repeat(1, 1, self.patch_size[0], 1, 1) - / self.patch_size[0] - ) - - # bicubic interpolate relative_position_bias_table if not match - relative_position_bias_table_keys = [ - k for k in state_dict.keys() if "relative_position_bias_table" in k - ] - for k in relative_position_bias_table_keys: - relative_position_bias_table_pretrained = state_dict[k] - relative_position_bias_table_current = self.state_dict()[k] - L1, nH1 = relative_position_bias_table_pretrained.size() - L2, nH2 = relative_position_bias_table_current.size() - L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) - wd = self.window_size[0] - if nH1 != nH2: - logger.warning(f"Error in loading {k}, passing") - else: - if L1 != L2: - S1 = int(L1 ** 0.5) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained.permute(1, 0).view( - 1, nH1, S1, S1 - ), - size=( - 2 * self.window_size[1] - 1, - 2 * self.window_size[2] - 1, - ), - mode="bicubic", - ) - relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( - nH2, L2 - ).permute( - 1, 0 - ) - state_dict[k] = relative_position_bias_table_pretrained.repeat( - 2 * wd - 1, 1 - ) - - msg = self.load_state_dict(state_dict, strict=False) - logger.info(msg) - logger.info(f"=> loaded successfully '{self.pretrained}'") - del checkpoint - torch.cuda.empty_cache() - - def load_checkpoint(self, load_path, strict=False): - from collections import OrderedDict - - model_state_dict = self.state_dict() - state_dict = torch.load(load_path) - if "state_dict" in state_dict.keys(): - state_dict = state_dict["state_dict"] - - clean_dict = OrderedDict() - for key, value in state_dict.items(): - if "backbone" in key: - clean_key = key[9:] - clean_dict[clean_key] = value - - if not strict: - for key, value in model_state_dict.items(): - if key in clean_dict: - if value.shape != clean_dict[key].shape: - clean_dict.pop(key) - - self.load_state_dict(clean_dict, strict=strict) - - def init_weights(self, pretrained=None): - """Initialize the weights in backbone. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - - def _init_weights(m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - if pretrained: - self.pretrained = pretrained - if isinstance(self.pretrained, str): - self.apply(_init_weights) - logger = get_root_logger() - logger.info(f"load model from: {self.pretrained}") - - if self.pretrained2d: - # Inflate 2D model into 3D model. - self.inflate_weights(logger) - else: - # Directly load 3D model. - self.load_checkpoint(self.pretrained, strict=False) # , logger=logger) - elif self.pretrained is None: - self.apply(_init_weights) - else: - raise TypeError("pretrained must be a str or None") - - def forward(self, x, multi=False): - """Forward function.""" - x = self.patch_embed(x) - - x = self.pos_drop(x) - - if multi: - feats = [x] - - for layer in self.layers: - x = layer(x.contiguous()) - if multi: - feats += [x] - - x = rearrange(x, "n c d h w -> n d h w c") - x = self.norm(x) - x = rearrange(x, "n d h w c -> n c d h w") - - if multi: - return feats[:-1] + [x] - else: - return x - - def train(self, mode=True): - """Convert the model into training mode while keep layers freezed.""" - super(SwinTransformer3D, self).train(mode) - self._freeze_stages() diff --git a/dover/models/.ipynb_checkpoints/conv_backbone-checkpoint.py b/dover/models/.ipynb_checkpoints/conv_backbone-checkpoint.py deleted file mode 100644 index 2c358fb..0000000 --- a/dover/models/.ipynb_checkpoints/conv_backbone-checkpoint.py +++ /dev/null @@ -1,651 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from timm.models.layers import trunc_normal_, DropPath -from timm.models.registry import register_model - -from open_clip import CLIP3D -import open_clip - -class GRN(nn.Module): - """ GRN (Global Response Normalization) layer - """ - def __init__(self, dim): - super().__init__() - self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) - self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) - - def forward(self, x): - Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) - Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) - return self.gamma * (x * Nx) + self.beta + x - -class Block(nn.Module): - r""" ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) - (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - """ - def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): - super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(4 * dim, dim) - self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), - requires_grad=True) if layer_scale_init_value > 0 else None - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - input = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - -class ConvNeXt(nn.Module): - r""" ConvNeXt - A PyTorch impl of : `A ConvNet for the 2020s` - - https://arxiv.org/pdf/2201.03545.pdf - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_path_rate (float): Stochastic depth rate. Default: 0. - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. - """ - def __init__(self, in_chans=3, num_classes=1000, - depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., - layer_scale_init_value=1e-6, head_init_scale=1., - ): - super().__init__() - - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first") - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - cur = 0 - for i in range(4): - stage = nn.Sequential( - *[Block(dim=dims[i], drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] - ) - self.stages.append(stage) - cur += depths[i] - - self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer - self.head = nn.Linear(dims[-1], num_classes) - - self.apply(self._init_weights) - self.head.weight.data.mul_(head_init_scale) - self.head.bias.data.mul_(head_init_scale) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x): - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - -class LayerNorm(nn.Module): - r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape, ) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - if len(x.shape) == 4: - x = self.weight[:, None, None] * x + self.bias[:, None, None] - elif len(x.shape) == 5: - x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] - return x - - -class Block3D(nn.Module): - r""" ConvNeXt Block. There are two equivalent implementations: - (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) - (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back - We use (2) as we find it slightly faster in PyTorch - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - """ - def __init__(self, dim, drop_path=0., inflate_len=3, layer_scale_init_value=1e-6): - super().__init__() - self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(4 * dim, dim) - self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), - requires_grad=True) if layer_scale_init_value > 0 else None - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - input = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - -class BlockV2(nn.Module): - """ ConvNeXtV2 Block. - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - """ - def __init__(self, dim, drop_path=0.): - super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.grn = GRN(4 * dim) - self.pwconv2 = nn.Linear(4 * dim, dim) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - input = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - -class BlockV23D(nn.Module): - """ ConvNeXtV2 Block. - - Args: - dim (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - """ - def __init__(self, dim, drop_path=0., inflate_len=3,): - super().__init__() - self.dwconv = nn.Conv3d(dim, dim, kernel_size=(inflate_len,7,7), padding=(inflate_len // 2,3,3), groups=dim) # depthwise conv - self.norm = LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.grn = GRN(4 * dim) - self.pwconv2 = nn.Linear(4 * dim, dim) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, x): - input = x - x = self.dwconv(x) - x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x) - x = self.pwconv2(x) - x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - -class ConvNeXtV2(nn.Module): - """ ConvNeXt V2 - - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_path_rate (float): Stochastic depth rate. Default: 0. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. - """ - def __init__(self, in_chans=3, num_classes=1000, - depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], - drop_path_rate=0., head_init_scale=1. - ): - super().__init__() - self.depths = depths - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first") - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - cur = 0 - for i in range(4): - stage = nn.Sequential( - *[BlockV2(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] - ) - self.stages.append(stage) - cur += depths[i] - - self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer - self.head = nn.Linear(dims[-1], num_classes) - - self.apply(self._init_weights) - self.head.weight.data.mul_(head_init_scale) - self.head.bias.data.mul_(head_init_scale) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x): - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - -def convnextv2_atto(**kwargs): - model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) - return model - -def convnextv2_femto(**kwargs): - model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) - return model - -def convnext_pico(**kwargs): - model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) - return model - -def convnextv2_nano(**kwargs): - model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) - return model - -def convnextv2_tiny(**kwargs): - model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) - return model - -def convnextv2_base(**kwargs): - model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - return model - -def convnextv2_large(**kwargs): - model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - return model - -def convnextv2_huge(**kwargs): - model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) - return model - -class ConvNeXt3D(nn.Module): - r""" ConvNeXt - A PyTorch impl of : `A ConvNet for the 2020s` - - https://arxiv.org/pdf/2201.03545.pdf - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_path_rate (float): Stochastic depth rate. Default: 0. - layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. - """ - def __init__(self, in_chans=3, num_classes=1000, - inflate_strategy='131', - depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., - layer_scale_init_value=1e-6, head_init_scale=1., - ): - super().__init__() - - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem = nn.Sequential( - nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first") - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - cur = 0 - for i in range(4): - stage = nn.Sequential( - *[Block3D(dim=dims[i], inflate_len=int(inflate_strategy[j%len(inflate_strategy)]), - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] - ) - self.stages.append(stage) - cur += depths[i] - - self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer - - self.apply(self._init_weights) - - def inflate_weights(self, s_state_dict): - t_state_dict = self.state_dict() - from collections import OrderedDict - for key in t_state_dict.keys(): - if key not in s_state_dict: - print(key) - continue - if t_state_dict[key].shape != s_state_dict[key].shape: - t = t_state_dict[key].shape[2] - s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t - self.load_state_dict(s_state_dict, strict=False) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv3d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x, return_spatial=False, multi=False, layer=-1): - if multi: - xs = [] - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - if multi: - xs.append(x) - if return_spatial: - if multi: - shape = xs[-1].shape[2:] - return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1) - elif layer > -1: - return xs[layer] - else: - return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) - return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C) - - def forward(self, x, multi=False, layer=-1): - x = self.forward_features(x, True, multi=multi, layer=layer) - return x - - -class ConvNeXtV23D(nn.Module): - """ ConvNeXt V2 - - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_path_rate (float): Stochastic depth rate. Default: 0. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. - """ - def __init__(self, in_chans=3, num_classes=1000, - inflate_strategy='131', - depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], - drop_path_rate=0., head_init_scale=1. - ): - super().__init__() - self.depths = depths - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem = nn.Sequential( - nn.Conv3d(in_chans, dims[0], kernel_size=(2,4,4), stride=(2,4,4)), - LayerNorm(dims[0], eps=1e-6, data_format="channels_first") - ) - self.downsample_layers.append(stem) - for i in range(3): - downsample_layer = nn.Sequential( - LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), - nn.Conv3d(dims[i], dims[i+1], kernel_size=(1,2,2), stride=(1,2,2)), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - cur = 0 - for i in range(4): - stage = nn.Sequential( - *[BlockV23D(dim=dims[i], drop_path=dp_rates[cur + j], - inflate_len=int(inflate_strategy[j%len(inflate_strategy)]), - ) for j in range(depths[i])] - ) - self.stages.append(stage) - cur += depths[i] - - self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer - self.head = nn.Linear(dims[-1], num_classes) - - self.apply(self._init_weights) - self.head.weight.data.mul_(head_init_scale) - self.head.bias.data.mul_(head_init_scale) - - def inflate_weights(self, pretrained_path): - t_state_dict = self.state_dict() - s_state_dict = torch.load(pretrained_path)["model"] - from collections import OrderedDict - for key in t_state_dict.keys(): - if key not in s_state_dict: - print(key) - continue - if t_state_dict[key].shape != s_state_dict[key].shape: - print(t_state_dict[key].shape, s_state_dict[key].shape) - t = t_state_dict[key].shape[2] - s_state_dict[key] = s_state_dict[key].unsqueeze(2).repeat(1,1,t,1,1) / t - self.load_state_dict(s_state_dict, strict=False) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv3d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - nn.init.constant_(m.bias, 0) - - def forward_features(self, x, return_spatial=False, multi=False, layer=-1): - if multi: - xs = [] - for i in range(4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - if multi: - xs.append(x) - if return_spatial: - if multi: - shape = xs[-1].shape[2:] - return torch.cat([F.interpolate(x,size=shape, mode="trilinear") for x in xs[:-1]], 1) #+ [self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)], 1) - elif layer > -1: - return xs[layer] - else: - return self.norm(x.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) - return self.norm(x.mean([-3, -2, -1])) # global average pooling, (N, C, T, H, W) -> (N, C) - - def forward(self, x, multi=False, layer=-1): - x = self.forward_features(x, True, multi=multi, layer=layer) - return x - - -model_urls = { - "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", - "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", - "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", - "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", - "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", - "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", - "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", - "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", - "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", -} - -def convnext_tiny(pretrained=False,in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) - if pretrained: - url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) - model.load_state_dict(checkpoint["model"]) - return model - -def convnext_small(pretrained=False,in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - if pretrained: - url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - return model - -def convnext_base(pretrained=False, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - if pretrained: - url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - return model - - -def convnext_large(pretrained=False, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - if pretrained: - url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - return model - -def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): - model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - if pretrained: - assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" - url = model_urls['convnext_xlarge_22k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - - return model - -def convnext_3d_tiny(pretrained=False, in_22k=False, **kwargs): - print("Using Imagenet 22K pretrain", in_22k) - model = ConvNeXt3D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) - if pretrained: - url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) - model.inflate_weights(checkpoint["model"]) - return model - -def convnext_3d_small(pretrained=False, in_22k=False, **kwargs): - model = ConvNeXt3D(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - if pretrained: - url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] - checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) - model.inflate_weights(checkpoint["model"]) - - return model - -def convnextv2_3d_atto(**kwargs): - model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) - - return model - -def convnextv2_3d_femto(pretrained="../pretrained/convnextv2_femto_1k_224_ema.pt", **kwargs): - model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) - #model.inflate_weights(pretrained) - return model - -def convnextv2_3d_pico(pretrained="../pretrained/convnextv2_pico_1k_224_ema.pt", **kwargs): - model = ConvNeXtV23D(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) - #model.inflate_weights(pretrained) - return model - -def convnextv2_3d_nano(pretrained="../pretrained/convnextv2_nano_1k_224_ema.pt", **kwargs): - model = ConvNeXtV23D(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) - #model.inflate_weights(pretrained) - return model - -def convnextv2_tiny(**kwargs): - model = ConvNeXtV23D(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) - return model - -def convnextv2_base(**kwargs): - model = ConvNeXtV23D(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - return model - -def convnextv2_large(**kwargs): - model = ConvNeXtV23D(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - return model - -def convnextv2_huge(**kwargs): - model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) - return model - - - - -if __name__ == "__main__": - - device = "cuda" if torch.cuda.is_available() else "cpu" - model = convnext_3d_tiny(True).to(device) - print(model) - from thop import profile - print(profile(model, (torch.randn(4,3,32,224,224).to(device),))[0] / 1e9) - - \ No newline at end of file diff --git a/dover/models/.ipynb_checkpoints/evaluator-checkpoint.py b/dover/models/.ipynb_checkpoints/evaluator-checkpoint.py deleted file mode 100644 index 11510ac..0000000 --- a/dover/models/.ipynb_checkpoints/evaluator-checkpoint.py +++ /dev/null @@ -1,262 +0,0 @@ -import time -from functools import partial, reduce - -import torch -import torch.nn as nn -from torch.nn.functional import adaptive_avg_pool3d - -from .conv_backbone import convnext_3d_small, convnext_3d_tiny, convnextv2_3d_pico, convnextv2_3d_femto -from .head import IQAHead, VARHead, VQAHead -from .swin_backbone import SwinTransformer2D as ImageBackbone -from .swin_backbone import SwinTransformer3D as VideoBackbone -from .swin_backbone import swin_3d_small, swin_3d_tiny - - -class BaseEvaluator(nn.Module): - def __init__( - self, backbone=dict(), vqa_head=dict(), - ): - super().__init__() - self.backbone = VideoBackbone(**backbone) - self.vqa_head = VQAHead(**vqa_head) - - def forward(self, vclip, inference=True, **kwargs): - if inference: - self.eval() - with torch.no_grad(): - feat = self.backbone(vclip) - score = self.vqa_head(feat) - self.train() - return score - else: - feat = self.backbone(vclip) - score = self.vqa_head(feat) - return score - - def forward_with_attention(self, vclip): - self.eval() - with torch.no_grad(): - feat, avg_attns = self.backbone(vclip, require_attn=True) - score = self.vqa_head(feat) - return score, avg_attns - - -class DOVER(nn.Module): - def __init__( - self, - backbone_size="divided", - backbone_preserve_keys="fragments,resize", - multi=False, - layer=-1, - backbone=dict( - resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)} - ), - divide_head=False, - vqa_head=dict(in_channels=768), - var=False, - ): - self.backbone_preserve_keys = backbone_preserve_keys.split(",") - self.multi = multi - self.layer = layer - super().__init__() - for key, hypers in backbone.items(): - print(backbone_size) - if key not in self.backbone_preserve_keys: - continue - if backbone_size == "divided": - t_backbone_size = hypers["type"] - else: - t_backbone_size = backbone_size - if t_backbone_size == "swin_tiny": - b = swin_3d_tiny(**backbone[key]) - elif t_backbone_size == "swin_tiny_grpb": - # to reproduce fast-vqa - b = VideoBackbone() - elif t_backbone_size == "swin_tiny_grpb_m": - # to reproduce fast-vqa-m - b = VideoBackbone(window_size=(4, 4, 4), frag_biases=[0, 0, 0, 0]) - elif t_backbone_size == "swin_small": - b = swin_3d_small(**backbone[key]) - elif t_backbone_size == "conv_tiny": - b = convnext_3d_tiny(pretrained=True) - elif t_backbone_size == "conv_small": - b = convnext_3d_small(pretrained=True) - elif t_backbone_size == "conv_femto": - b = convnextv2_3d_femto(pretrained=True) - elif t_backbone_size == "conv_pico": - b = convnextv2_3d_pico(pretrained=True) - elif t_backbone_size == "xclip": - raise NotImplementedError - # b = build_x_clip_model(**backbone[key]) - else: - raise NotImplementedError - print("Setting backbone:", key + "_backbone") - setattr(self, key + "_backbone", b) - if divide_head: - for key in backbone: - pre_pool = False #if key == "technical" else True - if key not in self.backbone_preserve_keys: - continue - b = VQAHead(pre_pool=pre_pool, **vqa_head) - print("Setting head:", key + "_head") - setattr(self, key + "_head", b) - else: - if var: - self.vqa_head = VARHead(**vqa_head) - print(b) - else: - self.vqa_head = VQAHead(**vqa_head) - - def forward( - self, - vclips, - inference=True, - return_pooled_feats=False, - return_raw_feats=False, - reduce_scores=False, - pooled=False, - **kwargs - ): - assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return" - if inference: - self.eval() - with torch.no_grad(): - scores = [] - feats = {} - for key in vclips: - feat = getattr(self, key.split("_")[0] + "_backbone")( - vclips[key], multi=self.multi, layer=self.layer, **kwargs - ) - if hasattr(self, key.split("_")[0] + "_head"): - scores += [getattr(self, key.split("_")[0] + "_head")(feat)] - else: - scores += [getattr(self, "vqa_head")(feat)] - if return_pooled_feats: - feats[key] = feat - if return_raw_feats: - feats[key] = feat - if reduce_scores: - if len(scores) > 1: - scores = reduce(lambda x, y: x + y, scores) - else: - scores = scores[0] - if pooled: - scores = torch.mean(scores, (1, 2, 3, 4)) - self.train() - if return_pooled_feats or return_raw_feats: - return scores, feats - return scores - else: - self.train() - scores = [] - feats = {} - for key in vclips: - feat = getattr(self, key.split("_")[0] + "_backbone")( - vclips[key], multi=self.multi, layer=self.layer, **kwargs - ) - if hasattr(self, key.split("_")[0] + "_head"): - scores += [getattr(self, key.split("_")[0] + "_head")(feat)] - else: - scores += [getattr(self, "vqa_head")(feat)] - if return_pooled_feats: - feats[key] = feat.mean((-3, -2, -1)) - if reduce_scores: - if len(scores) > 1: - scores = reduce(lambda x, y: x + y, scores) - else: - scores = scores[0] - if pooled: - print(scores.shape) - scores = torch.mean(scores, (1, 2, 3, 4)) - print(scores.shape) - - if return_pooled_feats: - return scores, feats - return scores - - def forward_head( - self, - feats, - inference=True, - reduce_scores=False, - pooled=False, - **kwargs - ): - if inference: - self.eval() - with torch.no_grad(): - scores = [] - feats = {} - for key in feats: - feat = feats[key] - if hasattr(self, key.split("_")[0] + "_head"): - scores += [getattr(self, key.split("_")[0] + "_head")(feat)] - else: - scores += [getattr(self, "vqa_head")(feat)] - if reduce_scores: - if len(scores) > 1: - scores = reduce(lambda x, y: x + y, scores) - else: - scores = scores[0] - if pooled: - scores = torch.mean(scores, (1, 2, 3, 4)) - self.train() - return scores - else: - self.train() - scores = [] - feats = {} - for key in vclips: - feat = getattr(self, key.split("_")[0] + "_backbone")( - vclips[key], multi=self.multi, layer=self.layer, **kwargs - ) - if hasattr(self, key.split("_")[0] + "_head"): - scores += [getattr(self, key.split("_")[0] + "_head")(feat)] - else: - scores += [getattr(self, "vqa_head")(feat)] - if return_pooled_feats: - feats[key] = feat - if reduce_scores: - if len(scores) > 1: - scores = reduce(lambda x, y: x + y, scores) - else: - scores = scores[0] - if pooled: - print(scores.shape) - scores = torch.mean(scores, (1, 2, 3, 4)) - print(scores.shape) - - if return_pooled_feats: - return scores, feats - return scores - - - - -class BaseImageEvaluator(nn.Module): - def __init__( - self, backbone=dict(), iqa_head=dict(), - ): - super().__init__() - self.backbone = ImageBackbone(**backbone) - self.iqa_head = IQAHead(**iqa_head) - - def forward(self, image, inference=True, **kwargs): - if inference: - self.eval() - with torch.no_grad(): - feat = self.backbone(image) - score = self.iqa_head(feat) - self.train() - return score - else: - feat = self.backbone(image) - score = self.iqa_head(feat) - return score - - def forward_with_attention(self, image): - self.eval() - with torch.no_grad(): - feat, avg_attns = self.backbone(image, require_attn=True) - score = self.iqa_head(feat) - return score, avg_attns diff --git a/dover/models/.ipynb_checkpoints/head-checkpoint.py b/dover/models/.ipynb_checkpoints/head-checkpoint.py deleted file mode 100644 index 8995821..0000000 --- a/dover/models/.ipynb_checkpoints/head-checkpoint.py +++ /dev/null @@ -1,101 +0,0 @@ -import math - -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional as F -from torchvision.ops import roi_align, roi_pool - - -class VQAHead(nn.Module): - """MLP Regression Head for VQA. - Args: - in_channels: input channels for MLP - hidden_channels: hidden channels for MLP - dropout_ratio: the dropout ratio for features before the MLP (default 0.5) - pre_pool: whether pre-pool the features or not (True for Aesthetic Attributes, False for Technical Attributes) - """ - - def __init__( - self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, pre_pool=False, **kwargs - ): - super().__init__() - self.dropout_ratio = dropout_ratio - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.pre_pool = pre_pool - if self.dropout_ratio != 0: - self.dropout = nn.Dropout(p=self.dropout_ratio) - else: - self.dropout = None - self.fc_hid = nn.Conv3d(self.in_channels, self.hidden_channels, (1, 1, 1)) - self.fc_last = nn.Conv3d(self.hidden_channels, 1, (1, 1, 1)) - self.gelu = nn.GELU() - - self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) - - def forward(self, x, rois=None): - if self.pre_pool: - x = self.avg_pool(x) - x = self.dropout(x) - qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x)))) - return qlt_score - - - - - -class VARHead(nn.Module): - """MLP Regression Head for Video Action Recognition. - Args: - in_channels: input channels for MLP - hidden_channels: hidden channels for MLP - dropout_ratio: the dropout ratio for features before the MLP (default 0.5) - """ - - def __init__(self, in_channels=768, out_channels=400, dropout_ratio=0.5, **kwargs): - super().__init__() - self.dropout_ratio = dropout_ratio - self.in_channels = in_channels - self.out_channels = out_channels - if self.dropout_ratio != 0: - self.dropout = nn.Dropout(p=self.dropout_ratio) - else: - self.dropout = None - self.fc = nn.Conv3d(self.in_channels, self.out_channels, (1, 1, 1)) - self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) - - def forward(self, x, rois=None): - x = self.dropout(x) - x = self.avg_pool(x) - out = self.fc(x) - return out - - -class IQAHead(nn.Module): - """MLP Regression Head for IQA. - Args: - in_channels: input channels for MLP - hidden_channels: hidden channels for MLP - dropout_ratio: the dropout ratio for features before the MLP (default 0.5) - """ - - def __init__( - self, in_channels=768, hidden_channels=64, dropout_ratio=0.5, **kwargs - ): - super().__init__() - self.dropout_ratio = dropout_ratio - self.in_channels = in_channels - self.hidden_channels = hidden_channels - if self.dropout_ratio != 0: - self.dropout = nn.Dropout(p=self.dropout_ratio) - else: - self.dropout = None - self.fc_hid = nn.Linear(self.in_channels, self.hidden_channels) - self.fc_last = nn.Linear(self.hidden_channels, 1) - self.gelu = nn.GELU() - - def forward(self, x): - x = self.dropout(x) - qlt_score = self.fc_last(self.dropout(self.gelu(self.fc_hid(x)))) - return qlt_score diff --git a/dover/models/.ipynb_checkpoints/swin_backbone-checkpoint.py b/dover/models/.ipynb_checkpoints/swin_backbone-checkpoint.py deleted file mode 100644 index 86faf57..0000000 --- a/dover/models/.ipynb_checkpoints/swin_backbone-checkpoint.py +++ /dev/null @@ -1,1097 +0,0 @@ -import math -from functools import lru_cache, reduce -from operator import mul - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from timm.models.layers import DropPath, trunc_normal_ - - -def fragment_infos(D, H, W, fragments=7, device="cuda"): - m = torch.arange(fragments).unsqueeze(-1).float() - m = (m + m.t() * fragments).reshape(1, 1, 1, fragments, fragments) - m = F.interpolate(m.to(device), size=(D, H, W)).permute(0, 2, 3, 4, 1) - return m.long() - - -@lru_cache -def global_position_index( - D, - H, - W, - fragments=(1, 7, 7), - window_size=(8, 7, 7), - shift_size=(0, 0, 0), - device="cuda", -): - frags_d = torch.arange(fragments[0]) - frags_h = torch.arange(fragments[1]) - frags_w = torch.arange(fragments[2]) - frags = torch.stack( - torch.meshgrid(frags_d, frags_h, frags_w) - ).float() # 3, Fd, Fh, Fw - coords = ( - torch.nn.functional.interpolate(frags[None].to(device), size=(D, H, W)) - .long() - .permute(0, 2, 3, 4, 1) - ) - # print(shift_size) - coords = torch.roll( - coords, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3) - ) - window_coords = window_partition(coords, window_size) - relative_coords = ( - window_coords[:, None, :] - window_coords[:, :, None] - ) # Wd*Wh*Ww, Wd*Wh*Ww, 3 - return relative_coords # relative_coords - - -@lru_cache -def get_adaptive_window_size( - base_window_size, input_x_size, base_x_size, -): - tw, hw, ww = base_window_size - tx_, hx_, wx_ = input_x_size - tx, hx, wx = base_x_size - print((tw * tx_) // tx, (hw * hx_) // hx, (ww * wx_) // wx) - return (tw * tx_) // tx, (hw * hx_) // hx, (ww * wx_) // wx - - -class Mlp(nn.Module): - """Multilayer perceptron.""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, D, H, W, C) - window_size (tuple[int]): window size - - Returns: - windows: (B*num_windows, window_size*window_size, C) - """ - B, D, H, W, C = x.shape - x = x.view( - B, - D // window_size[0], - window_size[0], - H // window_size[1], - window_size[1], - W // window_size[2], - window_size[2], - C, - ) - windows = ( - x.permute(0, 1, 3, 5, 2, 4, 6, 7) - .contiguous() - .view(-1, reduce(mul, window_size), C) - ) - return windows - - -def window_reverse(windows, window_size, B, D, H, W): - """ - Args: - windows: (B*num_windows, window_size, window_size, C) - window_size (tuple[int]): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, D, H, W, C) - """ - x = windows.view( - B, - D // window_size[0], - H // window_size[1], - W // window_size[2], - window_size[0], - window_size[1], - window_size[2], - -1, - ) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) - return x - - -def get_window_size(x_size, window_size, shift_size=None): - use_window_size = list(window_size) - if shift_size is not None: - use_shift_size = list(shift_size) - for i in range(len(x_size)): - if x_size[i] <= window_size[i]: - use_window_size[i] = x_size[i] - if shift_size is not None: - use_shift_size[i] = 0 - - if shift_size is None: - return tuple(use_window_size) - else: - return tuple(use_window_size), tuple(use_shift_size) - - -class WindowAttention3D(nn.Module): - """Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The temporal length, height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__( - self, - dim, - window_size, - num_heads, - qkv_bias=False, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - frag_bias=False, - ): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wd, Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros( - (2 * window_size[0] - 1) - * (2 * window_size[1] - 1) - * (2 * window_size[2] - 1), - num_heads, - ) - ) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH - if frag_bias: - self.fragment_position_bias_table = nn.Parameter( - torch.zeros( - (2 * window_size[0] - 1) - * (2 * window_size[1] - 1) - * (2 * window_size[2] - 1), - num_heads, - ) - ) - - # get pair-wise relative position index for each token inside the window - coords_d = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack( - torch.meshgrid(coords_d, coords_h, coords_w) - ) # 3, Wd, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :] - ) # 3, Wd*Wh*Ww, Wd*Wh*Ww - relative_coords = relative_coords.permute( - 1, 2, 0 - ).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * ( - 2 * self.window_size[2] - 1 - ) - relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 - relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=0.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None, fmask=None, resized_window_size=None): - """Forward function. - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, N, N) or None - """ - # print(x.shape) - B_, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B_, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = qkv[0], qkv[1], qkv[2] # B_, nH, N, C - - q = q * self.scale - attn = q @ k.transpose(-2, -1) - - if resized_window_size is None: - rpi = self.relative_position_index[:N, :N] - else: - relative_position_index = self.relative_position_index.reshape( - *self.window_size, *self.window_size - ) - d, h, w = resized_window_size - - rpi = relative_position_index[:d, :h, :w, :d, :h, :w] - relative_position_bias = self.relative_position_bias_table[ - rpi.reshape(-1) - ].reshape( - N, N, -1 - ) # Wd*Wh*Ww,Wd*Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww - if hasattr(self, "fragment_position_bias_table"): - fragment_position_bias = self.fragment_position_bias_table[ - rpi.reshape(-1) - ].reshape( - N, N, -1 - ) # Wd*Wh*Ww,Wd*Wh*Ww,nH - fragment_position_bias = fragment_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wd*Wh*Ww, Wd*Wh*Ww - - ### Mask Position Bias - if fmask is not None: - # fgate = torch.where(fmask - fmask.transpose(-1, -2) == 0, 1, 0).float() - fgate = fmask.abs().sum(-1) - nW = fmask.shape[0] - relative_position_bias = relative_position_bias.unsqueeze(0) - fgate = fgate.unsqueeze(1) - # print(fgate.shape, relative_position_bias.shape) - if hasattr(self, "fragment_position_bias_table"): - relative_position_bias = ( - relative_position_bias * fgate - + fragment_position_bias * (1 - fgate) - ) - - attn = attn.view( - B_ // nW, nW, self.num_heads, N, N - ) + relative_position_bias.unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - else: - attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( - 1 - ).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - - return x - - -class SwinTransformerBlock3D(nn.Module): - """Swin Transformer Block. - - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - window_size (tuple[int]): Window size. - shift_size (tuple[int]): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__( - self, - dim, - num_heads, - window_size=(2, 7, 7), - shift_size=(0, 0, 0), - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - use_checkpoint=False, - jump_attention=False, - frag_bias=False, - ): - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - self.use_checkpoint = use_checkpoint - self.jump_attention = jump_attention - self.frag_bias = frag_bias - - assert ( - 0 <= self.shift_size[0] < self.window_size[0] - ), "shift_size must in 0-window_size" - assert ( - 0 <= self.shift_size[1] < self.window_size[1] - ), "shift_size must in 0-window_size" - assert ( - 0 <= self.shift_size[2] < self.window_size[2] - ), "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention3D( - dim, - window_size=self.window_size, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - frag_bias=frag_bias, - ) - - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - def forward_part1(self, x, mask_matrix, resized_window_size=None): - B, D, H, W, C = x.shape - window_size, shift_size = get_window_size( - (D, H, W), - self.window_size if resized_window_size is None else resized_window_size, - self.shift_size, - ) - - x = self.norm1(x) - # pad feature maps to multiples of window size - pad_l = pad_t = pad_d0 = 0 - pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0] - pad_b = (window_size[1] - H % window_size[1]) % window_size[1] - pad_r = (window_size[2] - W % window_size[2]) % window_size[2] - - x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) - _, Dp, Hp, Wp, _ = x.shape - if False: # not hasattr(self, 'finfo_windows'): - finfo = fragment_infos(Dp, Hp, Wp) - - # cyclic shift - if any(i > 0 for i in shift_size): - shifted_x = torch.roll( - x, - shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), - dims=(1, 2, 3), - ) - if False: # not hasattr(self, 'finfo_windows'): - shifted_finfo = torch.roll( - finfo, - shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), - dims=(1, 2, 3), - ) - attn_mask = mask_matrix - else: - shifted_x = x - if False: # not hasattr(self, 'finfo_windows'): - shifted_finfo = finfo - attn_mask = None - # partition windows - x_windows = window_partition(shifted_x, window_size) # B*nW, Wd*Wh*Ww, C - if False: # not hasattr(self, 'finfo_windows'): - self.finfo_windows = window_partition(shifted_finfo, window_size) - # W-MSA/SW-MSA - # print(shift_size) - gpi = global_position_index( - Dp, - Hp, - Wp, - fragments=(1,) + window_size[1:], - window_size=window_size, - shift_size=shift_size, - device=x.device, - ) - attn_windows = self.attn( - x_windows, - mask=attn_mask, - fmask=gpi, - resized_window_size=window_size - if resized_window_size is not None - else None, - ) # self.finfo_windows) # B*nW, Wd*Wh*Ww, C - # merge windows - attn_windows = attn_windows.view(-1, *(window_size + (C,))) - shifted_x = window_reverse( - attn_windows, window_size, B, Dp, Hp, Wp - ) # B D' H' W' C - # reverse cyclic shift - if any(i > 0 for i in shift_size): - x = torch.roll( - shifted_x, - shifts=(shift_size[0], shift_size[1], shift_size[2]), - dims=(1, 2, 3), - ) - else: - x = shifted_x - - if pad_d1 > 0 or pad_r > 0 or pad_b > 0: - x = x[:, :D, :H, :W, :].contiguous() - return x - - def forward_part2(self, x): - return self.drop_path(self.mlp(self.norm2(x))) - - def forward(self, x, mask_matrix, resized_window_size=None): - """Forward function. - - Args: - x: Input feature, tensor size (B, D, H, W, C). - mask_matrix: Attention mask for cyclic shift. - """ - - shortcut = x - if not self.jump_attention: - if self.use_checkpoint: - x = checkpoint.checkpoint( - self.forward_part1, x, mask_matrix, resized_window_size - ) - else: - x = self.forward_part1(x, mask_matrix, resized_window_size) - x = shortcut + self.drop_path(x) - - if self.use_checkpoint: - x = x + checkpoint.checkpoint(self.forward_part2, x) - else: - x = x + self.forward_part2(x) - - return x - - -class PatchMerging(nn.Module): - """Patch Merging Layer - - Args: - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """Forward function. - - Args: - x: Input feature, tensor size (B, D, H, W, C). - """ - B, D, H, W, C = x.shape - - # padding - pad_input = (H % 2 == 1) or (W % 2 == 1) - if pad_input: - x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) - - x0 = x[:, :, 0::2, 0::2, :] # B D H/2 W/2 C - x1 = x[:, :, 1::2, 0::2, :] # B D H/2 W/2 C - x2 = x[:, :, 0::2, 1::2, :] # B D H/2 W/2 C - x3 = x[:, :, 1::2, 1::2, :] # B D H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B D H/2 W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - -# cache each stage results -@lru_cache() -def compute_mask(D, H, W, window_size, shift_size, device): - img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1 - cnt = 0 - for d in ( - slice(-window_size[0]), - slice(-window_size[0], -shift_size[0]), - slice(-shift_size[0], None), - ): - for h in ( - slice(-window_size[1]), - slice(-window_size[1], -shift_size[1]), - slice(-shift_size[1], None), - ): - for w in ( - slice(-window_size[2]), - slice(-window_size[2], -shift_size[2]), - slice(-shift_size[2], None), - ): - img_mask[:, d, h, w, :] = cnt - cnt += 1 - mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1 - mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2] - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( - attn_mask == 0, float(0.0) - ) - return attn_mask - - -class BasicLayer(nn.Module): - """A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of feature channels - depth (int): Depths of this stage. - num_heads (int): Number of attention head. - window_size (tuple[int]): Local window size. Default: (1,7,7). - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - """ - - def __init__( - self, - dim, - depth, - num_heads, - window_size=(1, 7, 7), - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False, - jump_attention=False, - frag_bias=False, - ): - super().__init__() - self.window_size = window_size - self.shift_size = tuple(i // 2 for i in window_size) - self.depth = depth - self.use_checkpoint = use_checkpoint - # print(window_size) - # build blocks - self.blocks = nn.ModuleList( - [ - SwinTransformerBlock3D( - dim=dim, - num_heads=num_heads, - window_size=window_size, - shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, - norm_layer=norm_layer, - use_checkpoint=use_checkpoint, - jump_attention=jump_attention, - frag_bias=frag_bias, - ) - for i in range(depth) - ] - ) - - self.downsample = downsample - if self.downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - - def forward(self, x, resized_window_size=None): - """Forward function. - - Args: - x: Input feature, tensor size (B, C, D, H, W). - """ - # calculate attention mask for SW-MSA - B, C, D, H, W = x.shape - - window_size, shift_size = get_window_size( - (D, H, W), - self.window_size if resized_window_size is None else resized_window_size, - self.shift_size, - ) - # print(window_size) - x = rearrange(x, "b c d h w -> b d h w c") - Dp = int(np.ceil(D / window_size[0])) * window_size[0] - Hp = int(np.ceil(H / window_size[1])) * window_size[1] - Wp = int(np.ceil(W / window_size[2])) * window_size[2] - attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) - for blk in self.blocks: - x = blk(x, attn_mask, resized_window_size=resized_window_size) - x = x.view(B, D, H, W, -1) - - if self.downsample is not None: - x = self.downsample(x) - x = rearrange(x, "b d h w c -> b c d h w") - return x - - -class PatchEmbed3D(nn.Module): - """Video to Patch Embedding. - - Args: - patch_size (int): Patch token size. Default: (2,4,4). - in_chans (int): Number of input video channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, patch_size=(2, 4, 4), in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - self.patch_size = patch_size - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size - ) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, D, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if D % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) - - x = self.proj(x) # B C D Wh Ww - if self.norm is not None: - D, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2) - x = self.norm(x) - x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) - - return x - - -class SwinTransformer3D(nn.Module): - """Swin Transformer backbone. - A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - - https://arxiv.org/pdf/2103.14030 - - Args: - patch_size (int | tuple(int)): Patch size. Default: (4,4,4). - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - depths (tuple[int]): Depths of each Swin Transformer stage. - num_heads (tuple[int]): Number of attention head of each stage. - window_size (int): Window size. Default: 7. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. - drop_rate (float): Dropout rate. - attn_drop_rate (float): Attention dropout rate. Default: 0. - drop_path_rate (float): Stochastic depth rate. Default: 0.2. - norm_layer: Normalization layer. Default: nn.LayerNorm. - patch_norm (bool): If True, add normalization after patch embedding. Default: False. - frozen_stages (int): Stages to be frozen (stop grad and set eval mode). - -1 means not freezing any parameters. - """ - - def __init__( - self, - pretrained=None, - pretrained2d=False, - patch_size=(2, 4, 4), - in_chans=3, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=(8, 7, 7), - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - patch_norm=True, - frozen_stages=-1, - use_checkpoint=True, - jump_attention=[False, False, False, False], - frag_biases=[True, True, True, False], - base_x_size=(32, 224, 224), - ): - super().__init__() - - self.pretrained = pretrained - self.pretrained2d = pretrained2d - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.patch_norm = patch_norm - self.frozen_stages = frozen_stages - self.window_size = window_size - self.patch_size = patch_size - self.base_x_size = base_x_size - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed3D( - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None, - ) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) - ] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer] - if isinstance(window_size, list) - else window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging if i_layer < self.num_layers - 1 else None, - use_checkpoint=use_checkpoint, - jump_attention=jump_attention[i_layer], - frag_bias=frag_biases[i_layer], - ) - self.layers.append(layer) - - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) - - # add a norm layer for each output - self.norm = norm_layer(self.num_features) - - self._freeze_stages() - - self.init_weights() - - def _freeze_stages(self): - if self.frozen_stages >= 0: - self.patch_embed.eval() - for param in self.patch_embed.parameters(): - param.requires_grad = False - - if self.frozen_stages >= 1: - self.pos_drop.eval() - for i in range(0, self.frozen_stages): - m = self.layers[i] - m.eval() - for param in m.parameters(): - param.requires_grad = False - - def inflate_weights(self): - """Inflate the swin2d parameters to swin3d. - - The differences between swin3d and swin2d mainly lie in an extra - axis. To utilize the pretrained parameters in 2d model, - the weight of swin2d models should be inflated to fit in the shapes of - the 3d counterpart. - - Args: - logger (logging.Logger): The logger used to print - debugging infomation. - """ - checkpoint = torch.load(self.pretrained, map_location="cpu") - state_dict = checkpoint["model"] - - # delete relative_position_index since we always re-init it - relative_position_index_keys = [ - k for k in state_dict.keys() if "relative_position_index" in k - ] - for k in relative_position_index_keys: - del state_dict[k] - - # delete attn_mask since we always re-init it - attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] - for k in attn_mask_keys: - del state_dict[k] - - state_dict["patch_embed.proj.weight"] = ( - state_dict["patch_embed.proj.weight"] - .unsqueeze(2) - .repeat(1, 1, self.patch_size[0], 1, 1) - / self.patch_size[0] - ) - - # bicubic interpolate relative_position_bias_table if not match - relative_position_bias_table_keys = [ - k for k in state_dict.keys() if "relative_position_bias_table" in k - ] - for k in relative_position_bias_table_keys: - relative_position_bias_table_pretrained = state_dict[k] - relative_position_bias_table_current = self.state_dict()[k] - L1, nH1 = relative_position_bias_table_pretrained.size() - L2, nH2 = relative_position_bias_table_current.size() - L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) - wd = self.window_size[0] - if nH1 != nH2: - print(f"Error in loading {k}, passing") - else: - if L1 != L2: - S1 = int(L1 ** 0.5) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained.permute(1, 0).view( - 1, nH1, S1, S1 - ), - size=( - 2 * self.window_size[1] - 1, - 2 * self.window_size[2] - 1, - ), - mode="bicubic", - ) - relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view( - nH2, L2 - ).permute( - 1, 0 - ) - state_dict[k] = relative_position_bias_table_pretrained.repeat( - 2 * wd - 1, 1 - ) - - msg = self.load_state_dict(state_dict, strict=False) - print(msg) - print(f"=> loaded successfully '{self.pretrained}'") - del checkpoint - torch.cuda.empty_cache() - - def load_swin(self, load_path, strict=False): - print("loading swin lah") - from collections import OrderedDict - - model_state_dict = self.state_dict() - state_dict = torch.load(load_path)["state_dict"] - - clean_dict = OrderedDict() - for key, value in state_dict.items(): - if "backbone" in key: - clean_key = key[9:] - clean_dict[clean_key] = value - if "relative_position_bias_table" in clean_key: - forked_key = clean_key.replace( - "relative_position_bias_table", "fragment_position_bias_table" - ) - if forked_key in clean_dict: - print("load_swin_error?") - else: - clean_dict[forked_key] = value - - # bicubic interpolate relative_position_bias_table if not match - relative_position_bias_table_keys = [ - k for k in clean_dict.keys() if "relative_position_bias_table" in k - ] - for k in relative_position_bias_table_keys: - print(k) - relative_position_bias_table_pretrained = clean_dict[k] - relative_position_bias_table_current = model_state_dict[k] - L1, nH1 = relative_position_bias_table_pretrained.size() - L2, nH2 = relative_position_bias_table_current.size() - if isinstance(self.window_size, list): - i_layer = int(k.split(".")[1]) - L2 = (2 * self.window_size[i_layer][1] - 1) * ( - 2 * self.window_size[i_layer][2] - 1 - ) - wd = self.window_size[i_layer][0] - else: - L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) - wd = self.window_size[0] - if nH1 != nH2: - print(f"Error in loading {k}, passing") - else: - if L1 != L2: - S1 = int((L1 / 15) ** 0.5) - print( - relative_position_bias_table_pretrained.shape, 15, nH1, S1, S1 - ) - relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( - relative_position_bias_table_pretrained.permute(1, 0) - .view(nH1, 15, S1, S1) - .transpose(0, 1), - size=( - 2 * self.window_size[i_layer][1] - 1, - 2 * self.window_size[i_layer][2] - 1, - ), - mode="bicubic", - ) - relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.transpose( - 0, 1 - ).view( - nH2, 15, L2 - ) - clean_dict[k] = relative_position_bias_table_pretrained # .repeat(2*wd-1,1) - - ## Clean Mismatched Keys - for key, value in model_state_dict.items(): - if key in clean_dict: - if value.shape != clean_dict[key].shape: - print(key) - clean_dict.pop(key) - - self.load_state_dict(clean_dict, strict=strict) - - def init_weights(self, pretrained=None): - print(self.pretrained, self.pretrained2d) - """Initialize the weights in backbone. - - Args: - pretrained (str, optional): Path to pre-trained weights. - Defaults to None. - """ - - def _init_weights(m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - if pretrained: - self.pretrained = pretrained - if isinstance(self.pretrained, str): - self.apply(_init_weights) - # logger = get_root_logger() - # logger.info(f"load model from: {self.pretrained}") - - if self.pretrained2d: - # Inflate 2D model into 3D model. - self.inflate_weights() - else: - # Directly load 3D model. - self.load_swin(self.pretrained, strict=False) # , logger=logger) - elif self.pretrained is None: - self.apply(_init_weights) - else: - raise TypeError("pretrained must be a str or None") - - def forward(self, x, multi=False, layer=-1, adaptive_window_size=False): - - """Forward function.""" - if adaptive_window_size: - resized_window_size = get_adaptive_window_size( - self.window_size, x.shape[2:], self.base_x_size - ) - else: - resized_window_size = None - - x = self.patch_embed(x) - - x = self.pos_drop(x) - feats = [x] - - for l, mlayer in enumerate(self.layers): - x = mlayer(x.contiguous(), resized_window_size) - feats += [x] - - x = rearrange(x, "n c d h w -> n d h w c") - x = self.norm(x) - x = rearrange(x, "n d h w c -> n c d h w") - - if multi: - shape = x.shape[2:] - return torch.cat( - [F.interpolate(xi, size=shape, mode="trilinear") for xi in feats[:-1]], - 1, - ) - elif layer > -1: - print("something", len(feats)) - return feats[layer] - else: - return x - - def train(self, mode=True): - """Convert the model into training mode while keep layers freezed.""" - super(SwinTransformer3D, self).train(mode) - self._freeze_stages() - - -def swin_3d_tiny(**kwargs): - ## Original Swin-3D Tiny with reduced windows - return SwinTransformer3D(depths=[2, 2, 6, 2], frag_biases=[0, 0, 0, 0], **kwargs) - - -def swin_3d_small(**kwargs): - # Original Swin-3D Small with reduced windows - return SwinTransformer3D(depths=[2, 2, 18, 2], frag_biases=[0, 0, 0, 0], **kwargs) - - -class SwinTransformer2D(nn.Sequential): - def __init__(self): - ## Only backbone for Swin Transformer 2D - from timm.models import swin_tiny_patch4_window7_224 - - super().__init__(*list(swin_tiny_patch4_window7_224().children())[:-2]) diff --git a/dover/models/.ipynb_checkpoints/xclip_backbone-checkpoint.py b/dover/models/.ipynb_checkpoints/xclip_backbone-checkpoint.py deleted file mode 100644 index d17fb21..0000000 --- a/dover/models/.ipynb_checkpoints/xclip_backbone-checkpoint.py +++ /dev/null @@ -1,902 +0,0 @@ -import copy -import math -from collections import OrderedDict -from typing import Tuple, Union - -import clip -import numpy as np -import torch -import torch.nn.functional as F -from einops import rearrange -from timm.models.layers import trunc_normal_ -from torch import nn -from torch.utils.checkpoint import checkpoint_sequential - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, - the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for - changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use - 'survival rate' as the argument. - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - # orig_type = x.dtype - # ret = super().forward(x.type(torch.float32)) - # return ret.type(orig_type) - return super().forward(x) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__( - self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, - ): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head,) - self.ln_1 = LayerNorm(d_model) - - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)), - ] - ) - ) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = ( - self.attn_mask.to(dtype=x.dtype, device=x.device) - if self.attn_mask is not None - else None - ) - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__( - self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None - ): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential( - *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] - ) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - - -class VisionTransformer(nn.Module): - def __init__( - self, - input_resolution: int, - patch_size: int, - width: int, - layers: int, - heads: int, - output_dim: int, - ): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) - ) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - return x - - -class CLIP(nn.Module): - def __init__( - self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - ): - super().__init__() - - self.context_length = context_length - - # vision_heads = vision_width // 64 - # self.visual = VisionTransformer( - # input_resolution=image_resolution, - # patch_size=vision_patch_size, - # width=vision_width, - # layers=vision_layers, - # heads=vision_heads, - # output_dim=embed_dim - # ) - - # self.transformer = Transformer( - # width=transformer_width, - # layers=transformer_layers, - # heads=transformer_heads, - # attn_mask=self.build_attention_mask() - # ) - - # self.vocab_size = vocab_size - # self.token_embedding = nn.Embedding(vocab_size, transformer_width) - # self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - # self.ln_final = LayerNorm(transformer_width) - - # self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - # self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - proj_std = (self.transformer.width ** -0.5) * ( - (2 * self.transformer.layers) ** -0.5 - ) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype - - def encode_image(self, image): - return self.visual(image.type(self.dtype)) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=1, keepdim=True) - text_features = text_features / text_features.norm(dim=1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logits_per_image.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -class CrossFramelAttentionBlock(nn.Module): - def __init__( - self, - d_model: int, - n_head: int, - attn_mask: torch.Tensor = None, - droppath=0.0, - T=0, - ): - super().__init__() - self.T = T - - self.message_fc = nn.Linear(d_model, d_model) - self.message_ln = LayerNorm(d_model) - self.message_attn = nn.MultiheadAttention(d_model, n_head,) - - self.attn = nn.MultiheadAttention(d_model, n_head,) - self.ln_1 = LayerNorm(d_model) - - self.drop_path = DropPath(droppath) if droppath > 0.0 else nn.Identity() - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)), - ] - ) - ) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = ( - self.attn_mask.to(dtype=x.dtype, device=x.device) - if self.attn_mask is not None - else None - ) - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x): - l, bt, d = x.size() - b = bt // self.T - x = x.view(l, b, self.T, d) - - msg_token = self.message_fc(x[0, :, :, :]) - msg_token = msg_token.view(b, self.T, 1, d) - - msg_token = msg_token.permute(1, 2, 0, 3).view(self.T, b, d) - msg_token = msg_token + self.drop_path( - self.message_attn( - self.message_ln(msg_token), - self.message_ln(msg_token), - self.message_ln(msg_token), - need_weights=False, - )[0] - ) - msg_token = msg_token.view(self.T, 1, b, d).permute(1, 2, 0, 3) - - x = torch.cat([x, msg_token], dim=0) - - x = x.view(l + 1, -1, d) - x = x + self.drop_path(self.attention(self.ln_1(x))) - x = x[:l, :, :] - x = x + self.drop_path(self.mlp(self.ln_2(x))) - return x - - -class Transformer(nn.Module): - def __init__( - self, - width: int, - layers: int, - heads: int, - attn_mask: torch.Tensor = None, - droppath=None, - use_checkpoint=False, - T=8, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - if droppath is None: - droppath = [0.0 for i in range(layers)] - self.width = width - self.layers = layers - - self.resblocks = nn.Sequential( - *[ - CrossFramelAttentionBlock(width, heads, attn_mask, droppath[i], T) - for i in range(layers) - ] - ) - - def forward(self, x: torch.Tensor): - if not self.use_checkpoint: - return self.resblocks(x) - else: - return checkpoint_sequential(self.resblocks, 3, x) - - -class CrossFrameCommunicationTransformer(nn.Module): - def __init__( - self, - input_resolution: int, - patch_size: int, - width: int, - layers: int, - heads: int, - output_dim: int, - droppath=None, - T=8, - use_checkpoint=False, - ): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - - self.conv1 = nn.Conv2d( - in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) - ) - self.ln_pre = LayerNorm(width) - - ## Attention Blocks - self.transformer = Transformer( - width, layers, heads, droppath=droppath, use_checkpoint=use_checkpoint, T=T, - ) - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def init_weights(self): - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) - x = self.transformer(x) - x = x.permute(1, 0, 2) - - cls_x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - cls_x = cls_x @ self.proj - - return cls_x, x[:, 1:, :] - - -class MulitHeadAttention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - - self.scale = qk_scale or head_dim ** -0.5 - - self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) - self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) - self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, q, k, v): - B, N, C = q.shape - B, M, C = k.shape - q = ( - self.q_proj(q) - .reshape(B, N, self.num_heads, C // self.num_heads) - .permute(0, 2, 1, 3) - ) - k = ( - self.k_proj(k) - .reshape(B, M, self.num_heads, C // self.num_heads) - .permute(0, 2, 1, 3) - ) - v = ( - self.v_proj(v) - .reshape(B, M, self.num_heads, C // self.num_heads) - .permute(0, 2, 1, 3) - ) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class PromptGeneratorLayer(nn.Module): - def __init__( - self, d_model, nhead, dropout=0.0, - ): - super().__init__() - self.cross_attn = MulitHeadAttention(d_model, nhead, proj_drop=dropout) - - self.norm1 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - - self.dropout = nn.Dropout(dropout) - - self.mlp = nn.Sequential( - nn.Linear(d_model, d_model * 4), - QuickGELU(), - nn.Dropout(dropout), - nn.Linear(d_model * 4, d_model), - ) - - def forward(self, x, visual): - q = k = v = self.norm1(x) - x = x + self.cross_attn(q, visual, visual) - x = x + self.dropout(self.mlp(self.norm3(x))) - return x - - -class VideoSpecificPrompt(nn.Module): - def __init__( - self, layers=2, embed_dim=512, alpha=0.1, - ): - super().__init__() - self.norm = nn.LayerNorm(embed_dim) - self.decoder = nn.ModuleList( - [PromptGeneratorLayer(embed_dim, embed_dim // 64) for _ in range(layers)] - ) - self.alpha = nn.Parameter(torch.ones(embed_dim) * alpha) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, text, visual): - B, N, C = visual.shape - visual = self.norm(visual) - for layer in self.decoder: - text = layer(text, visual) - - -from collections import OrderedDict - -from timm.models.layers import trunc_normal_ - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = nn.LayerNorm(d_model) - self.mlp = nn.Sequential( - OrderedDict( - [ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)), - ] - ) - ) - self.ln_2 = nn.LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = ( - self.attn_mask.to(dtype=x.dtype, device=x.device) - if self.attn_mask is not None - else None - ) - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class MultiframeIntegrationTransformer(nn.Module): - def __init__( - self, T, embed_dim=512, layers=1, - ): - super().__init__() - self.T = T - transformer_heads = embed_dim // 64 - self.positional_embedding = nn.Parameter(torch.empty(1, T, embed_dim)) - trunc_normal_(self.positional_embedding, std=0.02) - self.resblocks = nn.Sequential( - *[ - ResidualAttentionBlock(d_model=embed_dim, n_head=transformer_heads) - for _ in range(layers) - ] - ) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Linear,)): - trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) - - def forward(self, x): - ori_x = x - x = x + self.positional_embedding - x = x.permute(1, 0, 2) - x = self.resblocks(x) - x = x.permute(1, 0, 2) - x = x.type(ori_x.dtype) + ori_x - - return x.mean(dim=1, keepdim=False) - - -class XCLIP(CLIP): - def __init__( - self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - # video - T=8, - droppath=0.0, - mit_layers=1, - # prompt - prompts_alpha=1e-4, - prompts_layers=1, - # other - use_cache=True, - use_checkpoint=False, - ): - super().__init__( - embed_dim, - image_resolution, - vision_layers, - vision_width, - vision_patch_size, - context_length, - vocab_size, - transformer_width, - transformer_heads, - transformer_layers, - ) - - self.prompts_generator = VideoSpecificPrompt( - layers=prompts_layers, embed_dim=embed_dim, alpha=prompts_alpha, - ) - self.use_cache = use_cache - self.mit = MultiframeIntegrationTransformer( - T=T, embed_dim=embed_dim, layers=mit_layers, - ) - - dpr = ( - [x.item() for x in torch.linspace(0, droppath, vision_layers)] - if droppath > 0.0 - else None - ) - - vision_heads = vision_width // 64 - self.visual = CrossFrameCommunicationTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim, - droppath=dpr, - T=T, - use_checkpoint=use_checkpoint, - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask(), - ) - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter( - torch.empty(self.context_length, transformer_width) - ) - self.ln_final = LayerNorm(transformer_width) - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - self.cache_text_features = None - self.prompts_visual_ln = LayerNorm(vision_width) - self.prompts_visual_proj = nn.Parameter(torch.randn(vision_width, embed_dim)) - - self.initialize_parameters() - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {"positional_embedding"} - - def encode_image(self, image): - return self.visual(image) - - def encode_text(self, text): - x = self.token_embedding(text) - eos_indx = text.argmax(dim=-1) - K, N1, C = x.shape - - x = x + self.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), eos_indx] @ self.text_projection - x = x.reshape(K, -1) - return x - - def encode_video(self, image): - b, t, c, h, w = image.size() - image = image.reshape(-1, c, h, w) - - cls_features, img_features = self.encode_image(image) - img_features = self.prompts_visual_ln(img_features) - img_features = img_features @ self.prompts_visual_proj - - cls_features = cls_features.view(b, t, -1) - img_features = img_features.view(b, t, -1, cls_features.shape[-1]) - - video_features = self.mit(cls_features) - - return video_features, img_features - - def forward(self, image, **kwargs): - image = rearrange(image, "b c t h w -> b t c h w") - video_features, _ = self.encode_video(image) - return video_features.reshape(*video_features.shape, 1, 1, 1) - - def cache_text(self, text): - self.eval() - with torch.no_grad(): - if self.cache_text_features is None: - self.cache_text_features = self.encode_text(text) - self.train() - return self.cache_text_features - - def forward_original(self, image, text): - b = image.shape[0] - video_features, img_features = self.encode_video(image) - img_features = img_features.mean(dim=1, keepdim=False) - - if self.use_cache: - text_features = self.cache_text(text) - else: - text_features = self.encode_text(text) - - text_features = text_features.unsqueeze(0).expand(b, -1, -1) - text_features = text_features + self.prompts_generator( - text_features, img_features - ) - - video_features = video_features / video_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - logit_scale = self.logit_scale.exp() - logits = torch.einsum("bd,bkd->bk", video_features, logit_scale * text_features) - - return logits - - -def build_x_clip_model( - pretrained_path="./pretrained_weights/k400_32_8.pth", - droppath=0.0, - use_checkpoint=False, - logger=None, - prompts_alpha=1e-1, - prompts_layers=2, - use_cache=True, - mit_layers=4, - **kwargs, -): - state_dict = torch.load(pretrained_path, map_location="cpu")["model"] - T = int(pretrained_path.split("_")[-1].split(".")[0]) - print(T) - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len( - [ - k - for k in state_dict.keys() - if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") - ] - ) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round( - (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 - ) - image_resolution = vision_patch_size * grid_size - else: - counts: list = [ - len( - set( - k.split(".")[2] - for k in state_dict - if k.startswith(f"visual.layer{b}") - ) - ) - for b in [1, 2, 3, 4] - ] - vision_layers = tuple(counts) - - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round( - (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 - ) - vision_patch_size = None - assert ( - output_width ** 2 + 1 - == state_dict["visual.attnpool.positional_embedding"].shape[0] - ) - image_resolution = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len( - set( - k.split(".")[2] - for k in state_dict - if k.startswith(f"transformer.resblocks") - ) - ) - - model = XCLIP( - embed_dim, - image_resolution, - vision_layers, - vision_width, - vision_patch_size, - context_length, - vocab_size, - transformer_width, - transformer_heads, - transformer_layers, - T=T, - droppath=droppath, - mit_layers=mit_layers, - prompts_alpha=prompts_alpha, - prompts_layers=prompts_layers, - use_checkpoint=use_checkpoint, - use_cache=use_cache, - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - - msg = model.load_state_dict(state_dict, strict=False) - - return model.eval() diff --git a/figs/.ipynb_checkpoints/approach-checkpoint.png b/figs/.ipynb_checkpoints/approach-checkpoint.png deleted file mode 100644 index 84f4a98..0000000 Binary files a/figs/.ipynb_checkpoints/approach-checkpoint.png and /dev/null differ