forked from msylvester/Pokemon-Tonail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathviz_all_cumulative.py
131 lines (105 loc) · 4.12 KB
/
viz_all_cumulative.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
def load_replay_data(replays_dir="replays"):
"""Load all replay files and extract cumulative rewards at each step."""
all_step_rewards = {} # step -> list of cumulative rewards at that step
# Get list of replay files
replay_files = [f for f in os.listdir(replays_dir) if f.endswith(".pkl")]
print("Loading replay files...")
for filename in tqdm(replay_files):
path = os.path.join(replays_dir, filename)
with open(path, "rb") as f:
replay_buffer = pickle.load(f)
# Track cumulative rewards for each step in this episode
for step, experience in enumerate(replay_buffer, 1):
if step not in all_step_rewards:
all_step_rewards[step] = []
all_step_rewards[step].append(experience["cumulative_reward"])
return all_step_rewards
def analyze_distributions(all_step_rewards, max_steps=3000):
"""Calculate statistics for each step."""
steps = range(1, min(max_steps + 1, max(all_step_rewards.keys()) + 1))
medians = []
percentile_05 = []
percentile_25 = []
percentile_75 = []
percentile_10 = []
percentile_90 = []
percentile_99 = []
print("Analyzing distributions...")
for step in tqdm(steps):
if step in all_step_rewards:
rewards = all_step_rewards[step]
medians.append(np.median(rewards))
percentile_05.append(np.percentile(rewards, 5))
percentile_25.append(np.percentile(rewards, 25))
percentile_75.append(np.percentile(rewards, 75))
percentile_10.append(np.percentile(rewards, 10))
percentile_90.append(np.percentile(rewards, 90))
percentile_99.append(np.percentile(rewards, 99))
else:
# Handle missing steps if any
medians.append(np.nan)
percentile_05.append(np.nan)
percentile_25.append(np.nan)
percentile_75.append(np.nan)
percentile_10.append(np.nan)
percentile_90.append(np.nan)
percentile_99.append(np.nan)
return (
steps,
medians,
percentile_05,
percentile_25,
percentile_75,
percentile_10,
percentile_90,
percentile_99,
)
def create_visualization(data, output_file="cumulative_rewards_distribution.png"):
"""Create and save the visualization."""
steps, medians, p05, p25, p75, p10, p90, p99 = data
plt.figure(figsize=(15, 10))
# Plot the different percentile ranges
plt.fill_between(
steps, p05, p99, alpha=0.1, color="red", label="5th-99th percentile"
)
plt.fill_between(
steps, p10, p90, alpha=0.2, color="blue", label="10th-90th percentile"
)
plt.fill_between(
steps, p25, p75, alpha=0.3, color="blue", label="25th-75th percentile"
)
plt.plot(steps, medians, color="blue", label="Median", linewidth=2)
plt.xlabel("Step Number")
plt.ylabel("Cumulative Reward")
plt.title("Distribution of Cumulative Rewards Across Steps")
plt.legend()
plt.grid(True, alpha=0.3)
# Save the plot
plt.savefig(output_file)
print(f"Visualization saved to {output_file}")
plt.close()
def main():
# Load all replay data
all_step_rewards = load_replay_data()
# Analyze distributions
distribution_data = analyze_distributions(all_step_rewards)
# Create and save visualization
create_visualization(distribution_data)
# Print some summary statistics
final_step = max(all_step_rewards.keys())
final_rewards = all_step_rewards[final_step]
print("\nSummary Statistics for Final Step:")
print(f"Number of episodes: {len(final_rewards)}")
print(f"Median reward: {np.median(final_rewards):.2f}")
print(f"Mean reward: {np.mean(final_rewards):.2f}")
print(f"90th percentile: {np.percentile(final_rewards, 90):.2f}")
print(f"10th percentile: {np.percentile(final_rewards, 10):.2f}")
print(f"Max reward: {np.max(final_rewards):.2f}")
print(f"Min reward: {np.min(final_rewards):.2f}")
if __name__ == "__main__":
main()