diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index d77d43cbc..ccbcda053 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -123,11 +123,12 @@ def decode_to_language_logits(self, embeddings: torch.Tensor, target_ids: torch. encoder_hidden_states=embeddings, ) logits = self.decoder_network.lm_head(outputs.last_hidden_state) + return logits @torch.no_grad() def decode_to_language_logits_for_inference(self, embeddings: torch.Tensor, max_length: int = 512, pad_token_id: int = 0, eos_token_id: int = 102) -> torch.Tensor: - self.decoder_network.eval() + self.projection_layer.eval() if not isinstance(embeddings, torch.Tensor): embeddings = torch.tensor(embeddings, dtype=torch.float32) @@ -142,36 +143,56 @@ def decode_to_language_logits_for_inference(self, embeddings: torch.Tensor, max_ embeddings = self.projection_layer(embeddings) batch_size = embeddings.shape[0] + device = embeddings.device - decoder_input_ids = torch.full( + current_input_ids = torch.full( (batch_size, 1), pad_token_id, dtype=torch.long, device=device ) - generated_ids = [] + # generated_ids = [1, 2, 3, 4] + generated_ids = [current_input_ids] + past_key_values = None + + is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device) - for _ in range(max_length): + for step in range(max_length): outputs = self.decoder_network( - input_ids=decoder_input_ids, + input_ids=current_input_ids, encoder_hidden_states=embeddings, + past_key_values=past_key_values, + use_cache=True, return_dict=True ) + hidden_states = outputs.last_hidden_state logits = self.decoder_network.lm_head(hidden_states) next_token_logits = logits[:, -1, :] - next_token = next_token_logits.argmax(dim=-1, keepdim=True) + next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + past_key_values = outputs.past_key_values - decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1) + next_token = torch.where(is_finished.unsqueeze(-1), + torch.full_like(next_token, pad_token_id), + next_token) generated_ids.append(next_token) - if (next_token == eos_token_id).all(): + + just_finished = ~is_finished & (next_token.squeeze(-1) == eos_token_id) + is_finished |= just_finished + current_input_ids = next_token + + if is_finished.all(): break - generated_ids = torch.cat(generated_ids, dim=1).cpu().tolist() - - return generated_ids + all_generated_ids = torch.cat(generated_ids, dim=1) + + return all_generated_ids.cpu().tolist() + + # def decode_to_language_logits_for_inference(self, embeddings: torch.Tensor, max_length: int = 512, pad_token_id: int = 0, eos_token_id: int = 102) -> torch.Tensor: + # return [0] @staticmethod def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index d8782b34d..f303b45b8 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -664,11 +664,12 @@ def _forward_collect( roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) - + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} + batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] @@ -691,12 +692,13 @@ def _forward_collect( # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - next_latent_state = next_latent_state_with_env[env_id][action] - predicted_ids = self._collect_model.tokenizer.decode_to_language_logits_for_inference( embeddings=next_latent_state, max_length=512, pad_token_id=0, eos_token_id=102) + next_latent_state = next_latent_state_with_env[i][action] + + predicted_ids = self._collect_model.tokenizer.decode_to_language_logits_for_inference(embeddings=next_latent_state, max_length=256, pad_token_id=0, eos_token_id=102) # ============== TODO: only for visualize ============== # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # distribuxxtions, temperature=self._collect_mcts_temperature, deterministic=True # ) # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] # ============== TODO: only for visualize ============== @@ -817,9 +819,9 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] # 通过选中的action和policy预测得到下一个latent state - next_latent_state = next_latent_state_with_env[env_id][action] + next_latent_state = next_latent_state_with_env[i][action] predicted_ids = self._eval_model.tokenizer.decode_to_language_logits_for_inference( embeddings=next_latent_state, - max_length=512, + max_length=256, pad_token_id=0, eos_token_id=102) output[env_id] = { diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index a50356bb6..4a2a0822d 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -141,7 +141,7 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana if _policy is not None: self.reset_policy(_policy) - self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + self._env_info = {env_id: {'time': 0., 'step': 0, 'text_bleu': 0.} for env_id in range(self._env_num)} self._episode_info = [] self._total_envstep_count = 0 @@ -163,7 +163,7 @@ def _reset_stat(self, env_id: int) -> None: Arguments: - env_id (:obj:`int`): the id where we need to reset the collector's state """ - self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0} + self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0.} @property def envstep(self) -> int: @@ -411,6 +411,8 @@ def collect(self, if collect_with_pure_policy: temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + pred_text, groundtruth_text = [], [] + while True: with self._timer: # Get current ready env obs. @@ -520,6 +522,7 @@ def collect(self, # Interact with the environment # ============================================================== timesteps = self._env.step(actions) + pred_text, groundtruth_text = [], [] interaction_duration = self._timer.value / len(timesteps) @@ -543,6 +546,9 @@ def collect(self, groundtrut_next_text[env_id] = self._env._envs[env_id].tokenizer.decode(valid_input_ids, skip_special_tokens=True) text_bleu = compute_bleu(reference=groundtrut_next_text[env_id], prediction=pred_next_text[env_id]) + pred_text.append(pred_next_text[env_id]) + groundtruth_text.append(groundtrut_next_text[env_id]) + if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) @@ -638,6 +644,7 @@ def collect(self, self._env_info[env_id]['step'] += 1 self._env_info[env_id]['text_bleu'] += text_bleu + collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration @@ -771,10 +778,10 @@ def collect(self, self._total_duration += collected_duration # log - self._output_log(train_iter) + self._output_log(train_iter, groundtruth_text, pred_text) return return_data - def _output_log(self, train_iter: int) -> None: + def _output_log(self, train_iter: int, groundtruth_text: list, pred_text: list) -> None: """ Overview: Log the collector's data and output the log information. @@ -813,12 +820,15 @@ def _output_log(self, train_iter: int) -> None: 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, 'visit_entropy': np.mean(visit_entropy), - 'text_avg_bleu': np.mean(episode_bleu) + 'text_avg_bleu': np.mean(episode_bleu), } if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) self._episode_info.clear() self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + self._logger.info(f"\n'pred_text: {pred_text}") + self._logger.info(f"\n'groundtruth_text: {groundtruth_text}") + for k, v in info.items(): if k in ['each_reward']: continue diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index 63984f01b..c1e212de9 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -19,7 +19,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e gpu_num = 4 collector_env_num: int = 4 # Number of collector environments n_episode = int(collector_env_num*gpu_num) - batch_size = int(16*gpu_num) + batch_size = int(8*gpu_num) # ------------------------------------------------------------------ # Base environment parameters (Note: these values might be adjusted for different env_id) @@ -56,15 +56,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e # User frequently modified configurations # ------------------------------------------------------------------ evaluator_env_num: int = 3 # Number of evaluator environments - num_simulations: int = 100 # Number of simulations + num_simulations: int = 50 # Number of simulations # Project training parameters num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) infer_context_length: int = 4 # Inference context length - # num_unroll_steps: int = 20 # Number of unroll steps (for rollout sequence expansion) - # infer_context_length: int = 10 # Inference context length - num_layers: int = 2 # Number of layers in the model # replay_ratio: float = 0.25 # Replay ratio for experience replay replay_ratio: float = 0.1 # Replay ratio for experience replay