Skip to content

IamginE/RL_karel_tasks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Deep RL with PPO for Karel Tasks

This project solves 4-by-4 Karel Tasks as specified in project.pdf using deep RL. Concretely, this project solves this via a combination of imitation learning and a variant of the PPO-Algorithm.

Executing a Policy

To print out a sequence of actions generated by a policy, you can run main.py from the main folder of the project, e.g., python3 src/main.py --path "./data/train/task/0_task.json" It is important to run it from the main project folder, as paths to load models in main.py are relative to that folder. The following arguments can be passed:

  • --path (string) Path to the 4-by-4 Karel task in .json format.
  • --pretrained (bool) Boolean indicating if instead of the model trained with PPO the fully pretrained model should be used (see my submission for milestone #2). Possible values are: y, yes, 1, t, true (for true) and n, no, 0, f, false (for false). For example python3 src/main.py --path "./data/train/task/0_task.json" --pretrained t uses the model trained on all training data in a supervised way. python3 src/main.py --path "./data/train/task/0_task.json" --pretrained f is equivalent to python3 src/main.py --path "./data/train/task/0_task.json" and uses the model pretrained on the first 100 Karel tasks and then trained with PPO afterwards.

Project Structure

Folders

  • The data folder contains training and validation dataset in the defined .json format. Additionally, it contains datasets in .csv format that can be used for supervised training. They are created from the optimal sequences given in data/train/seq in combination with the corresponding states obtained by executing the given actions on the corresponding task in data/train/task.
    • supervised_first_<num>.csv contains data generated from the first <num> tasks in data/train/task. For example, supervised_first_50.csv contains data obtained from tasks with id 0 to 49.
    • supervised_full.csv contains a dataset for supervised training obtained from all tasks in data/train/task.
  • The logs folder contains logs that can be obtained, when comparing models, e.g., with test_params from src/imitation_learning.py.
  • The saved_models folder contains parameters for networks that were trained and saved afterwards.
  • The plots folder contains all generated plots, e.g., test_params from src/imitation_learning.py also produces plots in addition to the logs.
  • The src folder contains all python source code.

Python Files

  • config.py contains the functionality to parse arguments
  • create_supervised_data.py contains a single function that uses an environment to generate datasets that can be used for supervised training (imitation learning) and saves them at the specified location.
  • data_loading.py handles all data loading tasks. This includes reading from the .json files. It has functions that produce the state feature presentation as specified in project2_train.pdf. Additionally, it contains the Dataset that handles data flow during supervised training.
  • environment.py contains a function for printing a visuliazation of a vectorized state and the Karel_Environment that handles sampling initial states, as well as calculating state transitions and rewards.
  • evaluation.py contains a function used to execute and evaluate policies on multiple Karel tasks and computes the number of tasks solved as well as the average return obtained.
  • execute_policy.py contains functions used to print out or save the actions from a policy executed on a given Karel task.
  • imitation_learning.py contains functions to 1) pretrain a model in a supervised way and save it, 2) test optimzation of supervised learning on the same dataset for different learning rates (with SGD). 2) produces plots and logs.
  • main.py already provides the functionality for Milestone #3.
  • networks.py contains actor (policy) and critic (value) network.
  • plot.py contains all plotting functions used for evaluation.
  • reinforcement_learning.py contains the function for PPO-training with possibilities to load and save models.
  • rollout_buffer.py uses a rollout buffer that is heavily inspired by https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/buffers.py. This buffer is used for storing the data produced in episodes generated when training with PPO. It also contains the routines to compute the advantage values.
  • trainer.py contains classes handling training. The PPO_Trainer applies the PPO-Algorithm, while the FF_trainer is used for supervised training. Additionally, the trainer classes compute some performance metrics. Please note that I index/count epochs starting from 0.
  • test_env.py contains a very simplistic environment for sanity checking the PPO-Algorithm (see below).

Test Environment

The test environment consists of 5 non-terminal states that are arranged in a straight line with two terminal states, one at each side: (goal) - (s1) - (s2) - (s3) - (s4) - (s5) - (goal) There are two actions: move left or move right that transition deterministically to the neighbour states. The goal is to reach a terminal state using as few moves as possible.

Reproducibility of results

All of the results reported in project2_train.pdf can be reproduced by uncommenting the corresponding sections in reproduce_all.py and running it from the main folder of the project. Please note that index/count epochs starting from 0, while I plot them starting from 1.

I trained all models on CPU with a fixed random seed. However, since dataloading is different with GPU, the results may change when models are trained on GPU instead of CPU.

Please also note that some functions raise runtime errors, if they would overwrite a file that already exists. This means some sections of reproduce_all.py may cause a runtime error if certain files it tries to write already exist.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages