-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: process and plotting for full data set
- Loading branch information
Showing
21 changed files
with
2,203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,6 +141,7 @@ processed_*.json | |
|
||
# Plots folder | ||
plots | ||
data | ||
|
||
# csv data | ||
*.csv | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import json | ||
|
||
|
||
def duplicate_seed_data(data, env_name, task_name, algo_name, missing_seed, source_seed): | ||
if env_name in data: | ||
if task_name in data[env_name]: | ||
if algo_name in data[env_name][task_name]: | ||
if source_seed in data[env_name][task_name][algo_name]: | ||
# Duplicate the data | ||
data[env_name][task_name][algo_name][missing_seed] = data[env_name][task_name][algo_name][source_seed] | ||
print(f"Duplicated data for {env_name}/{task_name}/{algo_name}/{missing_seed}") | ||
else: | ||
print(f"Source seed {source_seed} not found for {env_name}/{task_name}/{algo_name}") | ||
else: | ||
print(f"Algorithm {algo_name} not found for {env_name}/{task_name}") | ||
else: | ||
print(f"Task {task_name} not found for {env_name}") | ||
else: | ||
print(f"Environment {env_name} not found") | ||
|
||
# Load the JSON file | ||
file_path = './data/full-benchmark-update/merged_data/metrics_winrate_processed_no_retmat.json' | ||
new_file_path = './data/full-benchmark-update/merged_data/interim_seed_duplicated.json' | ||
with open(file_path, 'r') as file: | ||
data = json.load(file) | ||
|
||
# Duplicate data for the first case | ||
duplicate_seed_data(data, 'Cleaner', 'clean-15x15x6a', 'ff_mappo', 'seed_9', 'seed_8') | ||
|
||
# Duplicate data for the second case | ||
duplicate_seed_data(data, 'Cleaner', 'clean-15x15x6a', 'retmat_memory', 'seed_4', 'seed_8') | ||
|
||
# Save the modified data back to the JSON file | ||
with open(new_file_path, 'w') as file: | ||
json.dump(data, file, indent=2) | ||
|
||
print("JSON file has been updated.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import json | ||
|
||
|
||
def filter_json(data, tasks_to_keep): | ||
filtered_data = {} | ||
for env_name, env_tasks in data.items(): | ||
kept_tasks = {task: info for task, info in env_tasks.items() if task in tasks_to_keep} | ||
if kept_tasks: | ||
filtered_data[env_name] = kept_tasks | ||
return filtered_data | ||
|
||
# Example usage: | ||
input_file = 'data/limited_benchmark/retmat-mat-ppo/merged_data/metrics_winrate_processed.json' | ||
output_file = 'data/limited_benchmark/retmat-mat-ppo/merged_data/task_name_processed.json' | ||
tasks_to_keep = [ | ||
'tiny-4ag', | ||
'small-4ag', | ||
'5m_vs_6m', | ||
'27m_vs_30m', | ||
'smacv2_10_units', | ||
'15x15-3p-5f', | ||
'15x15-4p-5f', | ||
'6h_vs_8z', | ||
] # Replace with your list of tasks to keep | ||
|
||
# Read the input JSON file | ||
with open(input_file, 'r') as f: | ||
data = json.load(f) | ||
|
||
# Filter the data | ||
filtered_data = filter_json(data, tasks_to_keep) | ||
|
||
# Write the filtered data to the output JSON file | ||
with open(output_file, 'w') as f: | ||
json.dump(filtered_data, f, indent=2) | ||
|
||
print(f"Filtered data has been written to {output_file}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import json | ||
|
||
import numpy as np | ||
|
||
|
||
def remove_win_rate(data): | ||
if isinstance(data, dict): | ||
for key in list(data.keys()): | ||
if key == 'win_rate': | ||
del data[key] | ||
else: | ||
data[key] = remove_win_rate(data[key]) | ||
elif isinstance(data, list): | ||
return [remove_win_rate(item) for item in data] | ||
return data | ||
|
||
def process_json_data(input_file, output_file): | ||
# Load the JSON data | ||
with open(input_file, 'r') as f: | ||
data = json.load(f) | ||
|
||
# Remove win_rate from the data | ||
data = remove_win_rate(data) | ||
|
||
# Find min and max values for each environment | ||
env_min_max = {} | ||
for env_name, env_data in data.items(): | ||
all_returns = [] | ||
for task_data in env_data.values(): | ||
for algo_data in task_data.values(): | ||
for seed_data in algo_data.values(): | ||
# Add absolute metrics | ||
if 'absolute_metrics' in seed_data: | ||
all_returns.extend(seed_data['absolute_metrics'].get('mean_episode_return', [])) | ||
|
||
# Add step metrics | ||
for step_data in seed_data.values(): | ||
if isinstance(step_data, dict) and 'mean_episode_return' in step_data: | ||
all_returns.extend(step_data['mean_episode_return']) | ||
|
||
if all_returns: | ||
env_min_max[env_name] = (min(all_returns), max(all_returns)) | ||
else: | ||
print(f"Warning: No valid mean_episode_return values found for environment {env_name}") | ||
env_min_max[env_name] = (0, 1) # Default range if no data | ||
|
||
# Min-max normalize the data | ||
for env_name, env_data in data.items(): | ||
env_min, env_max = env_min_max[env_name] | ||
if env_min == env_max: | ||
print(f"Warning: All mean_episode_return values are the same for environment {env_name}") | ||
env_max = env_min + 1 # Avoid division by zero | ||
|
||
for task_data in env_data.values(): | ||
for algo_data in task_data.values(): | ||
for seed_data in algo_data.values(): | ||
# Normalize absolute metrics | ||
if 'absolute_metrics' in seed_data: | ||
seed_data['absolute_metrics']['mean_episode_return'] = [ | ||
(x - env_min) / (env_max - env_min) if env_max != env_min else 0.5 | ||
for x in seed_data['absolute_metrics'].get('mean_episode_return', []) | ||
] | ||
|
||
# Normalize step metrics | ||
for step_data in seed_data.values(): | ||
if isinstance(step_data, dict) and 'mean_episode_return' in step_data: | ||
step_data['mean_episode_return'] = [ | ||
(x - env_min) / (env_max - env_min) if env_max != env_min else 0.5 | ||
for x in step_data['mean_episode_return'] | ||
] | ||
|
||
# Combine all environments under 'AllEnvs' | ||
all_envs_data = {} | ||
for env_data in data.values(): | ||
for task_name, task_data in env_data.items(): | ||
if task_name not in all_envs_data: | ||
all_envs_data[task_name] = {} | ||
all_envs_data[task_name].update(task_data) | ||
|
||
# Create the final output structure | ||
output_data = {'AllEnvs': all_envs_data} | ||
|
||
# Save the processed data to a new JSON file | ||
with open(output_file, 'w') as f: | ||
json.dump(output_data, f, indent=2) | ||
|
||
print(f"Processed data saved to {output_file}") | ||
|
||
# Usage | ||
input_file = 'data/full-benchmark-update/merged_data/interim_seed_duplicated_cleaner_filter.json' | ||
output_file = 'data/full-benchmark-update/merged_data/master_norm_episode_return.json' | ||
process_json_data(input_file, output_file) |
Oops, something went wrong.