diff --git a/README.md b/README.md
index 51ae60ee2..a1f480103 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,10 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
## 🚀 Recent Updates
+- 🪄**Support for Training Visualizing with [Weights & Biases](https://docs.wandb.ai/)**🐝
+ - **[Tutorial](./docs/en_US/get_started.md)**
+ ![Wandb Dashboard](./docs/imgs/wandb_dashboard.png)
+
- 👶 **Young or Old?:[StyleGAN V2 Face Editing](./docs/en_US/tutorials/styleganv2editing.md)-Time Machine!** 👨🦳
- **[Online Toturials](https://aistudio.baidu.com/aistudio/projectdetail/3251280?channelType=0&channel=0)**
diff --git a/docs/en_US/config_doc.md b/docs/en_US/config_doc.md
index a4f3f7065..be3d9b41a 100644
--- a/docs/en_US/config_doc.md
+++ b/docs/en_US/config_doc.md
@@ -75,3 +75,21 @@ Take`lapstyle_rev_first.yaml` as an example.
| :--------------- | ---- | ------ |
| interval | log printing interval | 10 |
| visiual_interval | interval for saving the generated images during training | 500 |
+
+### VDLLogger
+
+| Field | Usage | Default |
+| :--------------- | ---- | ------ |
+| save_dir | Directory to save VisualDL records | None |
+
+### WandbLogger
+
+| Parameter | Use | Defaults |
+| :---------------------: | :---------------------: | :--------------: |
+| project | Project to which the run is to be logged | uncategorized |
+| name | Alias/Name of the run | Randomly generated by wandb |
+| id | ID of the run | Randomly generated by wandb |
+| entity | User or team to which the run is being logged | The logged in user |
+| save_dir | local directory in which all the models and other data is saved | wandb |
+| config | model configuration | None |
+| log_model | Whether checkpoints are to be logged to W&B or not | False |
\ No newline at end of file
diff --git a/docs/en_US/get_started.md b/docs/en_US/get_started.md
index 85de5fb56..d083492dd 100644
--- a/docs/en_US/get_started.md
+++ b/docs/en_US/get_started.md
@@ -84,6 +84,8 @@ output_dir
#### Visualize Training
+##### VisualDL
+
[VisualDL](https://github.com/PaddlePaddle/VisualDL) is a visual analysis tool developed for deep learning model development, providing real-time trend visualization of key metrics, sample training intermediate process visualization, network structure visualization, etc. It can visually show the relationship between the effects of super participant models and assist in efficient tuning.
Please make sure that you have installed [VisualDL](https://github.com/PaddlePaddle/VisualDL). Refer to the [VisualDL installation guide](https://github.com/PaddlePaddle/VisualDL/blob/develop/README.md#Installation).
@@ -106,6 +108,29 @@ visualdl --logdir output_dir/CycleGANModel-2020-10-29-09-21/
Please refer to the [VisualDL User's Guide](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/components/README.md) for more guidance on how to start and use those visualization functions.
+##### Weights & Biases
+
+W&B is a MLOps tool that can be used for experiment tracking, dataset/model versioning, visualizing results and collaborating with colleagues. A W&B logger is integrated directly into PaddleOCR and to use it, first you need to install the `wandb` sdk and login to your wandb account.
+
+```shell
+pip install wandb
+wandb login
+```
+
+If you do not have a wandb account, you can make one [here](https://wandb.ai/site).
+
+To visualize and track your model training add the command `enable_wandb: True` to your config yaml file. To add more arguments to the `WandbLogger` listed [here](./config_doc.md) add the header `wandb` to the yaml file and add the arguments under it -
+
+![W&B Args](../imgs/wandb_args.png)
+
+These config variables from the yaml file are used to instantiate the `WandbLogger` object with the project name, entity name (the logged in user by default), directory to store metadata (`./wandb` by default) and more. During the training process, the `log_metrics` function is called to log training and evaluation metrics at the training and evaluation steps respectively from the rank 0 process only.
+
+At every model saving step, the WandbLogger, logs the model using the `log_model` function along with relavant metadata and tags to W&B if the `log_model` flag is set as true.
+
+The W&B logger, also supports visualizing images that are generated during the training process. An example W&B dashboard is available [here](https://wandb.ai/manan-goel/paddlegan/runs/24ezjmlz).
+
+P.S. - You can use both VisualDL and Weights & Biases together if you want by just enabling both of them.
+
#### Resume Training
The checkpoint of the previous epoch is saved in `output_dir` by default during the training process to facilitate resuming the training.
diff --git a/docs/imgs/wandb_args.png b/docs/imgs/wandb_args.png
new file mode 100644
index 000000000..585a7a0b7
Binary files /dev/null and b/docs/imgs/wandb_args.png differ
diff --git a/docs/imgs/wandb_dashboard.png b/docs/imgs/wandb_dashboard.png
new file mode 100644
index 000000000..6b742d53d
Binary files /dev/null and b/docs/imgs/wandb_dashboard.png differ
diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py
index fa2dfa0d6..41f2fa6f2 100755
--- a/ppgan/engine/trainer.py
+++ b/ppgan/engine/trainer.py
@@ -29,7 +29,7 @@
from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager
from ..utils.profiler import add_profiler_step
-
+from ..utils.loggers import *
class IterLoader:
def __init__(self, dataloader):
@@ -115,9 +115,20 @@ def __init__(self, cfg):
self.is_save_img = validate_cfg['save_img']
self.enable_visualdl = cfg.get('enable_visualdl', False)
+ self.enable_wandb = cfg.get('enable_wandb', False)
+
+ loggers = []
if self.enable_visualdl:
- import visualdl
- self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
+ self.vdl_logger = VDLLogger(cfg.output_dir)
+ loggers.append(self.vdl_logger)
+
+ if self.enable_wandb:
+ if "wandb" in cfg:
+ wandb_config = cfg.wandb
+ self.wandb_logger = WandbLogger(config=cfg, **wandb_config)
+ loggers.append(self.wandb_logger)
+
+ self.loggers = Loggers(loggers)
# evaluate only
if not cfg.is_train:
@@ -284,6 +295,8 @@ def test(self):
for metric_name, metric in self.metrics.items():
self.logger.info("Metric {}: {:.4f}".format(
metric_name, metric.accumulate()))
+ if self.local_rank == 0:
+ self.loggers.log_metrics(self.metrics, prefix="test")
def print_log(self):
losses = self.model.get_current_losses()
@@ -298,10 +311,11 @@ def print_log(self):
message += f'lr: {self.current_learning_rate:.3e} '
+ if self.local_rank == 0:
+ self.loggers.log_metrics(losses, step=self.current_iter)
+
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
- if self.enable_visualdl:
- self.vdl_logger.add_scalar(k, v, step=self.global_steps)
if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time
@@ -350,17 +364,22 @@ def visual(self,
min_max = (-1., 1.)
image_num = self.cfg.get('image_num', None)
- if (image_num is None) or (not self.enable_visualdl):
+ if image_num is None:
image_num = 1
+
+ if self.local_rank == 0:
+ self.loggers.log_images(
+ visual_results,
+ image_num,
+ min_max,
+ results_dir,
+ dataformats="HWC" if image_num == 1 else "NCHW",
+ step=step if step else self.current_iter
+ )
+
for label, image in visual_results.items():
image_numpy = tensor2img(image, min_max, image_num)
- if (not is_save_image) and self.enable_visualdl:
- self.vdl_logger.add_image(
- results_dir + '/' + label,
- image_numpy,
- step=step if step else self.global_steps,
- dataformats="HWC" if image_num == 1 else "NCHW")
- else:
+ if is_save_image:
if self.cfg.is_train:
if self.by_epoch:
msg = 'epoch%.3d_' % self.current_epoch
@@ -401,6 +420,7 @@ def save(self, epoch, name='checkpoint', keep=1):
state_dicts[opt_name] = opt.state_dict()
save(state_dicts, save_path)
+ self.loggers.log_model(save_path, aliases=[f"epoch {epoch}", name])
if keep > 0:
try:
diff --git a/ppgan/utils/loggers/__init__.py b/ppgan/utils/loggers/__init__.py
new file mode 100644
index 000000000..8f2618551
--- /dev/null
+++ b/ppgan/utils/loggers/__init__.py
@@ -0,0 +1,3 @@
+from .vdl_logger import VDLLogger
+from .wandb_logger import WandbLogger
+from .loggers import Loggers
\ No newline at end of file
diff --git a/ppgan/utils/loggers/base_logger.py b/ppgan/utils/loggers/base_logger.py
new file mode 100644
index 000000000..3a7fc3593
--- /dev/null
+++ b/ppgan/utils/loggers/base_logger.py
@@ -0,0 +1,15 @@
+import os
+from abc import ABC, abstractmethod
+
+class BaseLogger(ABC):
+ def __init__(self, save_dir):
+ self.save_dir = save_dir
+ os.makedirs(self.save_dir, exist_ok=True)
+
+ @abstractmethod
+ def log_metrics(self, metrics, prefix=None):
+ pass
+
+ @abstractmethod
+ def close(self):
+ pass
\ No newline at end of file
diff --git a/ppgan/utils/loggers/loggers.py b/ppgan/utils/loggers/loggers.py
new file mode 100644
index 000000000..08d2ce408
--- /dev/null
+++ b/ppgan/utils/loggers/loggers.py
@@ -0,0 +1,20 @@
+class Loggers(object):
+ def __init__(self, loggers):
+ super().__init__()
+ self.loggers = loggers
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ for logger in self.loggers:
+ logger.log_metrics(metrics, prefix=prefix, step=step)
+
+ def log_model(self, file_path, aliases=None, metadata=None):
+ for logger in self.loggers:
+ logger.log_model(file_path, aliases=aliases, metadata=metadata)
+
+ def log_images(self, results, image_num, min_max, results_dir, dataformats, step=None):
+ for logger in self.loggers:
+ logger.log_images(results, image_num, min_max, results_dir, dataformats, step=None)
+
+ def close(self):
+ for logger in self.loggers:
+ logger.close()
\ No newline at end of file
diff --git a/ppgan/utils/loggers/vdl_logger.py b/ppgan/utils/loggers/vdl_logger.py
new file mode 100644
index 000000000..a675f52a4
--- /dev/null
+++ b/ppgan/utils/loggers/vdl_logger.py
@@ -0,0 +1,43 @@
+import paddle
+
+from ..visual import tensor2img
+from .base_logger import BaseLogger
+
+
+class VDLLogger(BaseLogger):
+ def __init__(self, save_dir):
+ super().__init__(save_dir)
+ try:
+ from visualdl import LogWriter
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install visualdl using `pip install visualdl`"
+ )
+
+ self.vdl_writer = LogWriter(logdir=save_dir)
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ if prefix:
+ updated_metrics = {
+ prefix.lower() + "/" + k: v.item() for k, v in metrics.items() if isinstance(v, paddle.Tensor)
+ }
+ else:
+ updated_metrics = {k: v.item() for k, v in metrics.items()}
+ for k, v in updated_metrics.items():
+ self.vdl_writer.add_scalar(tag=k, value=v, step=step)
+
+ def log_model(self, file_path, aliases=None, metadata=None):
+ pass
+
+ def log_images(self, results, image_num, min_max, results_dir, dataformats, step=None):
+ for label, image in results.items():
+ image_numpy = tensor2img(image, min_max, image_num)
+ self.vdl_writer.add_image(
+ results_dir + "/" + label,
+ image_numpy,
+ step,
+ dataformats=dataformats
+ )
+
+ def close(self):
+ self.vdl_writer.close()
diff --git a/ppgan/utils/loggers/wandb_logger.py b/ppgan/utils/loggers/wandb_logger.py
new file mode 100644
index 000000000..4d33bd1cf
--- /dev/null
+++ b/ppgan/utils/loggers/wandb_logger.py
@@ -0,0 +1,102 @@
+import os
+
+import paddle
+
+from ..visual import tensor2img
+from .base_logger import BaseLogger
+
+
+class WandbLogger(BaseLogger):
+ def __init__(self,
+ project=None,
+ name=None,
+ id=None,
+ entity=None,
+ save_dir=None,
+ config=None,
+ log_model=False,
+ **kwargs):
+ try:
+ import wandb
+ self.wandb = wandb
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install wandb using `pip install wandb`"
+ )
+
+ self.project = project
+ self.name = name
+ self.id = id
+ self.save_dir = save_dir
+ self.config = config
+ self.kwargs = kwargs
+ self.entity = entity
+ self._run = None
+ self._wandb_init = dict(
+ project=self.project,
+ name=self.name,
+ id=self.id,
+ entity=self.entity,
+ dir=self.save_dir,
+ resume="allow"
+ )
+ self.model_logging = log_model
+ self._wandb_init.update(**kwargs)
+
+ _ = self.run
+
+ if self.config:
+ self.run.config.update(self.config)
+
+ @property
+ def run(self):
+ if self._run is None:
+ if self.wandb.run is not None:
+ print(
+ "There is a wandb run already in progress "
+ "and newly created instances of `WandbLogger` will reuse"
+ " this run. If this is not desired, call `wandb.finish()`"
+ "before instantiating `WandbLogger`."
+ )
+ self._run = self.wandb.run
+ else:
+ self._run = self.wandb.init(**self._wandb_init)
+ return self._run
+
+ def log_metrics(self, metrics, prefix=None, step=None):
+ if prefix:
+ updated_metrics = {
+ prefix.lower() + "/" + k: v.item() for k, v in metrics.items() if isinstance(v, paddle.Tensor)
+ }
+ else:
+ updated_metrics = {k: v.item() for k, v in metrics.items()}
+ self.run.log(updated_metrics, step=step)
+
+ def log_model(self, file_path, aliases=None, metadata=None):
+ if self.model_logging == False:
+ return
+ artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
+ artifact.add_file(file_path, name="model_ckpt.pkl")
+
+ self.run.log_artifact(artifact, aliases=aliases)
+
+ def log_images(self, results, image_num, min_max, results_dir, dataformats, step=None):
+ reqd = dict()
+ for label, image in results.items():
+ image_numpy = tensor2img(image, min_max, image_num)
+
+ images = []
+ if dataformats == 'HWC':
+ images.append(self.wandb.Image(image_numpy))
+ elif dataformats == 'NCHW':
+ for img in image_numpy:
+ images.append(self.wandb.Image(img.transpose(1, 2, 0)))
+
+ reqd.update({
+ results_dir + "/" + label: images
+ })
+
+ self.run.log(reqd)
+
+ def close(self):
+ self.run.finish()