diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a932b82 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.DS_Store +__pycache__ +PPO_logs +PPO_figs +rocket \ No newline at end of file diff --git a/PPO.py b/PPO.py index 49dae36..2f85585 100644 --- a/PPO.py +++ b/PPO.py @@ -5,14 +5,18 @@ ################################## set device ################################## print("============================================================================================") -# set device to cpu or cuda -device = torch.device('cpu') -if(torch.cuda.is_available()): - device = torch.device('cuda:0') + +if torch.cuda.is_available(): + device = torch.device("cuda:0") torch.cuda.empty_cache() - print("Device set to : " + str(torch.cuda.get_device_name(device))) + print("Device set to:", torch.cuda.get_device_name(device)) +elif torch.backends.mps.is_available(): + device = torch.device("mps") + print("Device set to: MPS (Apple Silicon)") else: - print("Device set to : cpu") + device = torch.device("cpu") + print("Device set to: CPU") + print("============================================================================================") @@ -47,31 +51,31 @@ def __init__(self, state_dim, action_dim, has_continuous_action_space, action_st # actor if has_continuous_action_space : self.actor = nn.Sequential( - nn.Linear(state_dim, 64), - nn.Tanh(), - nn.Linear(64, 64), - nn.Tanh(), - nn.Linear(64, action_dim), - nn.Tanh() - ) + nn.Linear(state_dim, 64), + nn.Tanh(), + nn.Linear(64, 64), + nn.Tanh(), + nn.Linear(64, action_dim), + nn.Tanh() + ) else: self.actor = nn.Sequential( - nn.Linear(state_dim, 64), - nn.Tanh(), - nn.Linear(64, 64), - nn.Tanh(), - nn.Linear(64, action_dim), - nn.Softmax(dim=-1) - ) + nn.Linear(state_dim, 64), + nn.Tanh(), + nn.Linear(64, 64), + nn.Tanh(), + nn.Linear(64, action_dim), + nn.Softmax(dim=-1) + ) # critic self.critic = nn.Sequential( - nn.Linear(state_dim, 64), - nn.Tanh(), - nn.Linear(64, 64), - nn.Tanh(), - nn.Linear(64, 1) - ) - + nn.Linear(state_dim, 64), + nn.Tanh(), + nn.Linear(64, 64), + nn.Tanh(), + nn.Linear(64, 1) + ) + def set_action_std(self, new_action_std): if self.has_continuous_action_space: self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device) @@ -137,9 +141,9 @@ def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device) self.optimizer = torch.optim.Adam([ - {'params': self.policy.actor.parameters(), 'lr': lr_actor}, - {'params': self.policy.critic.parameters(), 'lr': lr_critic} - ]) + {'params': self.policy.actor.parameters(), 'lr': lr_actor}, + {'params': self.policy.critic.parameters(), 'lr': lr_critic} + ]) self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device) self.policy_old.load_state_dict(self.policy.state_dict()) @@ -173,29 +177,24 @@ def decay_action_std(self, action_std_decay_rate, min_action_std): print("--------------------------------------------------------------------------------------------") def select_action(self, state): - if self.has_continuous_action_space: with torch.no_grad(): state = torch.FloatTensor(state).to(device) action, action_logprob, state_val = self.policy_old.act(state) - - self.buffer.states.append(state) - self.buffer.actions.append(action) - self.buffer.logprobs.append(action_logprob) - self.buffer.state_values.append(state_val) - - return action.detach().cpu().numpy().flatten() + self.buffer.states.append(state) + self.buffer.actions.append(action) + self.buffer.logprobs.append(action_logprob) + self.buffer.state_values.append(state_val) + return action.detach().cpu().numpy().flatten() else: with torch.no_grad(): state = torch.FloatTensor(state).to(device) action, action_logprob, state_val = self.policy_old.act(state) - - self.buffer.states.append(state) - self.buffer.actions.append(action) - self.buffer.logprobs.append(action_logprob) - self.buffer.state_values.append(state_val) - - return action.item() + self.buffer.states.append(state) + self.buffer.actions.append(action) + self.buffer.logprobs.append(action_logprob) + self.buffer.state_values.append(state_val) + return action.item() def update(self): # Monte Carlo estimate of returns @@ -252,11 +251,13 @@ def update(self): def save(self, checkpoint_path): torch.save(self.policy_old.state_dict(), checkpoint_path) - + torch.save(self.policy.state_dict(), checkpoint_path) + def load(self, checkpoint_path): - self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) - self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) - - - + self.policy_old.load_state_dict( + torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + ) + self.policy.load_state_dict( + torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + ) diff --git a/PPO_preTrained/RocketLanding/PPO_RocketLanding_0_0.pth b/PPO_preTrained/RocketLanding/PPO_RocketLanding_0_0.pth new file mode 100644 index 0000000..e03e830 Binary files /dev/null and b/PPO_preTrained/RocketLanding/PPO_RocketLanding_0_0.pth differ diff --git a/README.md b/README.md index a8b2a39..d7ab127 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ The goal is to train a reinforcement learning agent to control a rocket to either hover or land safely using the PPO algorithm. The environment simulates physics for the rocket, and the agent learns to make decisions based on the state observations to achieve the task. -https://github.com/user-attachments/assets/2bc71416-0043-4e8d-8f00-cd0d85a834ec +https://github.com/user-attachments/assets/d1977412-2de8-49c3-b0d1-f602dc28bb61 ![RewardsChart](images/rewards-timesteps.png) @@ -90,22 +90,26 @@ These states provide the necessary information for the agent to understand the r source venv/bin/activate # On Windows use venv\Scripts\activate ``` -3. **Install Dependencies** +3. [**Install Dependencies**](requirements.txt) ```bash - pip install torch numpy matplotlib + pip install -r requirements.txt ``` -4. **Ensure CUDA Availability (Optional)** +4. **Ensure GPU Availability (Optional)** - If you have a CUDA-compatible GPU and want to utilize it: + If you have a CUDA-compatible GPU or Apple Silicon Chip and want to utilize it: - Install the appropriate CUDA toolkit version compatible with your PyTorch installation. - - Verify CUDA availability in PyTorch: - + - Verify GPU availability in PyTorch: ```python - import torch - torch.cuda.is_available() + import torch + if torch.cuda.is_available(): + device = torch.device("cuda:0") + print("Device set to:", torch.cuda.get_device_name(device)) + elif torch.backends.mps.is_available(): + device = torch.device("mps") + print("Device set to: MPS (Apple Silicon)") ``` --- diff --git a/plot_graph.py b/plot_graph.py index 8a7ff6f..97f02f8 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -2,7 +2,6 @@ import pandas as pd import matplotlib.pyplot as plt - def save_graph(): print("============================================================================================") # env_name = 'CartPole-v1' @@ -28,115 +27,110 @@ def save_graph(): colors = ['red', 'blue', 'green', 'orange', 'purple', 'olive', 'brown', 'magenta', 'cyan', 'crimson','gray', 'black'] - # make directory for saving figures - figures_dir = "PPO_figs" - if not os.path.exists(figures_dir): - os.makedirs(figures_dir) - - # make environment directory for saving figures - figures_dir = figures_dir + '/' + env_name + '/' - if not os.path.exists(figures_dir): - os.makedirs(figures_dir) - - fig_save_path = figures_dir + '/PPO_' + env_name + '_fig_' + str(fig_num) + '.png' - - # get number of log files in directory - log_dir = "PPO_logs" + '/' + env_name + '/' + # Setup directories + figures_dir = os.path.join("PPO_figs", env_name) + os.makedirs(figures_dir, exist_ok=True) + fig_save_path = os.path.join(figures_dir, f'PPO_{env_name}_fig_{fig_num}.png') + log_dir = os.path.join("PPO_logs", env_name) + # Get log files current_num_files = next(os.walk(log_dir))[2] num_runs = len(current_num_files) - all_runs = [] + # Load and process data for run_num in range(num_runs): - - log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv" - print("loading data from : " + log_f_name) - data = pd.read_csv(log_f_name) - data = pd.DataFrame(data) - - print("data shape : ", data.shape) - - all_runs.append(data) - print("--------------------------------------------------------------------------------------------") - - ax = plt.gca() + log_f_name = os.path.join(log_dir, f'PPO_{env_name}_log_{run_num}.csv') + print("Loading data from:", log_f_name) + + try: + # Read CSV with specific column names + data = pd.read_csv(log_f_name, names=['episode', 'timestep', 'reward']) + print("Data shape:", data.shape) + all_runs.append(data) + print("-" * 90) + + except Exception as e: + print(f"Error loading {log_f_name}: {str(e)}") + continue + + if not all_runs: + print("No valid data files found!") + return + + # Create plot + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) if plot_avg: - # average all runs + # Average all runs df_concat = pd.concat(all_runs) df_concat_groupby = df_concat.groupby(df_concat.index) data_avg = df_concat_groupby.mean() - # smooth out rewards to get a smooth and a less smooth (var) plot lines - data_avg['reward_smooth'] = data_avg['reward'].rolling(window=window_len_smooth, win_type='triang', min_periods=min_window_len_smooth).mean() - data_avg['reward_var'] = data_avg['reward'].rolling(window=window_len_var, win_type='triang', min_periods=min_window_len_var).mean() - - data_avg.plot(kind='line', x='timestep' , y='reward_smooth',ax=ax,color=colors[0], linewidth=linewidth_smooth, alpha=alpha_smooth) - data_avg.plot(kind='line', x='timestep' , y='reward_var',ax=ax,color=colors[0], linewidth=linewidth_var, alpha=alpha_var) - - # keep only reward_smooth in the legend and rename it + # Smooth out rewards + data_avg['reward_smooth'] = data_avg['reward'].rolling( + window=window_len_smooth, + win_type='triang', + min_periods=min_window_len_smooth + ).mean() + + data_avg['reward_var'] = data_avg['reward'].rolling( + window=window_len_var, + win_type='triang', + min_periods=min_window_len_var + ).mean() + + # Plot + data_avg.plot(kind='line', x='timestep', y='reward_smooth', + ax=ax, color=colors[0], + linewidth=linewidth_smooth, alpha=alpha_smooth) + data_avg.plot(kind='line', x='timestep', y='reward_var', + ax=ax, color=colors[0], + linewidth=linewidth_var, alpha=alpha_var) + + # Update legend handles, labels = ax.get_legend_handles_labels() - ax.legend([handles[0]], ["reward_avg_" + str(len(all_runs)) + "_runs"], loc=2) + ax.legend([handles[0]], [f"reward_avg_{len(all_runs)}_runs"], loc=2) else: for i, run in enumerate(all_runs): - # smooth out rewards to get a smooth and a less smooth (var) plot lines - run['reward_smooth_' + str(i)] = run['reward'].rolling(window=window_len_smooth, win_type='triang', min_periods=min_window_len_smooth).mean() - run['reward_var_' + str(i)] = run['reward'].rolling(window=window_len_var, win_type='triang', min_periods=min_window_len_var).mean() - - # plot the lines - run.plot(kind='line', x='timestep' , y='reward_smooth_' + str(i),ax=ax,color=colors[i % len(colors)], linewidth=linewidth_smooth, alpha=alpha_smooth) - run.plot(kind='line', x='timestep' , y='reward_var_' + str(i),ax=ax,color=colors[i % len(colors)], linewidth=linewidth_var, alpha=alpha_var) - - # keep alternate elements (reward_smooth_i) in the legend + run[f'reward_smooth_{i}'] = run['reward'].rolling( + window=window_len_smooth, + win_type='triang', + min_periods=min_window_len_smooth + ).mean() + + run[f'reward_var_{i}'] = run['reward'].rolling( + window=window_len_var, + win_type='triang', + min_periods=min_window_len_var + ).mean() + + run.plot(kind='line', x='timestep', y=f'reward_smooth_{i}', + ax=ax, color=colors[i % len(colors)], + linewidth=linewidth_smooth, alpha=alpha_smooth) + run.plot(kind='line', x='timestep', y=f'reward_var_{i}', + ax=ax, color=colors[i % len(colors)], + linewidth=linewidth_var, alpha=alpha_var) + + # Update legend handles, labels = ax.get_legend_handles_labels() - new_handles = [] - new_labels = [] - for i in range(len(handles)): - if(i%2 == 0): - new_handles.append(handles[i]) - new_labels.append(labels[i]) + new_handles = [handles[i] for i in range(0, len(handles), 2)] + new_labels = [labels[i] for i in range(0, len(labels), 2)] ax.legend(new_handles, new_labels, loc=2) - # ax.set_yticks(np.arange(0, 1800, 200)) - # ax.set_xticks(np.arange(0, int(4e6), int(5e5))) - + # Finalize plot ax.grid(color='gray', linestyle='-', linewidth=1, alpha=0.2) - ax.set_xlabel("Timesteps", fontsize=12) ax.set_ylabel("Rewards", fontsize=12) - plt.title(env_name, fontsize=14) - fig = plt.gcf() - fig.set_size_inches(fig_width, fig_height) - + # Save and show print("============================================================================================") plt.savefig(fig_save_path) - print("figure saved at : ", fig_save_path) + print("Figure saved at:", fig_save_path) print("============================================================================================") - plt.show() - if __name__ == '__main__': - - save_graph() - - - - - - - - - - - - - - - - - \ No newline at end of file + save_graph() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8ae20aa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +matplotlib==3.9.2 +numpy==2.1.3 +opencv-python==4.10.0 +opencv-python-headless==4.10.0 +torch==2.5.1 +pandas==2.2.3 \ No newline at end of file diff --git a/test.py b/test.py index 4015ef5..7685d39 100644 --- a/test.py +++ b/test.py @@ -1,93 +1,187 @@ import os import time -from datetime import datetime - -import torch import numpy as np +import cv2 +import torch +from PPO import PPO +from rocket import Rocket +import re + +def create_confetti_particles(num_particles=30): # Reduced particles + """Create confetti particles""" + particles = [] + for _ in range(num_particles): + particles.append({ + 'x': np.random.randint(0, 800), + 'y': np.random.randint(-50, 0), + 'color': ( + np.random.randint(0, 255), + np.random.randint(0, 255), + np.random.randint(0, 255) + ), + 'size': np.random.randint(5, 15), + 'speed': np.random.randint(5, 15), + 'angle': np.random.uniform(-np.pi/4, np.pi/4) + }) + return particles + +def celebrate_landing(window_name="Perfect Landing!", duration=1.0): # Reduced duration + """Show celebration animation""" + width, height = 800, 600 + particles = create_confetti_particles() + + start_time = time.time() + while time.time() - start_time < duration: + frame = np.ones((height, width, 3), dtype=np.uint8) * 255 + + # Update and draw particles in one pass + for p in particles: + p['y'] += p['speed'] + p['x'] += np.sin(p['angle']) * 2 + p['speed'] += 0.5 + p['angle'] += np.random.uniform(-0.1, 0.1) + + cv2.circle(frame, + (int(p['x']), int(p['y'])), + p['size'], + p['color'], + -1) + + # Add celebration text + text = "Perfect Landing!" + cv2.putText(frame, text, + (width//4, height//2), + cv2.FONT_HERSHEY_DUPLEX, + 2.0, (0, 0, 0), 2) + + cv2.imshow(window_name, frame) + if cv2.waitKey(16) & 0xFF == 27: # ~60 FPS + break + + cv2.destroyWindow(window_name) -from PPO import PPO # Assuming PPO is your policy class -from rocket import Rocket # Import your Rocket environment class +def get_test_config(): + """Get test-specific configuration""" + config = {} + + print("\n====== Test Configuration ======") + + # Task selection + task = input("\nSelect task (hover/landing) [default: landing]: ").lower() + config['task'] = 'landing' if task in ['', 'landing'] else 'hover' + + # Rocket type selection + rocket = input("Select rocket type (falcon/starship) [default: starship]: ").lower() + config['rocket_type'] = 'starship' if rocket in ['', 'starship'] else 'falcon' + + # Rendering preference + config['render'] = input("Enable rendering? (y/n) [default: y]: ").lower() != 'n' + config['frame_delay'] = int(input("Frame delay in milliseconds [default: 16]: ") or 16) + + return config -#################################### Testing ################################### def test(): print("============================================================================================") + # Get test configuration + config = get_test_config() + ################## Hyperparameters ################## env_name = "RocketLanding" - task = 'landing' # 'hover' or 'landing' + max_ep_len = 1000 + total_test_episodes = 10 + # PPO hyperparameters has_continuous_action_space = False - max_ep_len = 1000 # Max timesteps in one episode - - render = True # Render environment on screen - frame_delay = 1 # Delay between frames (in seconds) - - total_test_episodes = 10 # Total number of testing episodes - - K_epochs = 80 # Update policy for K epochs - eps_clip = 0.2 # Clip parameter for PPO - gamma = 0.99 # Discount factor - - lr_actor = 0.0003 # Learning rate for actor - lr_critic = 0.001 # Learning rate for critic - ##################################################### - - # Initialize the Rocket environment - env = Rocket(max_steps=max_ep_len, task=task, rocket_type='starship') # Adjust for 'hover' task if needed - - # Set state and action dimensions based on Rocket's configuration + K_epochs = 80 + eps_clip = 0.2 + gamma = 0.99 + lr_actor = 0.0003 + lr_critic = 0.001 + + # Initialize environment + env = Rocket(max_steps=max_ep_len, + task=config['task'], + rocket_type=config['rocket_type']) + + # Set dimensions state_dim = env.state_dims action_dim = env.action_dims - - # Initialize a PPO agent - ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space) - - # Pretrained weights directory - random_seed = 0 # Set this to load a specific checkpoint trained on a random seed - run_num_pretrained = 13 # Set this to load a specific checkpoint number - - directory = "PPO_preTrained" + '/' + env_name + '/' - checkpoint_path = directory + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num_pretrained) - print("loading network from : " + checkpoint_path) - - # Load pretrained model - ppo_agent.load(checkpoint_path) - - print("--------------------------------------------------------------------------------------------") - - test_running_reward = 0 + # Initialize PPO agent + ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, + K_epochs, eps_clip, has_continuous_action_space) + + checkpoint_dir = os.path.join("PPO_preTrained", env_name) + checkpoints = [f for f in os.listdir(checkpoint_dir) if re.match(r"PPO_RocketLanding_\d+_\d+\.pth", f)] + if not checkpoints: + print(f"\nError: No checkpoints found in {checkpoint_dir}") + print("Please ensure you have trained the model first.") + return + checkpoints.sort(key=lambda x: [int(num) for num in re.findall(r'\d+', x)]) + latest_checkpoint = checkpoints[-1] + checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint) + print(f"\nLoading model from: {checkpoint_path}") + try: + ppo_agent.load(checkpoint_path) + print("Model loaded successfully!") + except Exception as e: + print(f"Error loading model: {e}") + return + + print("\nStarting testing...") + test_running_reward = 0 + successful_landings = 0 for ep in range(1, total_test_episodes + 1): ep_reward = 0 state = env.reset() - + for t in range(1, max_ep_len + 1): - action = ppo_agent.select_action(state) - state, reward, done, _ = env.step(action) + # Select action + with torch.no_grad(): # Faster inference + action = ppo_agent.select_action(state) + + # Take step + next_state, reward, done, _ = env.step(action) ep_reward += reward - - if render: - env.render(window_name="Rocket Test", wait_time=frame_delay) # Adjust for Rocket render method - + + # Render if enabled + if config['render'] and t % 2 == 0: # Skip frames for speed + env.render(window_name="Rocket Test", + wait_time=config['frame_delay']) + if done: + # Check landing conditions + x_pos, y_pos = next_state[0], next_state[1] + vx, vy = next_state[2], next_state[3] + theta = next_state[4] + + # Stricter landing conditions + if (reward > 500 and # High reward + abs(x_pos) < 10.0 and # Close to center + abs(vx) < 5.0 and abs(vy) < 5.0 and # Low velocity + abs(theta) < 0.1): # Nearly vertical + + successful_landings += 1 + print(f"\nPerfect landing! Reward: {reward:.2f}") + celebrate_landing() break - - # Clear PPO agent buffer after each episode + + state = next_state + ppo_agent.buffer.clear() - test_running_reward += ep_reward - print('Episode: {} \t\t Reward: {}'.format(ep, round(ep_reward, 2))) - - env.close() - + print(f'Episode: {ep} \t\t Reward: {round(ep_reward, 2)}') + + cv2.destroyAllWindows() + print("============================================================================================") - avg_test_reward = test_running_reward / total_test_episodes - print("average test reward : " + str(round(avg_test_reward, 2))) - + success_rate = (successful_landings / total_test_episodes) * 100 + print(f"Average test reward: {round(avg_test_reward, 2)}") + print(f"Successful landings: {successful_landings}/{total_test_episodes} ({success_rate:.1f}%)") print("============================================================================================") - if __name__ == '__main__': - test() + test() \ No newline at end of file diff --git a/train.py b/train.py index 9ce0018..54bdec6 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,6 @@ import os -import time from datetime import datetime - -import torch +import utils import numpy as np from PPO import PPO # Assuming PPO is your policy class @@ -10,196 +8,183 @@ import matplotlib.pyplot as plt +def get_latest_checkpoint(directory, env_name): + """Find the latest checkpoint in the directory.""" + if not os.path.exists(directory): + return None, 0, 0 + + files = [f for f in os.listdir(directory) if f.startswith(f"PPO_{env_name}")] + if not files: + return None, 0, 0 + + # Extract run numbers and find the latest + runs = [] + for f in files: + try: + # Format: PPO_RocketLanding_0_13.pth + parts = f.split('_') + seed, run = int(parts[-2]), int(parts[-1].split('.')[0]) + runs.append((seed, run, f)) + except: + continue + + if not runs: + return None, 0, 0 + + # Get the latest run + latest = max(runs, key=lambda x: x[1]) + return os.path.join(directory, latest[2]), latest[0], latest[1] + +def load_training_state(log_dir, env_name, run_num): + """Load the previous training state from logs.""" + log_file = os.path.join(log_dir, f'PPO_{env_name}_log_{run_num}.csv') + if not os.path.exists(log_file): + return 0, 0, [], 0 + + try: + data = np.genfromtxt(log_file, delimiter=',', skip_header=1) + if len(data) == 0: + return 0, 0, [], 0 + + last_episode = int(data[-1, 0]) + last_timestep = int(data[-1, 1]) + rewards = data[:, 2].tolist() + return last_episode, last_timestep, rewards, run_num + except: + return 0, 0, [], 0 + + ################################### Training ################################### def train(): print("============================================================================================") + # Get training configuration first + config = utils.get_training_config() + ####### initialize environment hyperparameters ###### env_name = "RocketLanding" - task = 'landing' # 'hover' or 'landing' - - render = True - - has_continuous_action_space = False # Discrete action space for Rocket - - max_ep_len = 1000 # Max timesteps in one episode - max_training_timesteps = int(6e6) # Break training loop if timeteps > max_training_timesteps - - print_freq = max_ep_len * 10 # Print avg reward in the interval (in num timesteps) - log_freq = max_ep_len * 2 # Log avg reward in the interval (in num timesteps) - save_model_freq = int(1e5) # Save model frequency (in num timesteps) - ##################################################### - - ################ PPO hyperparameters ################ - update_timestep = max_ep_len * 4 # Update policy every n timesteps - K_epochs = 80 # Update policy for K epochs in one PPO update - eps_clip = 0.2 # Clip parameter for PPO - gamma = 0.99 # Discount factor - lr_actor = 0.0003 # Learning rate for actor network - lr_critic = 0.001 # Learning rate for critic network - random_seed = 0 # Set random seed if required (0 = no random seed) - ##################################################### - - print("training environment name : " + env_name) + max_ep_len = 1000 + max_training_timesteps = int(6e6) + print_freq = max_ep_len * 10 + log_freq = max_ep_len * 2 + save_model_freq = int(1e5) # Initialize the Rocket environment - env = Rocket(max_steps=max_ep_len, task=task, rocket_type='starship') # Adjust as needed for the hover task + env = Rocket(max_steps=max_ep_len, task=config['task'], + rocket_type=config['rocket_type']) # Set state and action dimensions state_dim = env.state_dims action_dim = env.action_dims - ###################### logging ###################### - log_dir = "PPO_logs" - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - log_dir = log_dir + '/' + env_name + '/' - if not os.path.exists(log_dir): - os.makedirs(log_dir) - - run_num = len(next(os.walk(log_dir))[2]) - log_f_name = log_dir + '/PPO_' + env_name + "_log_" + str(run_num) + ".csv" - print("logging at : " + log_f_name) - ##################################################### - - ################### checkpointing ################### - directory = "PPO_preTrained" - if not os.path.exists(directory): - os.makedirs(directory) - - directory = directory + '/' + env_name + '/' - if not os.path.exists(directory): - os.makedirs(directory) - - checkpoint_path = directory + "PPO_{}_{}_{}.pth".format(env_name, random_seed, run_num) - print("save checkpoint path : " + checkpoint_path) - ##################################################### - - # Initialize a PPO agent - ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space) + ################ PPO hyperparameters ################ + has_continuous_action_space = False + update_timestep = max_ep_len * 4 + K_epochs = 80 + eps_clip = 0.2 + gamma = 0.99 + lr_actor = 0.0003 + lr_critic = 0.001 + + # Setup directories + directory, log_dir = utils.setup_directories() + + # Initialize PPO agent + ppo_agent = PPO(state_dim, action_dim, lr_actor, lr_critic, gamma, + K_epochs, eps_clip, has_continuous_action_space) + + # Setup training state (includes checkpoint loading if available) + i_episode, time_step, episode_rewards, run_num, log_f, checkpoint_path = utils.setup_training_state( + directory, log_dir, env_name, ppo_agent + ) # Track total training time start_time = datetime.now().replace(microsecond=0) print("Started training at (GMT) : ", start_time) - log_f = open(log_f_name, "w+") - log_f.write('episode,timestep,reward\n') - # Initialize logging variables print_running_reward = 0 print_running_episodes = 0 log_running_reward = 0 log_running_episodes = 0 + window_size = 10 - time_step = 0 - i_episode = 0 - - episode_rewards = [] - window_size = 10 # Window size for moving average and standard deviation - - # Initialize the plot for real-time updating - plt.ion() # Turn on interactive mode - fig, ax = plt.subplots() - ax.set_xlabel('Episode') - ax.set_ylabel('Reward') - ax.set_title('Training Progress') - plt.show(block=False) - window_size = 10 # Window size for moving average and standard deviation + # Setup plotting if enabled + if config['plot_realtime'] or config['save_plots']: + fig, ax = utils.setup_plotting(config) + else: + fig, ax = None, None # Training loop - while time_step <= max_training_timesteps: - state = env.reset() - current_ep_reward = 0 - - for t in range(1, max_ep_len + 1): - # Select action with policy - action = ppo_agent.select_action(state) - state, reward, done, _ = env.step(action) - - # Save reward and terminal state - ppo_agent.buffer.rewards.append(reward) - ppo_agent.buffer.is_terminals.append(done) - - time_step += 1 - current_ep_reward += reward - - if render and i_episode % 50 == 0: - env.render() - - # Update PPO agent - if time_step % update_timestep == 0: - ppo_agent.update() - - # Log to file - if time_step % log_freq == 0: - log_avg_reward = log_running_reward / log_running_episodes - log_f.write('{},{},{}\n'.format(i_episode, time_step, round(log_avg_reward, 4))) - log_running_reward, log_running_episodes = 0, 0 - - # Print average reward - if time_step % print_freq == 0: - print_avg_reward = print_running_reward / print_running_episodes - print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format(i_episode, time_step, round(print_avg_reward, 2))) - print_running_reward, print_running_episodes = 0, 0 - - # Save model weights - if time_step % save_model_freq == 0: - ppo_agent.save(checkpoint_path) - print("Model saved at timestep: ", time_step) - - if done: - break - - print_running_reward += current_ep_reward - print_running_episodes += 1 - log_running_reward += current_ep_reward - log_running_episodes += 1 - i_episode += 1 - - episode_rewards.append(current_ep_reward) - - # Update the plot - if len(episode_rewards) >= window_size: - # Calculate moving average and standard deviation - moving_avg = np.convolve( - episode_rewards, np.ones(window_size)/window_size, mode='valid' - ) - moving_std = np.array([ - np.std(episode_rewards[i-window_size+1:i+1]) - for i in range(window_size-1, len(episode_rewards)) - ]) - episodes = np.arange(window_size-1, len(episode_rewards)) - - # Clear the axis and redraw - ax.clear() - ax.plot(episodes, moving_avg, label='Moving Average Reward') - - # Shade the area between (mean - std) and (mean + std) - lower_bound = moving_avg - moving_std - upper_bound = moving_avg + moving_std - ax.fill_between(episodes, lower_bound, upper_bound, color='blue', alpha=0.2, label='Standard Deviation') - - # Set labels and title - ax.set_xlabel('Episode') - ax.set_ylabel('Reward') - ax.set_title('Training Progress with Variability Shading') - ax.legend() - plt.draw() - plt.pause(0.01) - else: - # For initial episodes where we don't have enough data for moving average - ax.clear() - ax.plot(range(len(episode_rewards)), episode_rewards, label='Episode Reward') - ax.set_xlabel('Episode') - ax.set_ylabel('Reward') - ax.set_title('Training Progress') - ax.legend() - plt.draw() - plt.pause(0.01) - - log_f.close() - print("Finished training at : ", datetime.now().replace(microsecond=0)) + try: + while time_step <= max_training_timesteps: + state = env.reset() + current_ep_reward = 0 + + for t in range(1, max_ep_len + 1): + # Select action with policy + action = ppo_agent.select_action(state) + state, reward, done, _ = env.step(action) + + # Save reward and terminal state + ppo_agent.buffer.rewards.append(reward) + ppo_agent.buffer.is_terminals.append(done) + + time_step += 1 + current_ep_reward += reward + + if config['render'] and i_episode % 50 == 0: + env.render() + + # Update PPO agent + if time_step % update_timestep == 0: + ppo_agent.update() + + # Log to file + if time_step % log_freq == 0: + log_avg_reward = log_running_reward / log_running_episodes + log_f.write('{},{},{}\n'.format(i_episode, time_step, round(log_avg_reward, 4))) + log_running_reward, log_running_episodes = 0, 0 + + # Print average reward + if time_step % print_freq == 0: + print_avg_reward = print_running_reward / print_running_episodes + print("Episode : {} \t\t Timestep : {} \t\t Average Reward : {}".format( + i_episode, time_step, round(print_avg_reward, 2))) + print_running_reward, print_running_episodes = 0, 0 + + # Save model weights + if time_step % save_model_freq == 0: + ppo_agent.save(checkpoint_path) + print("Model saved at timestep: ", time_step) + + if done: + break + + # Update rewards and episodes + print_running_reward += current_ep_reward + print_running_episodes += 1 + log_running_reward += current_ep_reward + log_running_episodes += 1 + i_episode += 1 + episode_rewards.append(current_ep_reward) + + # Update plot if enabled + if fig is not None and len(episode_rewards) >= window_size: + utils.update_plots(fig, ax, episode_rewards, window_size, config) + + except KeyboardInterrupt: + print("\nTraining interrupted by user") + except Exception as e: + print(f"\nError during training: {e}") + finally: + if time_step > 0: + ppo_agent.save(checkpoint_path) + print("Final model saved at: ", checkpoint_path) + if fig is not None: + plt.close('all') + log_f.close() + print("Finished training at : ", datetime.now().replace(microsecond=0)) if __name__ == '__main__': train() diff --git a/utils.py b/utils.py index 058d0f2..ec873bb 100644 --- a/utils.py +++ b/utils.py @@ -1,20 +1,60 @@ import numpy as np import cv2 - +import pandas as pd +import json +import matplotlib as plt +import torch +import os +import json + +################ Checkpoint Management #################### +def find_checkpoints(directory, env_name): + """Find all available checkpoints""" + checkpoints = [] + if os.path.exists(directory): + for file in os.listdir(directory): + if file.startswith(f"PPO_{env_name}") and file.endswith(".pth"): + checkpoints.append(file) + return sorted(checkpoints) + +def load_checkpoint(directory, env_name): + """Handle checkpoint loading with user interaction""" + checkpoints = find_checkpoints(directory, env_name) + + if not checkpoints: + print("\nNo existing checkpoints found. Starting fresh training.") + return None, None + + print("\n====== Available Checkpoints ======") + for i, ckpt in enumerate(checkpoints): + print(f"{i+1}. {ckpt}") + + while True: + choice = input("\nSelect checkpoint number to load (or press Enter to start fresh): ") + if choice == "": + return None, None + try: + idx = int(choice) - 1 + if 0 <= idx < len(checkpoints): + return os.path.join(directory, checkpoints[idx]), checkpoints[idx] + except ValueError: + pass + print("Invalid choice. Please try again.") + +################ Logging Management #################### +def setup_logging(log_dir, env_name, run_num): + """Setup logging with continuation support""" + log_path = os.path.join(log_dir, f'PPO_{env_name}_log_{run_num}.csv') + + if os.path.exists(log_path): + print(f"\nFound existing log file: {log_path}") + choice = input("Continue logging to this file? (y/n) [default: n]: ").lower() + if choice == 'y': + return open(log_path, 'a'), True + + return open(log_path, 'w+'), False ################ Some helper functions... #################### -def moving_avg(x, N=500): - - if len(x) <= N: - return [] - - x_pad_left = x[0:N] - x_pad_right = x[-N:] - x_pad = x_pad_left[::-1] + x + x_pad_right[::-1] - y = np.convolve(x_pad, np.ones(N) / N, mode='same') - return y[N:-N] - - def load_bg_img(path_to_img, w, h): bg_img = cv2.imread(path_to_img, cv2.IMREAD_COLOR) bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) @@ -79,7 +119,7 @@ def rotation_matrix(rx=0., ry=0., rz=0.): Rz[1, 1] = np.cos(rz) # RZ * RY * RX - RotationMatrix = np.mat(Rz) * np.mat(Ry) * np.mat(Rx) + RotationMatrix = np.asmatrix(Rz) * np.asmatrix(Ry) * np.asmatrix(Rx) return np.array(RotationMatrix) @@ -109,10 +149,366 @@ def create_pose_matrix(tx=0., ty=0., tz=0., TranslationMatrix = translation_matrix(tx, ty, tz) # TranslationMatrix * RotationMatrix * ScaleMatrix - PoseMatrix = np.mat(TranslationMatrix) \ - * np.mat(RotationMatrix) \ - * np.mat(ScaleMatrix) \ - * np.mat(base_correction) + PoseMatrix = np.asmatrix(TranslationMatrix) \ + * np.asmatrix(RotationMatrix) \ + * np.asmatrix(ScaleMatrix) \ + * np.asmatrix(base_correction) return np.array(PoseMatrix) +################ Training Management #################### + +def get_training_config(): + """Get all training configurations from user""" + config = {} + + print("\n====== Training Configuration ======") + + # Task selection + while True: + task = input("\nSelect task (hover/landing) [default: landing]: ").lower() + if task in ['', 'hover', 'landing']: + config['task'] = 'landing' if task == '' else task + break + print("Invalid choice. Please select 'hover' or 'landing'") + + # Rocket type selection + while True: + rocket = input("Select rocket type (falcon/starship) [default: starship]: ").lower() + if rocket in ['', 'falcon', 'starship']: + config['rocket_type'] = 'starship' if rocket == '' else rocket + break + print("Invalid choice. Please select 'falcon' or 'starship'") + + # Visualization preferences + config['render'] = input("Enable environment rendering? (y/n) [default: n]: ").lower() == 'y' + config['plot_realtime'] = input("Enable real-time plotting? (y/n) [default: y]: ").lower() != 'n' + config['save_plots'] = input("Save training plots? (y/n) [default: y]: ").lower() != 'n' + + # Training parameters + try: + config['max_episodes'] = int(input("Enter maximum episodes [default: 1000]: ") or 1000) + config['save_freq'] = int(input("Save checkpoint frequency (episodes) [default: 100]: ") or 100) + except ValueError: + print("Invalid input for episodes. Using defaults.") + config['max_episodes'] = 1000 + config['save_freq'] = 100 + + return config + +def setup_directories(base_dir="PPO_preTrained", env_name="RocketLanding"): + """Create necessary directories""" + # Create base directory + if not os.path.exists(base_dir): + os.makedirs(base_dir) + + # Create environment directory + env_dir = os.path.join(base_dir, env_name) + if not os.path.exists(env_dir): + os.makedirs(env_dir) + + # Create logs directory + log_dir = os.path.join("PPO_logs", env_name) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + return env_dir, log_dir + +def get_latest_checkpoint(directory, env_name): + """Find the latest checkpoint in the directory.""" + if not os.path.exists(directory): + return None, 0, 0 + + files = [f for f in os.listdir(directory) if f.startswith(f"PPO_{env_name}")] + if not files: + return None, 0, 0 + + runs = [] + for f in files: + try: + parts = f.split('_') + seed, run = int(parts[-2]), int(parts[-1].split('.')[0]) + runs.append((seed, run, f)) + except: + continue + + if not runs: + return None, 0, 0 + + latest = max(runs, key=lambda x: x[1]) + return os.path.join(directory, latest[2]), latest[0], latest[1] + +def setup_training_state(directory, log_dir, env_name, ppo_agent): + """Setup training state and handle checkpoint loading.""" + latest_checkpoint, checkpoint_seed, checkpoint_run = get_latest_checkpoint(directory, env_name) + + if latest_checkpoint is not None: + print(f"Found existing checkpoint: {latest_checkpoint}") + response = input("Continue from previous checkpoint? (y/n) [default: n]: ").lower() + + if response == 'y': + # Load model weights + ppo_agent.load(latest_checkpoint) + + # Load training logs + log_path = os.path.join(log_dir, f'PPO_{env_name}_log_{checkpoint_run}.csv') + if os.path.exists(log_path): + data = np.genfromtxt(log_path, delimiter=',', skip_header=1) + i_episode = int(data[-1, 0]) + time_step = int(data[-1, 1]) + rewards = data[:, 2].tolist() + else: + i_episode, time_step, rewards = 0, 0, [] + + # Setup logging + log_f = open(log_path, 'a') + + print(f"Resuming training from episode {i_episode}, timestep {time_step}") + return i_episode, time_step, rewards, checkpoint_run, log_f, latest_checkpoint + + # Start fresh training + run_num = len(next(os.walk(log_dir))[2]) + log_path = os.path.join(log_dir, f'PPO_{env_name}_log_{run_num}.csv') + log_f = open(log_path, 'w+') + log_f.write('episode,timestep,reward\n') + checkpoint_path = os.path.join(directory, f"PPO_{env_name}_0_{run_num}.pth") + + return 0, 0, [], run_num, log_f, checkpoint_path + +def setup_plotting(config): + """Setup plotting based on configuration""" + if not config['plot_realtime'] and not config['save_plots']: + return None, None + + plt.close('all') # Close any existing plots + fig, ax = plt.subplots(figsize=(10, 6)) + ax.set_xlabel('Episode') + ax.set_ylabel('Reward') + ax.set_title('Training Progress') + + if config['plot_realtime']: + plt.ion() + plt.show(block=False) + + return fig, ax + +def update_plots(fig, ax, episode_rewards, window_size, config, save_dir=None): + """Update and optionally save plots""" + if fig is None or ax is None: + return + + if len(episode_rewards) >= window_size: + moving_avg = np.convolve(episode_rewards, np.ones(window_size)/window_size, mode='valid') + moving_std = np.array([np.std(episode_rewards[i-window_size+1:i+1]) + for i in range(window_size-1, len(episode_rewards))]) + episodes = np.arange(window_size-1, len(episode_rewards)) + + ax.clear() + ax.plot(episodes, moving_avg, label='Moving Average') + ax.fill_between(episodes, moving_avg-moving_std, moving_avg+moving_std, + alpha=0.2, label='Standard Deviation') + ax.set_xlabel('Episode') + ax.set_ylabel('Reward') + ax.set_title('Training Progress') + ax.legend() + + if config['plot_realtime']: + plt.draw() + plt.pause(0.01) + + if config['save_plots'] and save_dir: + plt.savefig(os.path.join(save_dir, 'training_progress.png')) + +################ Configuration Management #################### +def save_training_config(config, directory): + """Save training configuration for reproducibility""" + config_path = os.path.join(directory, 'training_config.json') + with open(config_path, 'w') as f: + json.dump(config, f, indent=4) + +def load_training_config(directory): + """Load previous training configuration""" + config_path = os.path.join(directory, 'training_config.json') + if os.path.exists(config_path): + with open(config_path, 'r') as f: + return json.load(f) + return None + +def get_best_model_path(log_dir, env_name): + """Find the best performing model based on logs""" + best_reward = float('-inf') + best_model = None + + log_files = [f for f in os.listdir(log_dir) if f.endswith('.csv')] + for log_file in log_files: + data = pd.read_csv(os.path.join(log_dir, log_file)) + avg_reward = data['reward'].mean() + if avg_reward > best_reward: + best_reward = avg_reward + best_model = log_file.replace('log', 'model').replace('.csv', '.pth') + + return best_model if best_model else None + +################ Performance Monitoring #################### +def track_training_stats(): + """Track various training statistics""" + return { + 'best_reward': float('-inf'), + 'best_episode': 0, + 'running_avg': [], + 'episode_lengths': [], + 'success_rate': [], + 'crash_rate': [] + } + +def update_training_stats(stats, reward, episode_length, success, crash): + """Update training statistics""" + stats['running_avg'].append(reward) + stats['episode_lengths'].append(episode_length) + stats['success_rate'].append(1 if success else 0) + stats['crash_rate'].append(1 if crash else 0) + + if reward > stats['best_reward']: + stats['best_reward'] = reward + stats['best_episode'] = len(stats['running_avg']) + + return stats + +################ Plot Management #################### +def setup_training_plots(plot_config): + """Setup multiple plots for training visualization""" + if not plot_config['enabled']: + return None + + figs = {} + figs['reward'] = plt.figure(figsize=(10, 5)) + figs['success_rate'] = plt.figure(figsize=(10, 5)) + figs['episode_length'] = plt.figure(figsize=(10, 5)) + + return figs + +def update_training_plots(figs, stats, save_dir=None): + """Update all training plots""" + if not figs: + return + + # Update reward plot + plt.figure(figs['reward'].number) + plt.clf() + plt.plot(stats['running_avg']) + plt.title('Training Rewards') + + # Update success rate plot + plt.figure(figs['success_rate'].number) + plt.clf() + window = 100 + success_rate = np.convolve(stats['success_rate'], + np.ones(window)/window, + mode='valid') + plt.plot(success_rate) + plt.title('Success Rate') + + if save_dir: + for name, fig in figs.items(): + fig.savefig(os.path.join(save_dir, f'{name}.png')) + +################ Error Handling and Logging #################### +def setup_logger(log_dir, env_name): + """Setup logging configuration""" + import logging + + logger = logging.getLogger('rocket_training') + logger.setLevel(logging.INFO) + + # File handler + fh = logging.FileHandler(os.path.join(log_dir, f'{env_name}_training.log')) + fh.setLevel(logging.INFO) + + # Console handler + ch = logging.StreamHandler() + ch.setLevel(logging.INFO) + + # Formatter + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + fh.setFormatter(formatter) + ch.setFormatter(formatter) + + logger.addHandler(fh) + logger.addHandler(ch) + + return logger + +################ Training Resume Management #################### +def save_training_state(directory, episode, timestep, stats, model): + """Save complete training state""" + state = { + 'episode': episode, + 'timestep': timestep, + 'stats': stats, + 'model_state': model.state_dict() + } + torch.save(state, os.path.join(directory, 'training_state.pth')) + +def load_training_state(directory): + """Load complete training state""" + state_path = os.path.join(directory, 'training_state.pth') + if os.path.exists(state_path): + return torch.load(state_path) + return None + +def load_existing_training(directory, log_dir, env_name, ppo_agent, checkpoint_path, checkpoint_seed, checkpoint_run): + """Load existing training state""" + try: + # Load model weights + ppo_agent.load(checkpoint_path) + + # Load training logs + log_path = os.path.join(log_dir, f'PPO_{env_name}_log_{checkpoint_run}.csv') + if os.path.exists(log_path): + data = np.genfromtxt(log_path, delimiter=',', skip_header=1) + episode = int(data[-1, 0]) + timestep = int(data[-1, 1]) + rewards = data[:, 2].tolist() + else: + episode, timestep, rewards = 0, 0, [] + + # Setup logging + log_f = open(log_path, 'a') + + return episode, timestep, rewards, checkpoint_run, log_f, checkpoint_path + except Exception as e: + print(f"Error loading existing training: {e}") + return setup_new_training(directory, log_dir, env_name) + +def setup_new_training(directory, log_dir, env_name): + """Setup new training session""" + # Get new run number + run_num = len(next(os.walk(log_dir))[2]) + + # Create new log file + log_path = os.path.join(log_dir, f'PPO_{env_name}_log_{run_num}.csv') + log_f = open(log_path, 'w+') + log_f.write('episode,timestep,reward\n') + + # Create new checkpoint path + checkpoint_path = os.path.join(directory, f"PPO_{env_name}_0_{run_num}.pth") + + return 0, 0, [], run_num, log_f, checkpoint_path + +def setup_directories(base_dir="PPO_preTrained", env_name="RocketLanding"): + """Create necessary directories""" + # Create base directory + if not os.path.exists(base_dir): + os.makedirs(base_dir) + + # Create environment directory + env_dir = os.path.join(base_dir, env_name) + if not os.path.exists(env_dir): + os.makedirs(env_dir) + + # Create logs directory + log_dir = os.path.join("PPO_logs", env_name) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + return env_dir, log_dir \ No newline at end of file