Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save training logs into tensorboard files #46

Merged
merged 9 commits into from
Apr 11, 2023
7 changes: 3 additions & 4 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ on:
- main
- dev
- temp_test_branch # if in need, create such a temporary branch to test some functions
# pull_request:
# branches:
# - main
# - dev
pull_request:
branches:
- dev

jobs:
test:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/testing_daily.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# we have to install torch in advance because of torch_sparse,
# refer to https://github.com/rusty1s/pytorch_sparse/issues/156#issuecomment-1304869772 for details
pip install torch
pip install -r pypots/tests/environment_for_pip_test.txt

- name: Fetch the test environment details
Expand Down
14 changes: 8 additions & 6 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,21 @@ def __init__(

# get the current time to append to the dir name,
# so you can use the same tb_file_saving_path for multiple running
time_now = datetime.now().__format__("%Y-%m-%d_T%H:%M:%S")
time_now = datetime.now().__format__("%Y%m%d_T%H%M%S")
# the actual directory name to save the tensorboard file
actual_tb_saving_dir_name = "tensorboard_" + time_now
actual_tb_file_saving_path = os.path.join(
tb_file_saving_path, actual_tb_saving_dir_name
)
os.makedirs(actual_tb_saving_dir_name) # create the dir for file saving
self.summary_writer = SummaryWriter(actual_tb_file_saving_path)
# os.makedirs(actual_tb_file_saving_path) # create the dir for file saving
self.summary_writer = SummaryWriter(
actual_tb_file_saving_path, filename_suffix=".pypots"
)
else:
# don't save the log if tb_file_saving_path isn't given, set summary_writer as None
self.summary_writer = None

def save_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None:
def save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None:
"""Saving training logs into the tensorboard file.

Parameters
Expand All @@ -92,7 +94,8 @@ def save_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None:
"""
while len(loss_dict) > 0:
(item_name, loss) = loss_dict.popitem()
self.summary_writer.add_scalar(f"{item_name}/{stage}", loss, step)
if "loss" in item_name: # save all items containing word "loss" in the name
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss, step)

