Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add plot_logs tool #426

Merged
merged 2 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,25 @@ python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --ou
```shell
python tools/print_config.py ${CONFIG} [-h] [--options ${OPTIONS [OPTIONS...]}]
```

### Plot training logs

`tools/analyze_logs.py` plot s loss/mIoU curves given a training log file. `pip install seaborn` first to install the dependency.

```shell
python tools/analyze_logs.py xxx.log.json [--keys ${KEYS}] [--legend ${LEGEND}] [--backend ${BACKEND}] [--style ${STYLE}] [--out ${OUT_FILE}]
```

Examples:

- Plot the mIoU, mAcc, aAcc metrics.

```shell
python tools/analyze_logs.py log.json --keys mIoU mAcc aAcc --legend mIoU mAcc aAcc
```

- Plot loss metric.

```shell
python tools/analyze_logs.py log.json --keys loss --legend loss
```
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,seaborn,terminaltables,torch
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
123 changes: 123 additions & 0 deletions tools/analyze_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Modified from https://github.com/open-
mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
import argparse
import json
from collections import defaultdict

import matplotlib.pyplot as plt
import seaborn as sns


def plot_curve(log_dicts, args):
if args.backend is not None:
plt.switch_backend(args.backend)
sns.set_style(args.style)
# if legend is None, use {filename}_{key} as legend
legend = args.legend
if legend is None:
legend = []
for json_log in args.json_logs:
for metric in args.keys:
legend.append(f'{json_log}_{metric}')
assert len(legend) == (len(args.json_logs) * len(args.keys))
metrics = args.keys

num_metrics = len(metrics)
for i, log_dict in enumerate(log_dicts):
epochs = list(log_dict.keys())
for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
plot_epochs = []
plot_iters = []
plot_values = []
for epoch in epochs:
epoch_logs = log_dict[epoch]
if metric not in epoch_logs.keys():
continue
if metric in ['mIoU', 'mAcc', 'aAcc']:
plot_epochs.append(epoch)
plot_values.append(epoch_logs[metric][0])
else:
for idx in range(len(epoch_logs[metric])):
plot_iters.append(epoch_logs['iter'][idx])
plot_values.append(epoch_logs[metric][idx])
ax = plt.gca()
label = legend[i * num_metrics + j]
if metric in ['mIoU', 'mAcc', 'aAcc']:
ax.set_xticks(plot_epochs)
plt.xlabel('epoch')
plt.plot(plot_epochs, plot_values, label=label, marker='o')
else:
plt.xlabel('iter')
plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
plt.legend()
if args.title is not None:
plt.title(args.title)
if args.out is None:
plt.show()
else:
print(f'save curve to: {args.out}')
plt.savefig(args.out)
plt.cla()


def parse_args():
parser = argparse.ArgumentParser(description='Analyze Json Log')
parser.add_argument(
'json_logs',
type=str,
nargs='+',
help='path of train log in json format')
parser.add_argument(
'--keys',
type=str,
nargs='+',
default=['mIoU'],
help='the metric that you want to plot')
parser.add_argument('--title', type=str, help='title of figure')
parser.add_argument(
'--legend',
type=str,
nargs='+',
default=None,
help='legend of each plot')
parser.add_argument(
'--backend', type=str, default=None, help='backend of plt')
parser.add_argument(
'--style', type=str, default='dark', help='style of plt')
parser.add_argument('--out', type=str, default=None)
args = parser.parse_args()
return args


def load_json_logs(json_logs):
# load and convert json_logs to log_dict, key is epoch, value is a sub dict
# keys of sub dict is different metrics
# value of sub dict is a list of corresponding values of all iterations
log_dicts = [dict() for _ in json_logs]
for json_log, log_dict in zip(json_logs, log_dicts):
with open(json_log, 'r') as log_file:
for line in log_file:
log = json.loads(line.strip())
# skip lines without `epoch` field
if 'epoch' not in log:
continue
epoch = log.pop('epoch')
if epoch not in log_dict:
log_dict[epoch] = defaultdict(list)
for k, v in log.items():
log_dict[epoch][k].append(v)
return log_dicts


def main():
args = parse_args()
json_logs = args.json_logs
for json_log in json_logs:
assert json_log.endswith('.json')
log_dicts = load_json_logs(json_logs)
plot_curve(log_dicts, args)


if __name__ == '__main__':
main()