-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sample showing how to log trajectories
- Loading branch information
Showing
3 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Logging Episode Trajectories | ||
|
||
In this folder we will show how to log episode trajectories during training. | ||
|
||
|
||
## Prerequisites | ||
|
||
- You have followed one of our getting started samples and thus have Azure | ||
ML properly set up. Fron that sample we will be reusing: | ||
- Environment: `aml-environment` | ||
- Compute: `env-medium` | ||
|
||
|
||
# How the sample is structured | ||
|
||
This sample follows the usual structure. There is a ``job.yml`` that you can | ||
use for sending a job to Azure ML, and a Python script in ``src`` that is | ||
launched by the job. | ||
|
||
The important bits for logging trajectories are the following. | ||
|
||
In ``job.yml`` we have the following block: | ||
|
||
```yaml | ||
outputs: | ||
output_data: | ||
mode: rw_mount | ||
path: azureml://datastores/workspaceblobstore/paths/trajectories | ||
type: uri_folder | ||
``` | ||
This block defines mount instructions for our Azure ML Datastore. Once the | ||
job starts, the compute environment will have the datastore folder specified | ||
in ``path`` available as a writing destination. The customizable parts of | ||
``path`` are: ``workspaceblobstore`` (the name of an Azure ML datastore), and | ||
``trajectories`` (the name of the folder where we will save the trajectories | ||
in the datastore). The datastore *workspaceblobstore* should be available by | ||
default in your Azure ML workspace. Should you want to, you can create a | ||
custom [Azure ML | ||
datastore](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-datastore?view=azureml-api-2&tabs=sdk-identity-based-access%2Csdk-adls-identity-access%2Csdk-azfiles-accountkey%2Csdk-adlsgen1-identity-access) | ||
and replace ``workspaceblobstore``. | ||
In ``src/main.py``, two things are relevant for us. The first one is that the | ||
script must accept a folder path as a parameter (in our case this is | ||
``--storage-path``). This is passed in the ``job.yml`` and will be the folder | ||
where we will save the trajectories for the training run. The second | ||
important thing is the class named ``TrajectoryCallback``. In this class, we | ||
define a function ``on_postprocess_trajectory`` that runs during training. In | ||
this function we provide instructions on how to save the trajectories in the | ||
datastore. Note, we added the ``.callback`` method in the algorithm | ||
configuration (in our case ``PPOConfig``) to load the ``TrajectoryCallback`` | ||
functionality. | ||
Once the job completes, a CSV file containing your training agent trajectories | ||
will be stored in your datastore. To access and download it, open your Azure | ||
Machine Learning (AML) workspace. In your workspace go to *Data*, *Datastores*, | ||
*workspaceblobstore* and then *Browse*. In the trajectories folder you will | ||
find the CSV. The name of the CSV is the time when the job started. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json | ||
code: src | ||
command: >- | ||
python main.py --storage-path ${{outputs.output_data}} | ||
environment: azureml:aml-environment@latest | ||
compute: azureml:env-medium | ||
outputs: | ||
output_data: | ||
mode: rw_mount | ||
path: azureml://datastores/workspaceblobstore/paths/trajectories | ||
type: uri_folder | ||
display_name: logging-trajectories | ||
experiment_name: logging-trajectories | ||
description: Log episodes trajectories in a Datastore. | ||
# Needed for using ray on AML | ||
distribution: | ||
type: mpi | ||
# Modify the following and num_rollout_workers in main to use more workers | ||
resources: | ||
instance_count: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
import csv | ||
import datetime as dt | ||
from pathlib import Path | ||
|
||
from ray.rllib.algorithms.callbacks import DefaultCallbacks | ||
from ray.rllib.algorithms.ppo import PPOConfig | ||
from ray.tune.logger import pretty_print | ||
from ray_on_aml.core import Ray_On_AML | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--storage-path", type=Path) | ||
args = parser.parse_args() | ||
|
||
|
||
class TrajectoryCallback(DefaultCallbacks): | ||
fname = None | ||
|
||
def on_postprocess_trajectory( | ||
self, | ||
*, | ||
worker, | ||
episode, | ||
agent_id, | ||
policy_id, | ||
policies, | ||
postprocessed_batch, | ||
original_batches, | ||
): | ||
obs = postprocessed_batch["obs"] | ||
actions = postprocessed_batch["actions"] | ||
episode_id = postprocessed_batch["eps_id"] | ||
rewards = postprocessed_batch["rewards"] | ||
terminated = postprocessed_batch["terminateds"] | ||
truncated = postprocessed_batch["truncateds"] | ||
step_id = postprocessed_batch["t"] | ||
|
||
if self.fname is None: | ||
output_name = f"{dt.datetime.now().isoformat(timespec='milliseconds')}.csv" | ||
self.fname = args.storage_path / output_name | ||
print(f"Saving trajectories is {self.fname}...") | ||
|
||
first = False if Path(self.fname).exists() else True | ||
|
||
# Ordering of the tuple should be consistent with the zipped variables below | ||
header = ( | ||
"episode_id", | ||
"step_id", | ||
"state", | ||
"action", | ||
"reward", | ||
"terminated", | ||
"truncated", | ||
) | ||
|
||
# Check that headers match with file | ||
if not first: | ||
with open(self.fname, "r") as fp: | ||
reader = csv.reader(fp) | ||
file_head = tuple(next(reader)) | ||
if file_head != header: | ||
raise ValueError(f"Unexpected header in file {self.fname}") | ||
|
||
with open(self.fname, "a") as fp: | ||
writer = csv.writer(fp) | ||
if first: | ||
writer.writerow(header) | ||
for row in zip( | ||
episode_id, step_id, obs, actions, rewards, terminated, truncated | ||
): | ||
writer.writerow(row) | ||
|
||
|
||
def train(): | ||
# Define the algo for training the agent | ||
algo = ( | ||
PPOConfig() | ||
.callbacks(TrajectoryCallback) | ||
.rollouts(num_rollout_workers=1) | ||
.resources(num_gpus=0) | ||
# Set the training batch size to the appropriate number of steps | ||
.training(train_batch_size=4_000) | ||
.environment(env="CartPole-v1") | ||
.build() | ||
) | ||
# Train for 10 iterations | ||
for i in range(10): | ||
result = algo.train() | ||
print(pretty_print(result)) | ||
|
||
# outputs can be found in AML Studio under the "Outputs + Logs" tab of your job | ||
checkpoint_dir = algo.save(checkpoint_dir="./outputs") | ||
print(f"Checkpoint saved in directory {checkpoint_dir}") | ||
|
||
|
||
if __name__ == "__main__": | ||
ray_on_aml = Ray_On_AML() | ||
ray = ray_on_aml.getRay() | ||
|
||
if ray: | ||
print("head node detected") | ||
ray.init(address="auto") | ||
print(ray.cluster_resources()) | ||
train() | ||
else: | ||
print("in worker node") |