Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Apr 2, 2023
1 parent aeb5b2d commit 2753ef6
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
2 changes: 1 addition & 1 deletion examples/meta_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _classify_to_prototypes(
query_targets: [N]
return: dist: [N, M], labels: [N]
"""
dist = pairwise_euclidean_distance(query_feats, prototypes) # [N, M]
dist = pairwise_euclidean_distance(query_feats, prototypes, True) # [N, M]
labels = (query_targets[:, None] == proto_labels[None, :]).float().argmax(dim=1)
return dist, labels

Expand Down
2 changes: 1 addition & 1 deletion examples/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torchmetrics.classification.auroc import AUROC
from torchmetrics.classification.average_precision import AveragePrecision
from torchmetrics.functional import (
accuracy, auroc, pairwise_cosine_similarity, pairwise_euclidean_distance
accuracy, auroc, pairwise_cosine_similarity
)

#
Expand Down
21 changes: 20 additions & 1 deletion examples/pre_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_densenet_state_dict(

def NT_Xent_loss(features: Tensor, temperature: float = 0.1) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
features: Tensor[float]. [2N, E]
features: FloatTensor. [2N, E]
return: loss: [], pos_idx_mean: [2N], acc: [2N], acc_5: [2N]
"""
NINF = -torch.inf
Expand Down Expand Up @@ -188,3 +188,22 @@ def backward(ctx, *grads: Tensor) -> Tensor:
res = grads[ml.get_dist_setting()[0]]
res *= dist.get_world_size() # for same grad with W * batch_size; mean operation in ddp across device.
return res


def pairwise_euclidean_distance(
X: Tensor,
Y: Tensor,
squared: bool = False
) -> Tensor:
"""
X: shape[N1, F]. FloatTensor
Y: shape[N2, F]. FloatTensor
return: shape[N1, N2]
"""
XX = torch.einsum("ij,ij->i", X, X)
YY = torch.einsum("ij,ij->i", Y, Y)
#
res = X @ Y.T
res.mul_(-2).add_(XX[:, None]).add_(YY)
res.clamp_min_(0.)
return res if squared else res.sqrt_()
18 changes: 10 additions & 8 deletions mini_lightning/_mini_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def __init__(
#
self.lmodel = lmodel
self.device_ids = device_ids
#
#
if deterministic is not None:
torch.backends.cudnn.deterministic = deterministic
deterministic = torch.backends.cudnn.deterministic
Expand All @@ -397,7 +397,7 @@ def __init__(
torch.backends.cudnn.benchmark = benchmark
logger.info(f"Setting deterministic: {deterministic}")
logger.info(f"Setting benchmark: {benchmark}")
#
#
self.device = select_device(device_ids)
if self.rank == -1:
parallel_mode = "DP" if len(device_ids) > 1 else None
Expand Down Expand Up @@ -456,7 +456,7 @@ def __init__(
self.model_checkpoint = model_checkpoint if model_checkpoint is not None else ModelCheckpoint()
if self.rank in {-1, 0}:
runs_dir = os.path.abspath(runs_dir)
self.version = self._get_version(runs_dir)
self.version = self._get_version(runs_dir)
if platform.system().lower() == "windows":
time = dt.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # window not support `:`
runs_dir = os.path.join(runs_dir, f"v{self.version}_{time}")
Expand Down Expand Up @@ -502,17 +502,18 @@ def _get_version(runs_dir: str) -> int:
v_list.append(int(v))
return max(v_list) + 1

def _check_hparams(self, hparams: Any) -> Any:
@classmethod
def _check_hparams(cls, hparams: Any) -> Any:
if hparams is None or isinstance(hparams, (int, float, str, complex)): # bool is a subclass of int
return hparams
if isinstance(hparams, Sequence):
res = []
for hp in hparams:
res.append(self._check_hparams(hp))
res.append(cls._check_hparams(hp))
elif isinstance(hparams, Mapping):
res = {}
for k, v in hparams.items():
res[k] = self._check_hparams(v)
res[k] = cls._check_hparams(v)
else:
res = repr(hparams) # e.g. function
return res
Expand Down Expand Up @@ -540,7 +541,8 @@ def _metrics_update(metrics: Dict[str, MeanMetric], new_mes: Dict[str, float], p
continue
metrics[k].update(v)

def _metrics_compute(self, metrics: Dict[str, MeanMetric]) -> Dict[str, float]:
@staticmethod
def _metrics_compute(metrics: Dict[str, MeanMetric]) -> Dict[str, float]:
res = {}
for k in metrics.keys():
v: Tensor = metrics[k].compute()
Expand Down Expand Up @@ -812,7 +814,7 @@ def _train_epoch(self, dataloader: DataLoader, val_dataloader: Optional[DataLoad
prog_bar_mes = self._get_res_mes(_mean_metrics, _rec_mes, "prog_bar")
if self.rank >= 0:
prog_bar_mes = self._reduce_mes(prog_bar_mes, device)
#
#
if self.version is not None:
prog_bar_mes["v"] = self.version
prog_bar.set_postfix(prog_bar_mes, refresh=False) # rank > 0 disable.
Expand Down

0 comments on commit 2753ef6

Please sign in to comment.