-
Notifications
You must be signed in to change notification settings - Fork 0
/
misc.py
62 lines (56 loc) · 1.91 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
import os
from itertools import islice
from matplotlib import pyplot as plt
from matplotlib.ticker import ScalarFormatter, MultipleLocator
def window(seq, n=2):
"Returns a sliding window (of width n) over data from the iterable"
it = iter(seq)
result = tuple(islice(it, n))
if len(result) == n:
yield result
for elem in it:
result = result[1:] + (elem,)
yield result
def plot_learning_stats(learning_history, title: str, grid=True, log_scale=False,
figsize=(10, 8), show=True, *args, **kwargs):
fig, ax = plt.subplots(figsize=figsize, dpi=100)
ax.set_title(title)
if grid:
plt.grid()
if isinstance(learning_history, list):
ax.plot(learning_history)
else:
lines = {
'median': ax.plot(np.median(learning_history, axis=0),
label='mediana')[0],
'max': ax.plot(np.max(learning_history, axis=0), label='max')[0],
'min': ax.plot(np.min(learning_history, axis=0), label='min')[0]
}
ax.legend(lines.values(), lines.keys())
ax.set_xlabel('episod')
ax.set_ylabel('ilość kroków')
if log_scale:
plt.yscale('log')
if show:
plt.show()
return fig, ax
def compare_learning_curves(named_learning_histories: dict, title: str, log_scale=False,
grid=True, figsize=(10, 8), show=True,
*args, **kwargs):
fig, ax = plt.subplots(figsize=figsize, dpi=100)
ax.set_title(title)
if grid:
plt.grid()
lines = {
name: ax.plot(np.median(history, axis=0), label=name)[0]
for name, history in named_learning_histories.items()
}
ax.legend(lines.values(), lines.keys())
ax.set_xlabel('episod')
ax.set_ylabel('ilość kroków')
if log_scale:
plt.yscale('log')
if show:
plt.show()
return fig, ax