Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: RL baselines as policy support #60

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,5 @@ dmypy.json
_build
logs
demos
prof/
prof/
runs
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ coverage-view:
test-verbose:
pytest --cov-config=.coveragerc --cov=malib --cov-report html --doctest-modules tests -v -s
rm -f .coverage.*

.PHONY: compile
compile:
python -m grpc_tools.protoc -I malib/backend/protos --python_out=malib/backend/dataset_server --pyi_out=malib/backend/dataset_server --grpc_python_out=malib/backend/dataset_server malib/backend/protos/data.proto
File renamed without changes.
16 changes: 16 additions & 0 deletions docs/source/api/malib.common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,26 @@ malib.common.payoff\_manager module
:undoc-members:
:show-inheritance:

malib.common.retrace module
---------------------------

.. automodule:: malib.common.retrace
:members:
:undoc-members:
:show-inheritance:

malib.common.strategy\_spec module
----------------------------------

.. automodule:: malib.common.strategy_spec
:members:
:undoc-members:
:show-inheritance:

malib.common.vtrace module
--------------------------

.. automodule:: malib.common.vtrace
:members:
:undoc-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/source/api/malib.rl.ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ malib.rl.ppo package
Submodules
----------

malib.rl.ppo.config module
--------------------------

.. automodule:: malib.rl.ppo.config
:members:
:undoc-members:
:show-inheritance:

malib.rl.ppo.policy module
--------------------------

Expand Down
99 changes: 0 additions & 99 deletions examples/run_gym.py

This file was deleted.

8 changes: 5 additions & 3 deletions examples/run_psro.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import os
import time

from malib.runner import run
from malib.agent import IndependentAgent
from malib.scenarios import psro_scenario
from malib.learner import IndependentAgent
from malib.scenarios.psro_scenario import PSROScenario
from malib.rl.dqn import DQNPolicy, DQNTrainer, DEFAULT_CONFIG
from malib.rollout.envs.open_spiel import env_desc_gen
Expand Down Expand Up @@ -99,4 +99,6 @@
},
)

run(scenario)
results = psro_scenario.execution_plan(scenario=scenario, verbose=True)

print(results)
99 changes: 99 additions & 0 deletions examples/sarl/ppo_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import time

from argparse import ArgumentParser

from gym import spaces

import numpy as np

from malib.rollout.episode import Episode
from malib.learner import IndependentAgent
from malib.scenarios import sarl_scenario
from malib.rl.config import Algorithm
from malib.rl.ppo import PPOPolicy, PPOTrainer, DEFAULT_CONFIG
from malib.learner.config import LearnerConfig
from malib.rollout.config import RolloutConfig
from malib.rollout.envs.gym import env_desc_gen
from malib.backend.dataset_server.feature import BaseFeature


class FeatureHandler(BaseFeature):
pass


def feature_handler_meta_gen(env_desc, agent_id):
"""Return a generator of feature handler meta.

Args:
env_desc (_type_): _description_
agent_id (_type_): _description_
"""

def f(device="cpu"):
# define the data schema
_spaces = {
Episode.DONE: spaces.Discrete(1),
Episode.CUR_OBS: env_desc["observation_spaces"][agent_id],
Episode.ACTION: env_desc["action_spaces"][agent_id],
Episode.REWARD: spaces.Box(-np.inf, np.inf, shape=(), dtype=np.float32),
Episode.NEXT_OBS: env_desc["observation_spaces"][agent_id],
}

# you should know the maximum of replaybuffer before training
np_memory = {
k: np.zeros((10000,) + v.shape, dtype=v.dtype) for k, v in _spaces.items()
}

return FeatureHandler(_spaces, np_memory, device=device)

return f


if __name__ == "__main__":
parser = ArgumentParser("Use PPO solve Gym tasks.")
parser.add_argument("--log-dir", default="./logs/", help="Log directory.")
parser.add_argument("--env-id", default="CartPole-v1", help="gym environment id.")
parser.add_argument("--use-cuda", action="store_true")

args = parser.parse_args()

trainer_config = DEFAULT_CONFIG.TRAINING_CONIG.copy()
trainer_config["total_timesteps"] = int(1e6)
trainer_config["use_cuda"] = args.use_cuda

runtime_logdir = os.path.join(
args.log_dir, f"gym/{args.env_id}/independent_ppo/{time.time()}"
)

if not os.path.exists(runtime_logdir):
os.makedirs(runtime_logdir)

scenario = sarl_scenario.SARLScenario(
name=f"ppo-gym-{args.env_id}",
log_dir=runtime_logdir,
env_desc=env_desc_gen(env_id=args.env_id),
algorithm=Algorithm(
trainer=PPOTrainer,
policy=PPOPolicy,
model_config=None, # use default
trainer_config=trainer_config,
),
learner_config=LearnerConfig(
learner_type=IndependentAgent,
feature_handler_meta_gen=feature_handler_meta_gen,
custom_config={},
),
rollout_config=RolloutConfig(
num_workers=1,
),
stopping_conditions={
"golbal": {"max_iteration": 1000, "minimum_reward_improvement": 1.0},
"rollout": {"max_iteration": 1},
"training": {"max_iteration": 1},
},
)

results = sarl_scenario.execution_plan(scenario=scenario, verbose=False)

print(results)
Loading
Loading