Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions lzero/model/unizero_world_models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ def decode_to_language_logits(self, embeddings: torch.Tensor, target_ids: torch.
encoder_hidden_states=embeddings,
)
logits = self.decoder_network.lm_head(outputs.last_hidden_state)
return logits

@torch.no_grad()
def decode_to_language_logits_for_inference(self, embeddings: torch.Tensor, max_length: int = 512, pad_token_id: int = 0, eos_token_id: int = 102) -> torch.Tensor:

self.decoder_network.eval()
self.projection_layer.eval()

if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings, dtype=torch.float32)
Expand All @@ -142,36 +143,56 @@ def decode_to_language_logits_for_inference(self, embeddings: torch.Tensor, max_
embeddings = self.projection_layer(embeddings)

batch_size = embeddings.shape[0]

device = embeddings.device
decoder_input_ids = torch.full(
current_input_ids = torch.full(
(batch_size, 1),
pad_token_id,
dtype=torch.long,
device=device
)

generated_ids = []
# generated_ids = [1, 2, 3, 4]
generated_ids = [current_input_ids]
past_key_values = None

is_finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

for _ in range(max_length):
for step in range(max_length):
outputs = self.decoder_network(
input_ids=decoder_input_ids,
input_ids=current_input_ids,
encoder_hidden_states=embeddings,
past_key_values=past_key_values,
use_cache=True,
return_dict=True
)

hidden_states = outputs.last_hidden_state
logits = self.decoder_network.lm_head(hidden_states)

next_token_logits = logits[:, -1, :]
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

past_key_values = outputs.past_key_values

decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
next_token = torch.where(is_finished.unsqueeze(-1),
torch.full_like(next_token, pad_token_id),
next_token)
generated_ids.append(next_token)
if (next_token == eos_token_id).all():

just_finished = ~is_finished & (next_token.squeeze(-1) == eos_token_id)
is_finished |= just_finished
current_input_ids = next_token

if is_finished.all():
break

generated_ids = torch.cat(generated_ids, dim=1).cpu().tolist()

return generated_ids
all_generated_ids = torch.cat(generated_ids, dim=1)

return all_generated_ids.cpu().tolist()

# def decode_to_language_logits_for_inference(self, embeddings: torch.Tensor, max_length: int = 512, pad_token_id: int = 0, eos_token_id: int = 102) -> torch.Tensor:
# return [0]

@staticmethod
def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor:
Expand Down
14 changes: 8 additions & 6 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,11 +664,12 @@ def _forward_collect(
roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play)

next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep)

# list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
roots_visit_count_distributions = roots.get_distributions()
roots_values = roots.get_values() # shape: {list: batch_size}


batch_action = []
for i, env_id in enumerate(ready_env_id):
distributions, value = roots_visit_count_distributions[i], roots_values[i]
Expand All @@ -691,12 +692,13 @@ def _forward_collect(
# NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set.
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]

next_latent_state = next_latent_state_with_env[env_id][action]
predicted_ids = self._collect_model.tokenizer.decode_to_language_logits_for_inference( embeddings=next_latent_state, max_length=512, pad_token_id=0, eos_token_id=102)
next_latent_state = next_latent_state_with_env[i][action]

predicted_ids = self._collect_model.tokenizer.decode_to_language_logits_for_inference(embeddings=next_latent_state, max_length=256, pad_token_id=0, eos_token_id=102)

# ============== TODO: only for visualize ==============
# action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
# distributions, temperature=self._collect_mcts_temperature, deterministic=True
# distribuxxtions, temperature=self._collect_mcts_temperature, deterministic=True
# )
# action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
# ============== TODO: only for visualize ==============
Expand Down Expand Up @@ -817,9 +819,9 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]

# 通过选中的action和policy预测得到下一个latent state
next_latent_state = next_latent_state_with_env[env_id][action]
next_latent_state = next_latent_state_with_env[i][action]
predicted_ids = self._eval_model.tokenizer.decode_to_language_logits_for_inference( embeddings=next_latent_state,
max_length=512,
max_length=256,
pad_token_id=0,
eos_token_id=102)
output[env_id] = {
Expand Down
20 changes: 15 additions & 5 deletions lzero/worker/muzero_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana
if _policy is not None:
self.reset_policy(_policy)

self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}
self._env_info = {env_id: {'time': 0., 'step': 0, 'text_bleu': 0.} for env_id in range(self._env_num)}