def save_model(
self,
Expand Down Expand Up @@ -222,7 +225,6 @@ def __init__(
self.optimizer = None
self.best_model_dict = None
self.best_loss = float("inf")
self.logger = {"training_loss": [], "validating_loss": []}

def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
Expand Down
71 changes: 42 additions & 29 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class BaseClassifier(BaseModel):
"""Abstract class for all classification models."""

def __init__(
self,
device: Optional[Union[str, torch.device]] = None,
tb_file_saving_path: str = None,
self,
device: Optional[Union[str, torch.device]] = None,
tb_file_saving_path: str = None,
):
super().__init__(
device,
Expand All @@ -32,10 +32,10 @@ def __init__(

@abstractmethod
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "h5py",
) -> None:
"""Train the classifier on the given data.

Expand Down Expand Up @@ -67,9 +67,9 @@ def fit(

@abstractmethod
def classify(
self,
X: Union[dict, str],
file_type: str = "h5py",
self,
X: Union[dict, str],
file_type: str = "h5py",
) -> np.ndarray:
"""Classify the input data with the trained model.

Expand All @@ -92,16 +92,16 @@ def classify(

class BaseNNClassifier(BaseNNModel, BaseClassifier):
def __init__(
self,
n_classes: int,
batch_size: int,
epochs: int,
patience: int,
learning_rate: float,
weight_decay: float,
num_workers: int = 0,
device: Optional[Union[str, torch.device]] = None,
tb_file_saving_path: str = None,
self,
n_classes: int,
batch_size: int,
epochs: int,
patience: int,
learning_rate: float,
weight_decay: float,
num_workers: int = 0,
device: Optional[Union[str, torch.device]] = None,
tb_file_saving_path: str = None,
):
super().__init__(
batch_size,
Expand Down Expand Up @@ -172,9 +172,9 @@ def _assemble_input_for_testing(self, data) -> dict:
pass

def _train_model(
self,
training_loader: DataLoader,
val_loader: DataLoader = None,
self,
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:

self.optimizer = torch.optim.Adam(
Expand All @@ -186,21 +186,25 @@ def _train_model(
self.best_model_dict = None

try:
training_step = 0
for epoch in range(self.epochs):
self.model.train()
epoch_train_loss_collector = []
for idx, data in enumerate(training_loader):
training_step += 1
inputs = self._assemble_input_for_training(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
results["loss"].backward()
self.optimizer.step()
epoch_train_loss_collector.append(results["loss"].item())

mean_train_loss = np.mean(
epoch_train_loss_collector
) # mean training loss of the current epoch
self.logger["training_loss"].append(mean_train_loss)
# save training loss logs into the tensorboard file for every step if in need
if self.summary_writer is not None:
self.save_log_into_tb_file(training_step, "training", results)

# mean training loss of the current epoch
mean_train_loss = np.mean(epoch_train_loss_collector)

if val_loader is not None:
self.model.eval()
Expand All @@ -212,9 +216,18 @@ def _train_model(
epoch_val_loss_collector.append(results["loss"].item())

mean_val_loss = np.mean(epoch_val_loss_collector)
self.logger["validating_loss"].append(mean_val_loss)

# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"classification_loss": mean_val_loss,
}
self.save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
Expand Down
63 changes: 27 additions & 36 deletions pypots/classification/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,14 @@ def __init__(
self.classification_weight = classification_weight
self.reconstruction_weight = reconstruction_weight

def merge_ret(self, ret_f: dict, ret_b: dict) -> dict:
"""Merge (average) results from two RITS models into one.
def impute(self, inputs: dict) -> torch.Tensor:
return super().impute(inputs)

Parameters
----------
ret_f : dict,
Results from the forward RITS.
ret_b : dict,
Results from the backward RITS.

Returns
-------
dict,
Merged results in a dictionary.
"""
results = {
"imputed_data": (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2,
"prediction": (ret_f["prediction"] + ret_b["prediction"]) / 2,
}
return results

def classify(self, inputs: dict) -> Tuple[dict, dict, dict]:
def classify(self, inputs: dict) -> torch.Tensor:
ret_f = self.rits_f(inputs, "forward")
ret_b = self.reverse(self.rits_b(inputs, "backward"))
merged_ret = self.merge_ret(ret_f, ret_b)
return merged_ret, ret_f, ret_b
ret_b = self._reverse(self.rits_b(inputs, "backward"))
classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2
return classification_pred

def forward(self, inputs: dict) -> dict:
"""Forward processing of BRITS.
Expand All @@ -99,29 +81,38 @@ def forward(self, inputs: dict) -> dict:
-------
dict, A dictionary includes all results.
"""
merged_ret, ret_f, ret_b = self.classify(inputs)
ret_f = self.rits_f(inputs, "forward")
ret_b = self._reverse(self.rits_b(inputs, "backward"))

ret_f["classification_loss"] = F.nll_loss(
torch.log(ret_f["prediction"]), inputs["label"]
)
ret_b["classification_loss"] = F.nll_loss(
torch.log(ret_b["prediction"]), inputs["label"]
)
consistency_loss = self.get_consistency_loss(
consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)
classification_loss = (
ret_f["classification_loss"] + ret_b["classification_loss"]
) / 2
merged_ret["consistency_loss"] = consistency_loss
merged_ret["classification_loss"] = classification_loss
merged_ret["loss"] = (
reconstruction_loss = (
ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
) / 2

loss = (
consistency_loss
+ (ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"])
* self.reconstruction_weight
+ (ret_f["classification_loss"] + ret_b["classification_loss"])
* self.classification_weight
+ reconstruction_loss * self.reconstruction_weight
+ classification_loss * self.classification_weight
)
return merged_ret

results = {
"consistency_loss": consistency_loss,
"classification_loss": classification_loss,
"reconstruction_loss": reconstruction_loss,
"loss": loss,
}
return results


class BRITS(BaseNNClassifier):
Expand Down Expand Up @@ -395,8 +386,8 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py"):
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results, _, _ = self.model.classify(inputs)
prediction_collector.append(results["prediction"])
classification_pred = self.model.classify(inputs)
prediction_collector.append(classification_pred)

predictions = torch.cat(prediction_collector)
return predictions.cpu().detach().numpy()
8 changes: 4 additions & 4 deletions pypots/classification/raindrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,10 @@ def reset_parameters(self) -> None:
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_edge._reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()
self.lin_beta._reset_parameters()
glorot(self.weight)
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
Expand Down Expand Up @@ -446,7 +446,7 @@ def init_weights(self):
self.emb.weight.data.uniform_(-init_range, init_range)
glorot(self.R_u)

def classify(self, inputs):
def classify(self, inputs: dict) -> torch.Tensor:
"""Forward processing of BRITS.

Parameters
Expand Down Expand Up @@ -847,8 +847,8 @@ def classify(self, X: Union[dict, str], file_type: str = "h5py") -> np.ndarray:
shuffle=False,
num_workers=self.num_workers,
)
prediction_collector = []

prediction_collector = []
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
Expand Down
26 changes: 20 additions & 6 deletions pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,21 @@ def _train_model(
val_loader: DataLoader = None,
) -> None:

"""

Parameters
----------
training_loader
val_loader

Notes
-----
The training procedures of NN clustering models are very different from each other. For example, VaDER needs
pretraining while CRLI doesn't, VaDER only needs one optimizer while CRLI needs two for its generator and
discriminator separately. So far, I'd suggest to implement function _train_model() for each model individually.

"""

self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
Expand All @@ -189,10 +204,8 @@ def _train_model(
self.optimizer.step()
epoch_train_loss_collector.append(results["loss"].item())

mean_train_loss = np.mean(
epoch_train_loss_collector
) # mean training loss of the current epoch
self.logger["training_loss"].append(mean_train_loss)
# mean training loss of the current epoch
mean_train_loss = np.mean(epoch_train_loss_collector)

if val_loader is not None:
self.model.eval()
Expand All @@ -204,9 +217,10 @@ def _train_model(
epoch_val_loss_collector.append(results["loss"].item())

mean_val_loss = np.mean(epoch_val_loss_collector)
self.logger["validating_loss"].append(mean_val_loss)
logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
f"epoch {epoch}: "
f"training loss {mean_train_loss:.4f}, "
f"validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
Expand Down
Loading