Skip to content

Commit

Permalink
feat: process and plotting for full data set
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Jul 30, 2024
1 parent 5869627 commit 0864455
Show file tree
Hide file tree
Showing 21 changed files with 2,203 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ processed_*.json

# Plots folder
plots
data

# csv data
*.csv
Expand Down
37 changes: 37 additions & 0 deletions duplicate_seed_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json


def duplicate_seed_data(data, env_name, task_name, algo_name, missing_seed, source_seed):
if env_name in data:
if task_name in data[env_name]:
if algo_name in data[env_name][task_name]:
if source_seed in data[env_name][task_name][algo_name]:
# Duplicate the data
data[env_name][task_name][algo_name][missing_seed] = data[env_name][task_name][algo_name][source_seed]
print(f"Duplicated data for {env_name}/{task_name}/{algo_name}/{missing_seed}")
else:
print(f"Source seed {source_seed} not found for {env_name}/{task_name}/{algo_name}")
else:
print(f"Algorithm {algo_name} not found for {env_name}/{task_name}")
else:
print(f"Task {task_name} not found for {env_name}")
else:
print(f"Environment {env_name} not found")

# Load the JSON file
file_path = './data/full-benchmark-update/merged_data/metrics_winrate_processed_no_retmat.json'
new_file_path = './data/full-benchmark-update/merged_data/interim_seed_duplicated.json'
with open(file_path, 'r') as file:
data = json.load(file)

# Duplicate data for the first case
duplicate_seed_data(data, 'Cleaner', 'clean-15x15x6a', 'ff_mappo', 'seed_9', 'seed_8')

# Duplicate data for the second case
duplicate_seed_data(data, 'Cleaner', 'clean-15x15x6a', 'retmat_memory', 'seed_4', 'seed_8')

# Save the modified data back to the JSON file
with open(new_file_path, 'w') as file:
json.dump(data, file, indent=2)

print("JSON file has been updated.")
37 changes: 37 additions & 0 deletions keep_certain_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import json


def filter_json(data, tasks_to_keep):
filtered_data = {}
for env_name, env_tasks in data.items():
kept_tasks = {task: info for task, info in env_tasks.items() if task in tasks_to_keep}
if kept_tasks:
filtered_data[env_name] = kept_tasks
return filtered_data

# Example usage:
input_file = 'data/limited_benchmark/retmat-mat-ppo/merged_data/metrics_winrate_processed.json'
output_file = 'data/limited_benchmark/retmat-mat-ppo/merged_data/task_name_processed.json'
tasks_to_keep = [
'tiny-4ag',
'small-4ag',
'5m_vs_6m',
'27m_vs_30m',
'smacv2_10_units',
'15x15-3p-5f',
'15x15-4p-5f',
'6h_vs_8z',
] # Replace with your list of tasks to keep

# Read the input JSON file
with open(input_file, 'r') as f:
data = json.load(f)

# Filter the data
filtered_data = filter_json(data, tasks_to_keep)

# Write the filtered data to the output JSON file
with open(output_file, 'w') as f:
json.dump(filtered_data, f, indent=2)

print(f"Filtered data has been written to {output_file}")
7 changes: 7 additions & 0 deletions marl_eval/plotting_tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def performance_profiles(
upper_algo_dict = {algo.upper(): value for algo, value in data_dictionary.items()}
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())
algorithms.sort(reverse=True)

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
Expand All @@ -73,6 +74,7 @@ def performance_profiles(
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())
algorithms.sort(reverse=True)

if metric_name in metrics_to_normalize:
xlabel = "Normalized " + " ".join(metric_name.split("_"))
Expand Down Expand Up @@ -140,6 +142,7 @@ def aggregate_scores(
upper_algo_dict = {algo.upper(): value for algo, value in data_dictionary.items()}
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())
algorithms.sort(reverse=True)

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
Expand All @@ -148,6 +151,7 @@ def aggregate_scores(
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())
algorithms.sort(reverse=True)

