diff --git a/examples/meta_learning.py b/examples/meta_learning.py index 7e87b36..8ec0be5 100644 --- a/examples/meta_learning.py +++ b/examples/meta_learning.py @@ -181,7 +181,7 @@ def _classify_to_prototypes( query_targets: [N] return: dist: [N, M], labels: [N] """ - dist = pairwise_euclidean_distance(query_feats, prototypes) # [N, M] + dist = pairwise_euclidean_distance(query_feats, prototypes, True) # [N, M] labels = (query_targets[:, None] == proto_labels[None, :]).float().argmax(dim=1) return dist, labels diff --git a/examples/pre.py b/examples/pre.py index 0abf61e..bf1ff78 100644 --- a/examples/pre.py +++ b/examples/pre.py @@ -30,7 +30,7 @@ from torchmetrics.classification.auroc import AUROC from torchmetrics.classification.average_precision import AveragePrecision from torchmetrics.functional import ( - accuracy, auroc, pairwise_cosine_similarity, pairwise_euclidean_distance + accuracy, auroc, pairwise_cosine_similarity ) # diff --git a/examples/pre_cv.py b/examples/pre_cv.py index 102a3d1..fe67b2d 100644 --- a/examples/pre_cv.py +++ b/examples/pre_cv.py @@ -147,7 +147,7 @@ def load_densenet_state_dict( def NT_Xent_loss(features: Tensor, temperature: float = 0.1) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ - features: Tensor[float]. [2N, E] + features: FloatTensor. [2N, E] return: loss: [], pos_idx_mean: [2N], acc: [2N], acc_5: [2N] """ NINF = -torch.inf @@ -188,3 +188,22 @@ def backward(ctx, *grads: Tensor) -> Tensor: res = grads[ml.get_dist_setting()[0]] res *= dist.get_world_size() # for same grad with W * batch_size; mean operation in ddp across device. return res + + +def pairwise_euclidean_distance( + X: Tensor, + Y: Tensor, + squared: bool = False +) -> Tensor: + """ + X: shape[N1, F]. FloatTensor + Y: shape[N2, F]. FloatTensor + return: shape[N1, N2] + """ + XX = torch.einsum("ij,ij->i", X, X) + YY = torch.einsum("ij,ij->i", Y, Y) + # + res = X @ Y.T + res.mul_(-2).add_(XX[:, None]).add_(YY) + res.clamp_min_(0.) + return res if squared else res.sqrt_() \ No newline at end of file diff --git a/mini_lightning/_mini_lightning.py b/mini_lightning/_mini_lightning.py index 234e8b8..8e91ed5 100644 --- a/mini_lightning/_mini_lightning.py +++ b/mini_lightning/_mini_lightning.py @@ -386,7 +386,7 @@ def __init__( # self.lmodel = lmodel self.device_ids = device_ids - # + # if deterministic is not None: torch.backends.cudnn.deterministic = deterministic deterministic = torch.backends.cudnn.deterministic @@ -397,7 +397,7 @@ def __init__( torch.backends.cudnn.benchmark = benchmark logger.info(f"Setting deterministic: {deterministic}") logger.info(f"Setting benchmark: {benchmark}") - # + # self.device = select_device(device_ids) if self.rank == -1: parallel_mode = "DP" if len(device_ids) > 1 else None @@ -456,7 +456,7 @@ def __init__( self.model_checkpoint = model_checkpoint if model_checkpoint is not None else ModelCheckpoint() if self.rank in {-1, 0}: runs_dir = os.path.abspath(runs_dir) - self.version = self._get_version(runs_dir) + self.version = self._get_version(runs_dir) if platform.system().lower() == "windows": time = dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # window not support `:` runs_dir = os.path.join(runs_dir, f"v{self.version}_{time}") @@ -502,17 +502,18 @@ def _get_version(runs_dir: str) -> int: v_list.append(int(v)) return max(v_list) + 1 - def _check_hparams(self, hparams: Any) -> Any: + @classmethod + def _check_hparams(cls, hparams: Any) -> Any: if hparams is None or isinstance(hparams, (int, float, str, complex)): # bool is a subclass of int return hparams if isinstance(hparams, Sequence): res = [] for hp in hparams: - res.append(self._check_hparams(hp)) + res.append(cls._check_hparams(hp)) elif isinstance(hparams, Mapping): res = {} for k, v in hparams.items(): - res[k] = self._check_hparams(v) + res[k] = cls._check_hparams(v) else: res = repr(hparams) # e.g. function return res @@ -540,7 +541,8 @@ def _metrics_update(metrics: Dict[str, MeanMetric], new_mes: Dict[str, float], p continue metrics[k].update(v) - def _metrics_compute(self, metrics: Dict[str, MeanMetric]) -> Dict[str, float]: + @staticmethod + def _metrics_compute(metrics: Dict[str, MeanMetric]) -> Dict[str, float]: res = {} for k in metrics.keys(): v: Tensor = metrics[k].compute() @@ -812,7 +814,7 @@ def _train_epoch(self, dataloader: DataLoader, val_dataloader: Optional[DataLoad prog_bar_mes = self._get_res_mes(_mean_metrics, _rec_mes, "prog_bar") if self.rank >= 0: prog_bar_mes = self._reduce_mes(prog_bar_mes, device) - # + # if self.version is not None: prog_bar_mes["v"] = self.version prog_bar.set_postfix(prog_bar_mes, refresh=False) # rank > 0 disable.