Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
rongkunxue committed Jul 4, 2024
1 parent 44d746e commit 5d59b3d
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 9 deletions.
10 changes: 5 additions & 5 deletions ding/policy/qtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _init_learn(self) -> None:
self._target_model.reset()

self._forward_learn_cnt = 0
wandb.init(**self._cfg.wandb)

def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Expand All @@ -296,7 +297,6 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
You can implement you own model rather than use the default model. For more information, please raise an \
issue in GitHub repo and we will continue to follow up.
"""
wandb.init(**self._cfg.wandb)

def merge_dict1_into_dict2(
dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict]
Expand Down Expand Up @@ -393,21 +393,21 @@ def batch_select_indices(t, indices):

q_pred_rest_actions, q_pred_last_action = q_pred[:, :-1], q_pred[:, -1:]
with torch.no_grad():
q_next_target = self._target_model.forward(next_state)
# q_next_target = self._target_model.forward(next_state)
q_target = self._target_model.forward(state, action=action)[:, :-1, :]

q_target_rest_actions = q_target[:, 1:, :]
max_q_target_rest_actions = q_target_rest_actions.max(dim=-1).values

q_next_target_first_action = q_next_target[:, 0:1, :]
max_q_next_target_first_action = q_next_target_first_action.max(dim=-1).values
# q_next_target_first_action = q_next_target[:, 0:1, :]
# max_q_next_target_first_action = q_next_target_first_action.max(dim=-1).values

losses_all_actions_but_last = F.mse_loss(
q_pred_rest_actions, max_q_target_rest_actions
)
q_target_last_action = (reward * (1.0 - done.int())).unsqueeze(
1
) + self._gamma * max_q_next_target_first_action
) + self._gamma * data["mc"]
losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action)
td_loss = losses_all_actions_but_last + losses_last_action
td_loss.mean()
Expand Down
Empty file.
77 changes: 77 additions & 0 deletions qtransformer/algorithm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Optional, Callable, List, Any

from ding.policy import PolicyFactory
from ding.worker import IMetric, MetricSerialEvaluator


class AccMetric(IMetric):

def eval(self, inputs: Any, label: Any) -> dict:
return {
"Acc": (inputs["logit"].sum(dim=1) == label).sum().item() / label.shape[0]
}

def reduce_mean(self, inputs: List[Any]) -> Any:
s = 0
for item in inputs:
s += item["Acc"]
return {"Acc": s / len(inputs)}

def gt(self, metric1: Any, metric2: Any) -> bool:
if metric2 is None:
return True
if isinstance(metric2, dict):
m2 = metric2["Acc"]
else:
m2 = metric2
return metric1["Acc"] > m2


def mark_not_expert(ori_data: List[dict]) -> List[dict]:
for i in range(len(ori_data)):
# Set is_expert flag (expert 1, agent 0)
ori_data[i]["is_expert"] = 0
return ori_data


def mark_warm_up(ori_data: List[dict]) -> List[dict]:
# for td3_vae
for i in range(len(ori_data)):
ori_data[i]["warm_up"] = True
return ori_data


def random_collect(
policy_cfg: "EasyDict", # noqa
policy: "Policy", # noqa
collector: "ISerialCollector", # noqa
collector_env: "BaseEnvManager", # noqa
commander: "BaseSerialCommander", # noqa
replay_buffer: "IBuffer", # noqa
postprocess_data_fn: Optional[Callable] = None,
) -> None: # noqa
assert policy_cfg.random_collect_size > 0
if policy_cfg.get("transition_with_policy_data", False):
collector.reset_policy(policy.collect_mode)
else:
action_space = collector_env.action_space
random_policy = PolicyFactory.get_random_policy(
policy.collect_mode, action_space=action_space
)
collector.reset_policy(random_policy)
# collect_kwargs = commander.step()
if policy_cfg.collect.collector.type == "episode":
new_data = collector.collect(
n_episode=policy_cfg.random_collect_size, policy_kwargs=None
)
else:
new_data = collector.collect(
n_sample=policy_cfg.random_collect_size,
random_collect=True,
record_random_collect=False,
policy_kwargs=None,
) # 'record_random_collect=False' means random collect without output log
if postprocess_data_fn is not None:
new_data = postprocess_data_fn(new_data)
replay_buffer.push(new_data, cur_collector_envstep=0)
collector.reset_policy(policy.collect_mode)
12 changes: 8 additions & 4 deletions qtransformer/algorithm/walker2d_qtransformer_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
action_bin=256,
),
learn=dict(
update_per_collect=1,
batch_size=2048,
update_per_collect=5,
batch_size=200,
learning_rate_q=3e-4,
learning_rate_policy=1e-4,
learning_rate_alpha=1e-4,
Expand All @@ -57,7 +57,11 @@
unroll_len=1,
),
command=dict(),
eval=dict(),
eval=dict(
evaluator=dict(
eval_freq=10,
)
),
other=dict(
replay_buffer=dict(
replay_buffer_size=1000000,
Expand Down Expand Up @@ -88,7 +92,7 @@

if __name__ == "__main__":
# or you can enter `ding -m serial -c walker2d_sac_config.py -s 0`
from ding.entry import serial_pipeline
from qtransformer.algorithm.serial_entry import serial_pipeline

model = QTransformer(**main_config.policy.model)
serial_pipeline([main_config, create_config], seed=0, model=model)

0 comments on commit 5d59b3d

Please sign in to comment.