Skip to content

Commit

Permalink
Merge pull request #76 from salesforce/qlearner
Browse files Browse the repository at this point in the history
model factory
  • Loading branch information
Emerald01 authored Mar 22, 2023
2 parents b5d46d4 + b37b73f commit 2632a07
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 33 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Changelog
# Release 2.3 (2022-03-22)
- Add ModelFactory class to manage custom models
- Add Xavier initialization for the model
- Improve trainer.fetch_episode_states() so it can fetch (s, a, r) and can replay with argmax.

# Release 2.2 (2022-12-20)
- Factorize the data loading for placeholders and batches (obs, actions and rewards) for the trainer.

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

setup(
name="rl-warp-drive",
version="2.2.2",
version="2.3",
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng",
author_email="stephan.zheng@salesforce.com",
author_email="tian.lan@salesforce.com",
description="Framework for fast end-to-end "
"multi-agent reinforcement learning on GPUs.",
long_description=open("README.md", "r", encoding="utf-8").read(),
Expand Down
55 changes: 55 additions & 0 deletions warp_drive/training/models/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import importlib

# warpdrive reserved models
default_models = {
"fully_connected": "warp_drive.training.models.fully_connected:FullyConnected",
}


def dynamic_import(model_name: str, model_pool: dict):
"""
Dynamically import a member from the specified module.
:param model_name: the name of the model, e.g., fully_connected
:param model_pool: the dictionary of all available models
:return: imported class
"""

if model_name not in model_pool:
raise ValueError(
f"model_name {model_name} should be registered in the model factory in the form of,"
f"e.g. {'fully_connected': 'warp_drive.training.models.fully_connected:FullyConnected' } "
)
if ":" not in model_pool[model_name]:
raise ValueError(
f"Invalid model path format, expect ':' to separate the path and the object name"
f"e.g. 'warp_drive.training.models.fully_connected:FullyConnected' "
)

module_name, objname = model_pool[model_name].split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)


class ModelFactory:

model_pool = {}

@classmethod
def add(cls, model_name: str, model_path: str, object_name: str):
"""
:param model_name: e.g., "fully_connected"
:param model_path: e.g., "warp_drive.training.models.fully_connected"
:param object_name: e.g., "FullyConnected"
:return:
:rtype:
"""
assert model_name not in default_models and model_name not in cls.model_pool, \
f"{model_name} has already been used by the model collection"

cls.model_pool.update({model_name: f"{model_path}:{object_name}"})

@classmethod
def create(cls, model_name):
cls.model_pool.update(default_models)
return dynamic_import(model_name, model_pool=cls.model_pool)
3 changes: 2 additions & 1 deletion warp_drive/training/models/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class FullyConnected(nn.Module):
def __init__(
self,
env,
fc_dims,
model_config,
policy,
policy_tag_to_agent_id_map,
create_separate_placeholders_for_each_policy=False,
Expand All @@ -59,6 +59,7 @@ def __init__(
super().__init__()

self.env = env
fc_dims = model_config["fc_dims"]
assert isinstance(fc_dims, list)
num_fc_layers = len(fc_dims)
self.policy = policy
Expand Down
22 changes: 10 additions & 12 deletions warp_drive/training/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from warp_drive.training.algorithms.policygradient.a2c import A2C
from warp_drive.training.algorithms.policygradient.ppo import PPO
from warp_drive.training.models.fully_connected import FullyConnected
from warp_drive.training.models.factory import ModelFactory
from warp_drive.training.trainer import Metrics
from warp_drive.training.utils.data_loader import create_and_push_data_placeholders
from warp_drive.training.utils.param_scheduler import LRScheduler, ParamScheduler
Expand Down Expand Up @@ -353,17 +353,15 @@ def _initialize_policy_algorithm(self, policy):

def _initialize_policy_model(self, policy):
policy_model_config = self._get_config(["policy", policy, "model"])
if policy_model_config["type"] == "fully_connected":
model = FullyConnected(
self.cuda_envs,
policy_model_config["fc_dims"],
policy,
self.policy_tag_to_agent_id_map,
self.create_separate_placeholders_for_each_policy,
self.obs_dim_corresponding_to_num_agents,
)
else:
raise NotImplementedError
model_obj = ModelFactory.create(policy_model_config["type"])
model = model_obj(
env=self.cuda_envs,
model_config=policy_model_config,
policy=policy,
policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map,
create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy,
obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents,
)
self.models[policy] = model

def _get_config(self, args):
Expand Down
36 changes: 18 additions & 18 deletions warp_drive/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from warp_drive.training.algorithms.policygradient.a2c import A2C
from warp_drive.training.algorithms.policygradient.ppo import PPO
from warp_drive.training.models.fully_connected import FullyConnected
from warp_drive.training.models.factory import ModelFactory
from warp_drive.training.utils.data_loader import create_and_push_data_placeholders
from warp_drive.training.utils.param_scheduler import ParamScheduler
from warp_drive.utils.common import get_project_root
Expand Down Expand Up @@ -368,24 +368,24 @@ def _initialize_policy_algorithm(self, policy):

def _initialize_policy_model(self, policy):
policy_model_config = self._get_config(["policy", policy, "model"])
if policy_model_config["type"] == "fully_connected":
model = FullyConnected(
self.cuda_envs,
policy_model_config["fc_dims"],
policy,
self.policy_tag_to_agent_id_map,
self.create_separate_placeholders_for_each_policy,
self.obs_dim_corresponding_to_num_agents,
)
if "init_method" in policy_model_config and \
policy_model_config["init_method"] == "xavier":
def init_weights_by_xavier_uniform(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)
model_obj = ModelFactory.create(policy_model_config["type"])
model = model_obj(
env=self.cuda_envs,
model_config=policy_model_config,
policy=policy,
policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map,
create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy,
obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents,
)

if "init_method" in policy_model_config and \
policy_model_config["init_method"] == "xavier":
def init_weights_by_xavier_uniform(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)

model.apply(init_weights_by_xavier_uniform)

model.apply(init_weights_by_xavier_uniform)
else:
raise NotImplementedError
self.models[policy] = model

def _get_config(self, args):
Expand Down

0 comments on commit 2632a07

Please sign in to comment.