diff --git a/ding/policy/qtransformer.py b/ding/policy/qtransformer.py index 4043037b8d..e3e317aaf1 100644 --- a/ding/policy/qtransformer.py +++ b/ding/policy/qtransformer.py @@ -270,6 +270,7 @@ def _init_learn(self) -> None: self._target_model.reset() self._forward_learn_cnt = 0 + wandb.init(**self._cfg.wandb) def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: """ @@ -296,7 +297,6 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: You can implement you own model rather than use the default model. For more information, please raise an \ issue in GitHub repo and we will continue to follow up. """ - wandb.init(**self._cfg.wandb) def merge_dict1_into_dict2( dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] @@ -393,21 +393,21 @@ def batch_select_indices(t, indices): q_pred_rest_actions, q_pred_last_action = q_pred[:, :-1], q_pred[:, -1:] with torch.no_grad(): - q_next_target = self._target_model.forward(next_state) + # q_next_target = self._target_model.forward(next_state) q_target = self._target_model.forward(state, action=action)[:, :-1, :] q_target_rest_actions = q_target[:, 1:, :] max_q_target_rest_actions = q_target_rest_actions.max(dim=-1).values - q_next_target_first_action = q_next_target[:, 0:1, :] - max_q_next_target_first_action = q_next_target_first_action.max(dim=-1).values + # q_next_target_first_action = q_next_target[:, 0:1, :] + # max_q_next_target_first_action = q_next_target_first_action.max(dim=-1).values losses_all_actions_but_last = F.mse_loss( q_pred_rest_actions, max_q_target_rest_actions ) q_target_last_action = (reward * (1.0 - done.int())).unsqueeze( 1 - ) + self._gamma * max_q_next_target_first_action + ) + self._gamma * data["mc"] losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action) td_loss = losses_all_actions_but_last + losses_last_action td_loss.mean() diff --git a/qtransformer/algorithm/__init__.py b/qtransformer/algorithm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qtransformer/algorithm/utils.py b/qtransformer/algorithm/utils.py new file mode 100644 index 0000000000..e9b6a4a260 --- /dev/null +++ b/qtransformer/algorithm/utils.py @@ -0,0 +1,77 @@ +from typing import Optional, Callable, List, Any + +from ding.policy import PolicyFactory +from ding.worker import IMetric, MetricSerialEvaluator + + +class AccMetric(IMetric): + + def eval(self, inputs: Any, label: Any) -> dict: + return { + "Acc": (inputs["logit"].sum(dim=1) == label).sum().item() / label.shape[0] + } + + def reduce_mean(self, inputs: List[Any]) -> Any: + s = 0 + for item in inputs: + s += item["Acc"] + return {"Acc": s / len(inputs)} + + def gt(self, metric1: Any, metric2: Any) -> bool: + if metric2 is None: + return True + if isinstance(metric2, dict): + m2 = metric2["Acc"] + else: + m2 = metric2 + return metric1["Acc"] > m2 + + +def mark_not_expert(ori_data: List[dict]) -> List[dict]: + for i in range(len(ori_data)): + # Set is_expert flag (expert 1, agent 0) + ori_data[i]["is_expert"] = 0 + return ori_data + + +def mark_warm_up(ori_data: List[dict]) -> List[dict]: + # for td3_vae + for i in range(len(ori_data)): + ori_data[i]["warm_up"] = True + return ori_data + + +def random_collect( + policy_cfg: "EasyDict", # noqa + policy: "Policy", # noqa + collector: "ISerialCollector", # noqa + collector_env: "BaseEnvManager", # noqa + commander: "BaseSerialCommander", # noqa + replay_buffer: "IBuffer", # noqa + postprocess_data_fn: Optional[Callable] = None, +) -> None: # noqa + assert policy_cfg.random_collect_size > 0 + if policy_cfg.get("transition_with_policy_data", False): + collector.reset_policy(policy.collect_mode) + else: + action_space = collector_env.action_space + random_policy = PolicyFactory.get_random_policy( + policy.collect_mode, action_space=action_space + ) + collector.reset_policy(random_policy) + # collect_kwargs = commander.step() + if policy_cfg.collect.collector.type == "episode": + new_data = collector.collect( + n_episode=policy_cfg.random_collect_size, policy_kwargs=None + ) + else: + new_data = collector.collect( + n_sample=policy_cfg.random_collect_size, + random_collect=True, + record_random_collect=False, + policy_kwargs=None, + ) # 'record_random_collect=False' means random collect without output log + if postprocess_data_fn is not None: + new_data = postprocess_data_fn(new_data) + replay_buffer.push(new_data, cur_collector_envstep=0) + collector.reset_policy(policy.collect_mode) diff --git a/qtransformer/algorithm/walker2d_qtransformer_online.py b/qtransformer/algorithm/walker2d_qtransformer_online.py index fcbb73beff..04382e84b9 100644 --- a/qtransformer/algorithm/walker2d_qtransformer_online.py +++ b/qtransformer/algorithm/walker2d_qtransformer_online.py @@ -36,8 +36,8 @@ action_bin=256, ), learn=dict( - update_per_collect=1, - batch_size=2048, + update_per_collect=5, + batch_size=200, learning_rate_q=3e-4, learning_rate_policy=1e-4, learning_rate_alpha=1e-4, @@ -57,7 +57,11 @@ unroll_len=1, ), command=dict(), - eval=dict(), + eval=dict( + evaluator=dict( + eval_freq=10, + ) + ), other=dict( replay_buffer=dict( replay_buffer_size=1000000, @@ -88,7 +92,7 @@ if __name__ == "__main__": # or you can enter `ding -m serial -c walker2d_sac_config.py -s 0` - from ding.entry import serial_pipeline + from qtransformer.algorithm.serial_entry import serial_pipeline model = QTransformer(**main_config.policy.model) serial_pipeline([main_config, create_config], seed=0, model=model)