aggregate_func = lambda x: np.array( # noqa: E731
[
Expand Down Expand Up @@ -346,6 +350,7 @@ def sample_efficiency_curves(
upper_algo_dict = {algo.upper(): value for algo, value in data_dictionary.items()}
data_dictionary = upper_algo_dict
algorithms = list(data_dictionary.keys())
algorithms.sort(reverse=True)

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
Expand All @@ -354,6 +359,7 @@ def sample_efficiency_curves(
legend_map[algo]: value for algo, value in data_dictionary.items()
}
algorithms = list(data_dictionary.keys())
algorithms.sort(reverse=True)

# Find lowest values from amount of runs that have completed
# across all algorithms
Expand Down Expand Up @@ -441,6 +447,7 @@ def plot_single_task(
task_mean_ci_data = upper_algo_dict
algorithms = list(task_mean_ci_data.keys())
algorithms.remove("extra")
algorithms.sort(reverse=True)

if legend_map is not None:
legend_map = {algo.upper(): value for algo, value in legend_map.items()}
Expand Down
92 changes: 92 additions & 0 deletions master_episode_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import json

import numpy as np


def remove_win_rate(data):
if isinstance(data, dict):
for key in list(data.keys()):
if key == 'win_rate':
del data[key]
else:
data[key] = remove_win_rate(data[key])
elif isinstance(data, list):
return [remove_win_rate(item) for item in data]
return data

def process_json_data(input_file, output_file):
# Load the JSON data
with open(input_file, 'r') as f:
data = json.load(f)

# Remove win_rate from the data
data = remove_win_rate(data)

# Find min and max values for each environment
env_min_max = {}
for env_name, env_data in data.items():
all_returns = []
for task_data in env_data.values():
for algo_data in task_data.values():
for seed_data in algo_data.values():
# Add absolute metrics
if 'absolute_metrics' in seed_data:
all_returns.extend(seed_data['absolute_metrics'].get('mean_episode_return', []))

# Add step metrics
for step_data in seed_data.values():
if isinstance(step_data, dict) and 'mean_episode_return' in step_data:
all_returns.extend(step_data['mean_episode_return'])

if all_returns:
env_min_max[env_name] = (min(all_returns), max(all_returns))
else:
print(f"Warning: No valid mean_episode_return values found for environment {env_name}")
env_min_max[env_name] = (0, 1) # Default range if no data

# Min-max normalize the data
for env_name, env_data in data.items():
env_min, env_max = env_min_max[env_name]
if env_min == env_max:
print(f"Warning: All mean_episode_return values are the same for environment {env_name}")
env_max = env_min + 1 # Avoid division by zero

for task_data in env_data.values():
for algo_data in task_data.values():
for seed_data in algo_data.values():
# Normalize absolute metrics
if 'absolute_metrics' in seed_data:
seed_data['absolute_metrics']['mean_episode_return'] = [
(x - env_min) / (env_max - env_min) if env_max != env_min else 0.5
for x in seed_data['absolute_metrics'].get('mean_episode_return', [])
]

# Normalize step metrics
for step_data in seed_data.values():
if isinstance(step_data, dict) and 'mean_episode_return' in step_data:
step_data['mean_episode_return'] = [
(x - env_min) / (env_max - env_min) if env_max != env_min else 0.5
for x in step_data['mean_episode_return']
]

# Combine all environments under 'AllEnvs'
all_envs_data = {}
for env_data in data.values():
for task_name, task_data in env_data.items():
if task_name not in all_envs_data:
all_envs_data[task_name] = {}
all_envs_data[task_name].update(task_data)

# Create the final output structure
output_data = {'AllEnvs': all_envs_data}

# Save the processed data to a new JSON file
with open(output_file, 'w') as f:
json.dump(output_data, f, indent=2)

print(f"Processed data saved to {output_file}")

# Usage
input_file = 'data/full-benchmark-update/merged_data/interim_seed_duplicated_cleaner_filter.json'
output_file = 'data/full-benchmark-update/merged_data/master_norm_episode_return.json'
process_json_data(input_file, output_file)
Loading

0 comments on commit 0864455

Please sign in to comment.