Skip to content

Commit

Permalink
Add sample showing how to log trajectories
Browse files Browse the repository at this point in the history
  • Loading branch information
mzat-msft committed Aug 16, 2023
1 parent 9e32391 commit 1c75540
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 0 deletions.
58 changes: 58 additions & 0 deletions examples/logging-trajectories/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Logging Episode Trajectories

In this folder we will show how to log episode trajectories during training.


## Prerequisites

- You have followed one of our getting started samples and thus have Azure
ML properly set up. Fron that sample we will be reusing:
- Environment: `aml-environment`
- Compute: `env-medium`


# How the sample is structured

This sample follows the usual structure. There is a ``job.yml`` that you can
use for sending a job to Azure ML, and a Python script in ``src`` that is
launched by the job.

The important bits for logging trajectories are the following.

In ``job.yml`` we have the following block:

```yaml
outputs:
output_data:
mode: rw_mount
path: azureml://datastores/workspaceblobstore/paths/trajectories
type: uri_folder
```
This block defines mount instructions for our Azure ML Datastore. Once the
job starts, the compute environment will have the datastore folder specified
in ``path`` available as a writing destination. The customizable parts of
``path`` are: ``workspaceblobstore`` (the name of an Azure ML datastore), and
``trajectories`` (the name of the folder where we will save the trajectories
in the datastore). The datastore *workspaceblobstore* should be available by
default in your Azure ML workspace. Should you want to, you can create a
custom [Azure ML
datastore](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-datastore?view=azureml-api-2&tabs=sdk-identity-based-access%2Csdk-adls-identity-access%2Csdk-azfiles-accountkey%2Csdk-adlsgen1-identity-access)
and replace ``workspaceblobstore``.
In ``src/main.py``, two things are relevant for us. The first one is that the
script must accept a folder path as a parameter (in our case this is
``--storage-path``). This is passed in the ``job.yml`` and will be the folder
where we will save the trajectories for the training run. The second
important thing is the class named ``TrajectoryCallback``. In this class, we
define a function ``on_postprocess_trajectory`` that runs during training. In
this function we provide instructions on how to save the trajectories in the
datastore. Note, we added the ``.callback`` method in the algorithm
configuration (in our case ``PPOConfig``) to load the ``TrajectoryCallback``
functionality.
Once the job completes, a CSV file containing your training agent trajectories
will be stored in your datastore. To access and download it, open your Azure
Machine Learning (AML) workspace. In your workspace go to *Data*, *Datastores*,
*workspaceblobstore* and then *Browse*. In the trajectories folder you will
find the CSV. The name of the CSV is the time when the job started.
20 changes: 20 additions & 0 deletions examples/logging-trajectories/job.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
$schema: https://azuremlschemas.azureedge.net/latest/commandJob.schema.json
code: src
command: >-
python main.py --storage-path ${{outputs.output_data}}
environment: azureml:aml-environment@latest
compute: azureml:env-medium
outputs:
output_data:
mode: rw_mount
path: azureml://datastores/workspaceblobstore/paths/trajectories
type: uri_folder
display_name: logging-trajectories
experiment_name: logging-trajectories
description: Log episodes trajectories in a Datastore.
# Needed for using ray on AML
distribution:
type: mpi
# Modify the following and num_rollout_workers in main to use more workers
resources:
instance_count: 1
106 changes: 106 additions & 0 deletions examples/logging-trajectories/src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import argparse
import csv
import datetime as dt
from pathlib import Path

from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
from ray_on_aml.core import Ray_On_AML

parser = argparse.ArgumentParser()
parser.add_argument("--storage-path", type=Path)
args = parser.parse_args()


class TrajectoryCallback(DefaultCallbacks):
fname = None

def on_postprocess_trajectory(
self,
*,
worker,
episode,
agent_id,
policy_id,
policies,
postprocessed_batch,
original_batches,
):
obs = postprocessed_batch["obs"]
actions = postprocessed_batch["actions"]
episode_id = postprocessed_batch["eps_id"]
rewards = postprocessed_batch["rewards"]
terminated = postprocessed_batch["terminateds"]
truncated = postprocessed_batch["truncateds"]
step_id = postprocessed_batch["t"]

if self.fname is None:
output_name = f"{dt.datetime.now().isoformat(timespec='milliseconds')}.csv"
self.fname = args.storage_path / output_name
print(f"Saving trajectories is {self.fname}...")

first = False if Path(self.fname).exists() else True

# Ordering of the tuple should be consistent with the zipped variables below
header = (
"episode_id",
"step_id",
"state",
"action",
"reward",
"terminated",
"truncated",
)

# Check that headers match with file
if not first:
with open(self.fname, "r") as fp:
reader = csv.reader(fp)
file_head = tuple(next(reader))
if file_head != header:
raise ValueError(f"Unexpected header in file {self.fname}")

with open(self.fname, "a") as fp:
writer = csv.writer(fp)
if first:
writer.writerow(header)
for row in zip(
episode_id, step_id, obs, actions, rewards, terminated, truncated
):
writer.writerow(row)


def train():
# Define the algo for training the agent
algo = (
PPOConfig()
.callbacks(TrajectoryCallback)
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
# Set the training batch size to the appropriate number of steps
.training(train_batch_size=4_000)
.environment(env="CartPole-v1")
.build()
)
# Train for 10 iterations
for i in range(10):
result = algo.train()
print(pretty_print(result))

# outputs can be found in AML Studio under the "Outputs + Logs" tab of your job
checkpoint_dir = algo.save(checkpoint_dir="./outputs")
print(f"Checkpoint saved in directory {checkpoint_dir}")


if __name__ == "__main__":
ray_on_aml = Ray_On_AML()
ray = ray_on_aml.getRay()

if ray:
print("head node detected")
ray.init(address="auto")
print(ray.cluster_resources())
train()
else:
print("in worker node")

0 comments on commit 1c75540

Please sign in to comment.