From 13de38d17dd5a1c67e5327328c01837c6e40f87d Mon Sep 17 00:00:00 2001
From: Dirk Groeneveld <dirkg@allenai.org>
Date: Thu, 19 Aug 2021 11:32:37 -0700
Subject: [PATCH] Log batch metrics (#5362)

* Lists should be lists

* Formatting

* By default, don't log parameter stats

* Log batch metrics

* Changelog

* Don't try to be more general than Patton
---
 CHANGELOG.md                               |  4 +++-
 allennlp/training/callbacks/log_writer.py  | 12 ++++++++++--
 allennlp/training/callbacks/tensorboard.py |  2 +-
 3 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index d96ca710c1e..9beb3526d8d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
   `self.ddp_accelerator` during distributed training. This is useful when, for example, instantiating submodules in your
   model's `__init__()` method by wrapping them with `self.ddp_accelerator.wrap_module()`. See the `allennlp.modules.transformer.t5`
   for an example.
+- We now log batch metrics to tensorboard and wandb.
 - Added Tango components, to be explored in detail in a later post
 - Added `ScaledDotProductMatrixAttention`, and converted the transformer toolkit to use it
 - Added tests to ensure that all `Attention` and `MatrixAttention` implementations are interchangeable
@@ -46,7 +47,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
   with a default value of `False`. `False` means gradients are not rescaled and the gradient
   norm is never even calculated. `True` means the gradients are still not rescaled but the gradient
   norm is calculated and passed on to callbacks. A `float` value means gradients are rescaled.
-- `TensorCache` now supports more concurrent readers and writers. 
+- `TensorCache` now supports more concurrent readers and writers.
+- We no longer log parameter statistics to tensorboard or wandb by default.
 
 
 ## [v2.6.0](https://github.com/allenai/allennlp/releases/tag/v2.6.0) - 2021-07-19
diff --git a/allennlp/training/callbacks/log_writer.py b/allennlp/training/callbacks/log_writer.py
index 8b3183f28c3..244656c060c 100644
--- a/allennlp/training/callbacks/log_writer.py
+++ b/allennlp/training/callbacks/log_writer.py
@@ -227,7 +227,8 @@ def log_batch(
 
             # Now collect per-batch metrics to log.
             metrics_to_log: Dict[str, float] = {}
-            for key in ("batch_loss", "batch_reg_loss"):
+            batch_loss_metrics = {"batch_loss", "batch_reg_loss"}
+            for key in batch_loss_metrics:
                 if key not in metrics:
                     continue
                 value = metrics[key]
@@ -241,6 +242,13 @@ def log_batch(
                     self._batch_loss_moving_items[key]
                 )
 
+            for key, value in metrics.items():
+                if key in batch_loss_metrics:
+                    continue
+                key = "batch_" + key
+                if key not in metrics_to_log:
+                    metrics_to_log[key] = value
+
             self.log_scalars(
                 metrics_to_log,
                 log_prefix="train",
@@ -253,7 +261,7 @@ def log_batch(
 
         if self._batch_size_interval:
             # We're assuming here that `log_batch` will get called every batch, and only every
-            # batch.  This is true with our current usage of this code (version 1.0); if that
+            # batch. This is true with our current usage of this code (version 1.0); if that
             # assumption becomes wrong, this code will break.
             batch_group_size = sum(get_batch_size(batch) for batch in batch_group)  # type: ignore
             self._cumulative_batch_group_size += batch_group_size
diff --git a/allennlp/training/callbacks/tensorboard.py b/allennlp/training/callbacks/tensorboard.py
index 73bc04a686a..630c0b70996 100644
--- a/allennlp/training/callbacks/tensorboard.py
+++ b/allennlp/training/callbacks/tensorboard.py
@@ -21,7 +21,7 @@ def __init__(
         summary_interval: int = 100,
         distribution_interval: Optional[int] = None,
         batch_size_interval: Optional[int] = None,
-        should_log_parameter_statistics: bool = True,
+        should_log_parameter_statistics: bool = False,
         should_log_learning_rate: bool = False,
     ) -> None:
         super().__init__(