self._episode_info = []
self._total_envstep_count = 0
Expand All @@ -163,7 +163,7 @@ def _reset_stat(self, env_id: int) -> None:
Arguments:
- env_id (:obj:`int`): the id where we need to reset the collector's state
"""
self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0}
self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0.}

@property
def envstep(self) -> int:
Expand Down Expand Up @@ -411,6 +411,8 @@ def collect(self,
if collect_with_pure_policy:
temp_visit_list = [0.0 for i in range(self._env.action_space.n)]

pred_text, groundtruth_text = [], []

while True:
with self._timer:
# Get current ready env obs.
Expand Down Expand Up @@ -520,6 +522,7 @@ def collect(self,
# Interact with the environment
# ==============================================================
timesteps = self._env.step(actions)
pred_text, groundtruth_text = [], []

interaction_duration = self._timer.value / len(timesteps)

Expand All @@ -543,6 +546,9 @@ def collect(self,
groundtrut_next_text[env_id] = self._env._envs[env_id].tokenizer.decode(valid_input_ids, skip_special_tokens=True)

text_bleu = compute_bleu(reference=groundtrut_next_text[env_id], prediction=pred_next_text[env_id])
pred_text.append(pred_next_text[env_id])
groundtruth_text.append(groundtrut_next_text[env_id])


if collect_with_pure_policy:
game_segments[env_id].store_search_stats(temp_visit_list, 0)
Expand Down Expand Up @@ -638,6 +644,7 @@ def collect(self,

self._env_info[env_id]['step'] += 1
self._env_info[env_id]['text_bleu'] += text_bleu

collected_step += 1

self._env_info[env_id]['time'] += self._timer.value + interaction_duration
Expand Down Expand Up @@ -771,10 +778,10 @@ def collect(self,
self._total_duration += collected_duration

# log
self._output_log(train_iter)
self._output_log(train_iter, groundtruth_text, pred_text)
return return_data

def _output_log(self, train_iter: int) -> None:
def _output_log(self, train_iter: int, groundtruth_text: list, pred_text: list) -> None:
"""
Overview:
Log the collector's data and output the log information.
Expand Down Expand Up @@ -813,12 +820,15 @@ def _output_log(self, train_iter: int) -> None:
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
'visit_entropy': np.mean(visit_entropy),
'text_avg_bleu': np.mean(episode_bleu)
'text_avg_bleu': np.mean(episode_bleu),
}
if self.policy_config.gumbel_algo:
info['completed_value'] = np.mean(completed_value)
self._episode_info.clear()
self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
self._logger.info(f"\n'pred_text: {pred_text}")
self._logger.info(f"\n'groundtruth_text: {groundtruth_text}")

for k, v in info.items():
if k in ['each_reward']:
continue
Expand Down
7 changes: 2 additions & 5 deletions zoo/jericho/configs/jericho_unizero_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
gpu_num = 4
collector_env_num: int = 4 # Number of collector environments
n_episode = int(collector_env_num*gpu_num)
batch_size = int(16*gpu_num)
batch_size = int(8*gpu_num)

# ------------------------------------------------------------------
# Base environment parameters (Note: these values might be adjusted for different env_id)
Expand Down Expand Up @@ -56,15 +56,12 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
# User frequently modified configurations
# ------------------------------------------------------------------
evaluator_env_num: int = 3 # Number of evaluator environments
num_simulations: int = 100 # Number of simulations
num_simulations: int = 50 # Number of simulations

# Project training parameters
num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion)
infer_context_length: int = 4 # Inference context length

# num_unroll_steps: int = 20 # Number of unroll steps (for rollout sequence expansion)
# infer_context_length: int = 10 # Inference context length

num_layers: int = 2 # Number of layers in the model
# replay_ratio: float = 0.25 # Replay ratio for experience replay
replay_ratio: float = 0.1 # Replay ratio for experience replay
Expand Down