diff --git a/.gitignore b/.gitignore index fed96ee2f..1cda7444c 100644 --- a/.gitignore +++ b/.gitignore @@ -1450,4 +1450,4 @@ events.* !/assets/pooltool/** lzero/mcts/ctree/ctree_alphazero/pybind11 -zoo/jericho/envs/z-machine-games-master \ No newline at end of file +zoo/jericho/envs/z-machine-games-master diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 7dc0328f7..68f3a66aa 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -9,4 +9,7 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment +from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp +from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp +from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval from .utils import * diff --git a/lzero/entry/compute_task_weight.py b/lzero/entry/compute_task_weight.py new file mode 100644 index 000000000..84204a9a2 --- /dev/null +++ b/lzero/entry/compute_task_weight.py @@ -0,0 +1,80 @@ + + + +import numpy as np +import torch + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 归一化,减少目标值的幅度差异。 + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 的逆操作,用于恢复原始值。 + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + +def compute_task_weights( + task_rewards: dict, + epsilon: float = 1e-6, + min_weight: float = 0.1, + max_weight: float = 0.5, + temperature: float = 1.0, + use_symlog: bool = True, +) -> dict: + """ + 改进后的任务权重计算函数,加入 symlog 处理和鲁棒性设计。 + + Args: + task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。 + epsilon (float): 避免分母为零的小值。 + min_weight (float): 权重的最小值,用于裁剪。 + max_weight (float): 权重的最大值,用于裁剪。 + temperature (float): 控制权重分布的温度系数。 + use_symlog (bool): 是否使用 symlog 对 task_rewards 进行矫正。 + + Returns: + dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。 + """ + # Step 1: 矫正奖励值(可选,使用 symlog) + if use_symlog: + rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + corrected_rewards = symlog(rewards_tensor).numpy() # 使用 symlog 矫正 + task_rewards = dict(zip(task_rewards.keys(), corrected_rewards)) + + # Step 2: 计算初始权重(反比例关系) + raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()} + + # Step 3: 温度缩放 + scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()} + + # Step 4: 归一化权重 + total_weight = sum(scaled_weights.values()) + normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()} + + # Step 5: 裁剪权重,确保在 [min_weight, max_weight] 范围内 + clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()} + + final_weights = clipped_weights + return final_weights + +task_rewards_list = [ + {"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300}, + {"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000}, + {"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10}, +] + +for i, task_rewards in enumerate(task_rewards_list, start=1): + print(f"Case {i}: Original Rewards: {task_rewards}") + print("Original Weights:") + print(compute_task_weights(task_rewards, use_symlog=False)) + print("Improved Weights with Symlog:") + print(compute_task_weights(task_rewards, use_symlog=True)) + print() \ No newline at end of file diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py new file mode 100644 index 000000000..5ece29f28 --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -0,0 +1,582 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import MuZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.distributed as dist + +import concurrent.futures + +# ========== 超时时间设置 ========== +TIMEOUT = 3600 # 例如,60分钟 + +timer = EasyTimer() + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + 安全地执行评估操作,防止因超时导致训练过程阻塞。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的排名。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: + - stop (Optional[bool]): 评估是否停止的标志。 + - reward (Optional[float]): 评估得到的奖励。 + """ + print(f"=========评估前 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交 evaluator.eval 任务 + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 evaluator 的 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超过 {TIMEOUT} 秒超时。") + return None, None + + print(f"======评估后 Rank {rank}/{world_size}======") + return stop, reward + + +def allocate_batch_size( + cfgs: List, + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的 num_of_collected_episodes 反比分配 batch_size, + 并动态调整 batch_size 限制范围以提高训练的稳定性和效率。 + + Args: + cfgs (List): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的 replay_buffer 实例列表。 + alpha (float): 控制反比程度的超参数 (默认为1.0)。 + clip_scale (int): 动态调整的缩放因子 (默认为1)。 + + Returns: + List[int]: 分配后的 batch_size 列表。 + """ + # 提取每个任务的 num_of_collected_episodes + buffer_num_of_collected_episodes = [ + buffer.num_of_collected_episodes for buffer in game_buffers + ] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 num_of_collected_episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object( + all_task_num_of_collected_episodes, + buffer_num_of_collected_episodes + ) + + # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 + all_task_num_of_collected_episodes = [ + item for sublist in all_task_num_of_collected_episodes for item in sublist + ] + if rank == 0: + print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([ + 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes + ]) + inv_sum = np.sum(inv_episodes) + + # 计算总的 batch_size (所有任务 cfg.policy.max_batch_size 的和) + max_batch_size = cfgs[0].policy.max_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = max_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = max_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + # 返回最终分配的 batch_size 列表 + return batch_sizes + + +def train_muzero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for multi-task MuZero, adapted from UniZero's multi-task training. + This script aims to enhance the planning capabilities of reinforcement learning agents + by leveraging multi-task learning to address diverse environments. + + Args: + input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): + Configurations for different tasks as a list of tuples containing task ID and configuration dictionaries. + seed (int): + Random seed for reproducibility. + model (Optional[torch.nn.Module]): + Predefined model instance. If provided, it will be used instead of creating a new one. + model_path (Optional[str]): + Path to the pretrained model checkpoint. Should point to the ckpt file of the pretrained model. + max_train_iter (Optional[int]): + Maximum number of training iterations. Defaults to 1e10. + max_env_step (Optional[int]): + Maximum number of environment interaction steps. Defaults to 1e10. + + Returns: + Policy: + The trained policy instance. + """ + # 获取当前进程的 rank 和总的进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任何任务,继续运行但无任务处理。") + # 初始化一些空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + return + + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # 使用第一个任务的配置来创建共享的 policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + # 设置每个任务的随机种子和任务编号 + for config in tasks_for_this_rank: + config[1][0].policy.task_num = len(tasks_for_this_rank) + + # 根据 CUDA 可用性设置设备 + cfg.policy.device = cfg.policy.model.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config( + cfg, + seed=seed, + env=None, + auto=True, + create_cfg=create_cfg, + save_cfg=True + ) + # 创建共享的 policy + policy = create_policy( + cfg.policy, + model=model, + enable_field=['learn', 'collect', 'eval'] + ) + + # 如果指定了预训练模型,则加载 + if model_path is not None: + logging.info(f'开始加载模型来自 {model_path}...') + policy.learn_mode.load_state_dict( + torch.load(model_path, map_location=cfg.policy.device) + ) + logging.info(f'完成加载模型来自 {model_path}.') + + # 创建 TensorBoard 的日志记录器 + log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的 learner + learner = BaseLearner( + cfg.policy.learn.learner, + policy.learn_mode, + tb_logger, + exp_name=cfg.exp_name + ) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 只处理当前进程分配到的任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务自己的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config( + cfg, + seed=seed + task_id, + env=None, + auto=True, + create_cfg=create_cfg, + save_cfg=True + ) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager( + cfg.env.manager, + [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + evaluator_env = create_env_manager( + cfg.env.manager, + [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 为每个任务创建不同的 game buffer、collector、evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + torch.cuda.empty_cache() + + if cfg.policy.allocated_batch_sizes: + # TODO========== + # 线性变化的 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size( + cfgs, + game_buffers, + alpha=1.0, + clip_scale=clip_scale + ) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers) + ): + cfg.policy.batch_size = allocated_batch_sizes[idx] + policy._cfg.batch_size[idx] = allocated_batch_sizes[idx] + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers) + ): + + log_buffer_memory_usage( + learner.train_iter, + replay_buffer, + tb_logger, + cfg.policy.task_id + ) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + # if learner.train_iter > 1 and evaluator.should_eval(learner.train_iter): # TODO: debug + print('=' * 20) + print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...') + + # 在训练进程中调用 safe_eval + stop, reward = safe_eval( + evaluator, + learner, + collector, + rank, + world_size + ) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估期间遇到问题。继续训练中...") + else: + print(f"评估成功: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'entry: Rank {rank} 收集 task_id: {cfg.policy.task_id}...') + + # 收集数据 + new_data = collector.collect( + train_iter=learner.train_iter, + policy_kwargs=collect_kwargs + ) + + # 更新 replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 + if ( + train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition) + ): + with timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfg.policy.batch_size[cfg.policy.task_id] + for cfg, replay_buffer in zip(cfgs, game_buffers) + ) + assert not not_enough_data, f"Rank {rank}: 某些任务的数据量不足以进行训练。请确保所有任务的 replay buffer 中有足够的数据。" + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练前的 barrier') + except Exception as e: + logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') + break # 或者进行其他错误处理 + + # 学习策略 + if not not_enough_data: + # Learner 将在一次迭代中训练 update_per_collect 次 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate( + zip(cfgs, collectors, game_buffers) + ): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + if ( + i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition) + ): + with timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析计数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析时间: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 追加 task_id,以便在训练时区分任务 + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Replay buffer 中的数据不足以采样一个 mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP 会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate( + zip(cfgs, game_buffers) + ): + # 更新任务特定的 replay buffer 的优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 运行均值的平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # 如果不存在,则初始化运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # 更新运行均值 + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = ( + current_priorities - running_mean_priority + ) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回 replay buffer + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 如果设置了 print_task_priority_logs 标志,则记录统计信息 + if cfg.policy.print_task_priority_logs: + print( + f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}" + ) + + train_epoch += 1 + + # 同步所有 Rank,确保所有 Rank 都完成了训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的 barrier') + except Exception as e: + logging.error(f'Rank {rank}: Barrier 失败,错误: {e}') + break # 或者进行其他错误处理 + + # 检查是否需要终止训练 + try: + # local_envsteps 不再需要填充 + local_envsteps = [collector.envstep for collector in collectors] + + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + # 将所有 envsteps 拼接在一起 + all_envsteps = torch.cat([ + torch.tensor(envsteps, device=cfg.policy.device) + for envsteps in total_envsteps + ]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的 train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any( + torch.stack(all_train_iters) >= max_train_iter + ) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 满足终止条件') + dist.barrier() # 确保所有进程同步 + break + else: + pass + + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break # 或者进行其他错误处理 + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index cb5712d0b..d09f963b7 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -136,6 +136,9 @@ def train_unizero( else: world_size = 1 rank = 0 + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + # import sys; sys.exit(0) while True: # Log memory usage of the replay buffer diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py new file mode 100644 index 000000000..8c3d6c15f --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -0,0 +1,722 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.nn.functional as F + +import torch.distributed as dist + +import concurrent.futures + + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + +timer = EasyTimer() + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely执行评估任务,避免超时。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的rank。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: 如果评估成功,返回停止标志和奖励,否则返回(None, None)。 + """ + try: + print(f"=========评估开始 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交评估任务 + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") + return None, None + + print(f"======评估结束 Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"Rank {rank}/{world_size} 评估过程中发生错误: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers, + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的收集剧集数反比分配batch_size,并动态调整batch_size范围以提高训练稳定性和效率。 + + Args: + cfgs (List[dict]): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的重放缓冲区实例列表。 + alpha (float, optional): 控制反比程度的超参数。默认为1.0。 + clip_scale (int, optional): 动态调整的clip比例。默认为1。 + + Returns: + List[int]: 分配后的batch_size列表。 + """ + # 提取每个任务的 collected episodes 数量 + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 collected episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 collected episodes 合并为一个大列表 + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # 计算总的batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + +import numpy as np + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 归一化,减少目标值的幅度差异。 + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 的逆操作,用于恢复原始值。 + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + +# 全局最大值和最小值(用于 "run-max-min") +GLOBAL_MAX = -float('inf') +GLOBAL_MIN = float('inf') + +def compute_task_weights( + task_rewards: dict, + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, # 是否使用 Softmax + reverse: bool = False, # 正比 (False) 或反比 (True) + clip_min: float = 1e-2, # 权重的最小值 + clip_max: float = 1.0, # 权重的最大值 +) -> dict: + """ + 改进后的任务权重计算函数,支持多种标准化方式、Softmax 和正反比权重计算,并增加权重范围裁剪功能。 + + Args: + task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励或损失。 + option (str): 标准化方式,可选值为 "symlog", "max-min", "run-max-min", "rank", "none"。 + epsilon (float): 避免分母为零的小值。 + temperature (float): 控制权重分布的温度系数。 + use_softmax (bool): 是否使用 Softmax 进行权重分配。 + reverse (bool): 若为 True,权重与值反比;若为 False,权重与值正比。 + clip_min (float): 权重的最小值,用于裁剪。 + clip_max (float): 权重的最大值,用于裁剪。 + + Returns: + dict: 每个任务的权重,键为 task_id,值为归一化后的权重。 + """ + import torch + import torch.nn.functional as F + + global GLOBAL_MAX, GLOBAL_MIN + + # 如果输入为空字典,直接返回空结果 + if not task_rewards: + return {} + + # Step 1: 对 task_rewards 的值构造张量 + task_ids = list(task_rewards.keys()) + rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32) + + if option == "symlog": + # 使用 symlog 标准化 + scaled_rewards = symlog(rewards_tensor) + elif option == "max-min": + # 使用最大最小值归一化 + max_reward = rewards_tensor.max().item() + min_reward = rewards_tensor.min().item() + scaled_rewards = (rewards_tensor - min_reward) / (max_reward - min_reward + epsilon) + elif option == "run-max-min": + # 使用全局最大最小值归一化 + GLOBAL_MAX = max(GLOBAL_MAX, rewards_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, rewards_tensor.min().item()) + scaled_rewards = (rewards_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + elif option == "rank": + # 使用 rank 标准化 + # Rank 是基于值大小的排名,1 表示最小值,越大排名越高 + sorted_indices = torch.argsort(rewards_tensor) + scaled_rewards = torch.empty_like(rewards_tensor) + rank_values = torch.arange(1, len(rewards_tensor) + 1, dtype=torch.float32) # 1 到 N + scaled_rewards[sorted_indices] = rank_values + elif option == "none": + # 不进行标准化 + scaled_rewards = rewards_tensor + else: + raise ValueError(f"Unsupported option: {option}") + + # Step 2: 根据 reverse 确定权重是正比还是反比 + if not reverse: + # 正比:权重与值正相关 + raw_weights = scaled_rewards + else: + # 反比:权重与值负相关 + # 避免 scaled_rewards 为负数或零 + scaled_rewards = torch.clamp(scaled_rewards, min=epsilon) + raw_weights = 1.0 / scaled_rewards + + # Step 3: 根据是否使用 Softmax 进行权重计算 + if use_softmax: + # 使用 Softmax 进行权重分配 + beta = 1.0 / max(temperature, epsilon) # 确保 temperature 不为零 + logits = -beta * raw_weights + softmax_weights = F.softmax(logits, dim=0).numpy() + weights = dict(zip(task_ids, softmax_weights)) + else: + # 不使用 Softmax,直接计算权重 + # 温度缩放 + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # 确保温度不为零 + + # 归一化权重 + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / total_weight + + # 转换为字典 + weights = dict(zip(task_ids, normalized_weights.numpy())) + + # Step 4: Clip 权重范围 + for task_id in weights: + weights[task_id] = max(min(weights[task_id], clip_max), clip_min) + + return weights + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + UniZero的训练入口,旨在通过解决MuZero类算法在需要捕捉长期依赖环境中的局限性,提高强化学习代理的规划能力。 + 详细信息请参阅 https://arxiv.org/abs/2406.10667。 + + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): 不同任务的配置列表。 + - seed (:obj:`int`): 随机种子。 + - model (:obj:`Optional[torch.nn.Module]`): torch.nn.Module实例。 + - model_path (:obj:`Optional[str]`): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (:obj:`Optional[int]`): 训练中的最大策略更新迭代次数。 + - max_env_step (:obj:`Optional[int]`): 最大收集环境交互步数。 + + Returns: + - policy (:obj:`Policy`): 收敛的策略。 + """ + # 初始化温度调度器 + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) # 训练步数达到 10k 时,温度降至 1.0 + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' # 或 'exponential' + ) + + # 获取当前进程的rank和总进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任务,继续执行。") + # 初始化空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # 使用第一个任务的配置创建共享的policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的策略类型受支持 + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'配置的设备: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 加载预训练模型(如果提供) + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'完成加载模型: {model_path}') + + # 创建TensorBoard日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # 处理当前进程分配到的每个任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建不同的game buffer、collector和evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # 调用learner的before_run钩子 + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_complexity_weight = cfg.policy.task_complexity_weight + use_task_exploitation_weight = cfg.policy.use_task_exploitation_weight + task_exploitation_weight = None + + # 创建任务奖励字典 + task_rewards = {} # {task_id: reward} + + while True: + # 动态调整batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # 记录缓冲区内存使用情况 + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的epsilon值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 判断是否需要进行评估 + # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug + # if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + + # =========TODO========= + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # 执行安全评估 + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + task_rewards[cfg.policy.task_id] = float('inf') # 如果评估失败,将任务难度设为最大值 + else: + # 确保从评估结果中提取 `eval_episode_return_mean` 作为奖励值 + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}") + task_rewards[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"提取评估奖励时发生错误: {e}") + task_rewards[cfg.policy.task_id] = float('inf') # 出现问题时,将奖励设为最大值 + + + print('=' * 20) + print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新重放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # # ===== only for debug ===== + # if train_epoch > 2: + # with timer: + # replay_buffer.reanalyze_buffer(2, policy) + # buffer_reanalyze_count += 1 + # logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + # logging.info(f'缓冲区重新分析耗时: {timer.value}') + # # ===== only for debug ===== + + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 获取当前温度 + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + # collector._policy._task_weight_temperature = current_temperature_task_weight + # policy.collect_mode.get_attribute('task_weight_temperature') = current_temperature_task_weight + + # 计算任务权重 + try: + # 汇聚任务奖励 + dist.barrier() + if task_complexity_weight: + all_task_rewards = [None for _ in range(world_size)] + dist.all_gather_object(all_task_rewards, task_rewards) + # 合并任务奖励 + merged_task_rewards = {} + for rewards in all_task_rewards: + if rewards: + merged_task_rewards.update(rewards) + # 计算全局任务权重 + task_weights = compute_task_weights(merged_task_rewards, temperature=current_temperature_task_weight) + # 同步任务权重 + dist.broadcast_object_list([task_weights], src=0) + print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}") + else: + task_weights = None + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + break + + + # 学习策略 + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # 追加task_id以区分任务 + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓冲区中的数据不足以采样mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # learn_kwargs = {'task_exploitation_weight':task_exploitation_weight, 'task_weights':task_weights, } + learn_kwargs = {'task_weights':task_exploitation_weight} + + # 在训练时,DDP会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # 判断是否需要计算task_exploitation_weight + if i == 0: + # 计算任务权重 + try: + dist.barrier() # 等待所有进程同步 + if use_task_exploitation_weight: + # 收集所有任务的 obs_loss + all_obs_loss = [None for _ in range(world_size)] + # 构建当前进程的任务 obs_loss 数据 + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] + # 汇聚所有进程的 obs_loss 数据 + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + # 合并所有进程的 obs_loss 数据 + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + # 计算全局任务权重 + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + # temperature=current_temperature_task_weight # TODO + temperature=1, + ) + # 广播任务权重到所有进程 + dist.broadcast_object_list([task_exploitation_weight], src=0) + print(f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: 未能计算全局 obs_loss 任务权重,obs_loss 数据为空。") + task_exploitation_weight = None + else: + task_exploitation_weight = None + # 更新训练参数,使其包含计算后的任务权重 + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + raise e # 保留异常抛出,便于外部捕获和分析 + + + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # 更新任务特定的重放缓冲区优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回重放缓冲区 + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 记录优先级统计信息 + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有Rank,确保所有Rank完成训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 检查是否需要终止训练 + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 达到终止条件') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break + + # 调用learner的after_run钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_eval.py b/lzero/entry/train_unizero_multitask_segment_eval.py new file mode 100644 index 000000000..f98e4c41b --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_eval.py @@ -0,0 +1,480 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + +import torch.distributed as dist +import concurrent.futures + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely evaluates the policy using the evaluator with a timeout. + + Args: + evaluator (Evaluator): The evaluator instance. + learner (BaseLearner): The learner instance. + collector (Collector): The collector instance. + rank (int): The rank of the current process. + world_size (int): Total number of processes. + + Returns: + Tuple[Optional[bool], Optional[float]]: A tuple containing the stop flag and reward. + """ + try: + print(f"=========before eval Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交 evaluator.eval 任务 + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 evaluator 的 stop_event + evaluator.stop_event.set() + print(f"Eval operation timed out after {TIMEOUT} seconds on Rank {rank}/{world_size}.") + return None, None + + print(f"======after eval Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[Any], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + Allocates batch sizes inversely proportional to the number of collected episodes for each task. + Dynamically adjusts batch size within a specified range to enhance training stability and efficiency. + + Args: + cfgs (List[Any]): List of configurations for each task. + game_buffers (List[GameBuffer]): List of replay buffer instances for each task. + alpha (float): The hyperparameter controlling the degree of inverse proportionality. Default is 1.0. + clip_scale (int): The scaling factor to clip the batch size. Default is 1. + + Returns: + List[int]: A list of allocated batch sizes for each task. + """ + # 提取每个任务的 num_of_collected_episodes + buffer_num_of_collected_episodes = [ + buffer.num_of_collected_episodes for buffer in game_buffers + ] + + # 获取当前的 world_size 和 rank + world_size = get_world_size() + rank = get_rank() + + # 收集所有 rank 的 num_of_collected_episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + dist.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 num_of_collected_episodes 拼接成一个大列表 + all_task_num_of_collected_episodes = [ + item for sublist in all_task_num_of_collected_episodes for item in sublist + ] + if rank == 0: + print(f'all_task_num_of_collected_episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([ + 1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes + ]) + inv_sum = np.sum(inv_episodes) + + # 计算总的 batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + # 返回最终分配的 batch_size 列表 + return batch_sizes + + +def train_unizero_multitask_segment_eval( + input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models". + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + + Args: + input_cfg_list (List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]]): + List of configurations for different tasks. Each item is a tuple containing a task ID and a tuple of configuration dictionaries. + seed (int): + Random seed for reproducibility. + model (Optional[torch.nn.Module]): + Instance of torch.nn.Module representing the model. If None, a new model will be created. + model_path (Optional[str]): + Path to a pretrained model checkpoint. Should point to the ckpt file of the pretrained model. + max_train_iter (Optional[int]): + Maximum number of policy update iterations during training. Default is a very large number. + max_env_step (Optional[int]): + Maximum number of environment interaction steps to collect. Default is a very large number. + + Returns: + 'Policy': + The converged policy after training. + """ + # 获取当前进程的 rank 和总的进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: No tasks assigned, continuing without tasks.") + # 初始化一些空列表以避免后续代码报错 + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") + + cfgs: List[Any] = [] + game_buffers: List[GameBuffer] = [] + collectors: List[Collector] = [] + evaluators: List[Evaluator] = [] + + # 使用本rank的第一个任务的配置来创建共享的 policy + task_id, (cfg, create_cfg) = tasks_for_this_rank[0] + + # 设置每个任务的 task_num 以用于 learner_log + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的 policy 类型是支持的 + assert create_cfg.policy.type in [ + 'unizero_multitask'], "train_unizero entry now only supports 'unizero_multitask'" + + # 根据 CUDA 可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的 policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 如果指定了预训练模型,则加载 + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # 创建 TensorBoard 的日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的 learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + # 只处理当前进程分配到的任务 + for local_task_id, (task_id, (cfg, create_cfg)) in enumerate(tasks_for_this_rank): + # 设置每个任务自己的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 为每个任务创建不同的 game buffer、collector、evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + while True: + # 预先计算位置嵌入矩阵(如果需要) + # policy._collect_model.world_model.precompute_pos_emb_diff_kv() + # policy._target_model.world_model.precompute_pos_emb_diff_kv() + + if cfg.policy.allocated_batch_sizes: + # 动态调整 clip_scale 随着 train_epoch 从 0 增加到 1000, clip_scale 从 1 线性增加到 4 + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for cfg, _collector, _evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的 epsilon 值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} evaluates task_id: {cfg.policy.task_id}...') + + # 在训练进程中调用 safe_eval + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} encountered an issue during evaluation. Continuing training...") + else: + print(f"Evaluation successful: stop={stop}, reward={reward}") + + print('=' * 20) + print(f'entry: Rank {rank} collects task_id: {cfg.policy.task_id}...') + + # NOTE: 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新 replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # 每 <1/buffer_reanalyze_freq> 个训练 epoch 重新分析一次缓冲区 + if (train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier before training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # 学习策略 + if not not_enough_data: + # Learner 将在一次迭代中训练 update_per_collect 次 + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for cfg, collector, replay_buffer in zip(cfgs, collectors, game_buffers): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + # 在一个训练 epoch 中重新分析缓冲区 次 + if (i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > + int(reanalyze_batch_size / cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + # 每个重新分析过程将重新分析 个序列 + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + # 追加 task_id,以便在训练时区分任务 + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # 在训练时,DDP 会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + # 同步训练前所有 rank 的准备状态 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier during training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # TODO: 可选:终止进程 + import sys + sys.exit(0) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有 Rank,确保所有 Rank 都完成了训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed barrier after training') + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed with error {e}') + break # 或者进行其他错误处理 + + # 检查是否需要终止训练 + try: + # 收集本地的 envsteps + local_envsteps = [collector.envstep for collector in collectors] + + # 收集所有进程的 envsteps + total_envsteps: List[Optional[int]] = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + # 将所有 envsteps 拼接在一起进行检查 + all_envsteps = torch.cat([ + torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps + ]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的 train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: Termination condition met') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: Termination check failed with error {e}') + break # 或者进行其他错误处理 + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index e107beae6..525ac0812 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -9,7 +9,49 @@ import torch -import torch.distributed as dist +import numpy as np +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt + +class TemperatureScheduler: + def __init__(self, initial_temp: float, final_temp: float, threshold_steps: int, mode: str = 'linear'): + """ + 温度调度器,用于根据当前训练步数逐渐调整温度。 + + Args: + initial_temp (float): 初始温度值。 + final_temp (float): 最终温度值。 + threshold_steps (int): 温度衰减到最终温度所需的训练步数。 + mode (str): 衰减方式,可选 'linear' 或 'exponential'。默认 'linear'。 + """ + self.initial_temp = initial_temp + self.final_temp = final_temp + self.threshold_steps = threshold_steps + assert mode in ['linear', 'exponential'], "Mode must be 'linear' or 'exponential'." + self.mode = mode + + def get_temperature(self, current_step: int) -> float: + """ + 根据当前步数计算温度。 + + Args: + current_step (int): 当前的训练步数。 + + Returns: + float: 当前温度值。 + """ + if current_step >= self.threshold_steps: + return self.final_temp + progress = current_step / self.threshold_steps + if self.mode == 'linear': + temp = self.initial_temp - (self.initial_temp - self.final_temp) * progress + elif self.mode == 'exponential': + # 指数衰减,确保温度逐渐接近 final_temp + decay_rate = np.log(self.final_temp / self.initial_temp) / self.threshold_steps + temp = self.initial_temp * np.exp(decay_rate * current_step) + temp = max(temp, self.final_temp) + return temp def is_ddp_enabled(): """ @@ -102,7 +144,7 @@ def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], Returns: - zeros (:obj:`torch.Tensor`): The zeros tensor. """ - if isinstance(observation_shape, (list, tuple)): + if isinstance(observation_shape, (list,tuple)): shape = [batch_size, *observation_shape] elif isinstance(observation_shape, int): shape = [batch_size, observation_shape] @@ -146,7 +188,7 @@ def random_collect( collector.reset_policy(policy.collect_mode) -def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: +def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter, task_id=0) -> None: """ Overview: Log the memory usage of the buffer and the current process to TensorBoard. @@ -157,9 +199,9 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa """ # "writer is None" means we are in a slave process in the DDP setup. if writer is not None: - writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter) - writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter) - writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) + writer.add_scalar(f'Buffer/num_of_all_collected_episodes_{task_id}', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar(f'Buffer/num_of_game_segments_{task_id}', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar(f'Buffer/num_of_transitions_{task_id}', len(buffer.game_segment_game_pos_look_up), train_iter) game_segment_buffer = buffer.game_segment_buffer @@ -170,7 +212,7 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) # Record the memory usage of self.game_segment_buffer to TensorBoard. - writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter) + writer.add_scalar(f'Buffer/memory_usage/game_segment_buffer_{task_id}', buffer_memory_usage_mb, train_iter) # Get the amount of memory currently used by the process (in bytes). process = psutil.Process(os.getpid()) @@ -180,7 +222,7 @@ def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: Summa process_memory_usage_mb = process_memory_usage / (1024 * 1024) # Record the memory usage of the process to TensorBoard. - writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) + writer.add_scalar(f'Buffer/memory_usage/process_{task_id}', process_memory_usage_mb, train_iter) def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index fe5e28090..1097636f3 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,22 +102,23 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int) -> Tuple: + def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: """ Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. """ - assert self._beta > 0 + assert self._beta > 0, "Beta should be greater than 0" num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: + if not self._cfg.use_priority: + # If priority is not used, set all priorities to 1 self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability @@ -126,20 +127,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # sample according to transition index batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part + + if self._cfg.reanalyze_outdated: + # Sort the batch indices if reanalyze is enabled batch_index_list.sort() - + + # Calculate weights for the sampled transitions weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() + weights_list /= weights_list.max() # Normalize weights game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx + game_segment_idx -= self.base_idx # Adjust index based on base index game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -151,14 +153,10 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # Indices exceeding `game_segment_length` are padded with the next segment and are not updated # in the current implementation. Therefore, we need to sample `pos_in_game_segment` within # [0, game_segment_length - num_unroll_steps] to avoid padded data. - # TODO: Consider increasing `self._cfg.game_segment_length` to ensure sampling efficiency. - # if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: - # pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() - # NOTE: Sample the init position from the whole segment, but not from the padded part - if pos_in_game_segment >= self._cfg.game_segment_length: - pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() + if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps: + pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) @@ -166,6 +164,12 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: make_time = [time.time() for _ in range(len(batch_index_list))] orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + if print_priority_logs: + print(f"Sampled batch indices: {batch_index_list}") + print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") + print(f"Sampled weights: {weights_list}") + return orig_data def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: @@ -589,7 +593,8 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + if isinstance(self._cfg.batch_size, int): + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -601,8 +606,15 @@ def remove_oldest_data_to_fit(self) -> None: # find the max game_segment index to keep in the buffer index = i break - if total_transition >= self._cfg.batch_size: - self._remove(index + 1) + if isinstance(self._cfg.batch_size, int): + if total_transition >= self._cfg.batch_size: + self._remove(index + 1) + else: + try: + if total_transition >= self._cfg.batch_size[0]: + self._remove(index + 1) + except Exception as e: + print(e) def _remove(self, excess_game_segment_index: List[int]) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 1e4c9d698..664b2042f 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -61,6 +61,19 @@ def __init__(self, cfg: dict): self.sample_times = 0 self.active_root_num = 0 + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + def reset_runtime_metrics(self): """ Overview: @@ -146,7 +159,7 @@ def sample( self.compute_target_re_time += self._compute_target_timer.value batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -448,17 +461,21 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device) # calculate the target value - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + + + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -573,17 +590,20 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -591,7 +611,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -603,7 +623,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model with self._origin_search_timer: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + self.origin_search_time += self._origin_search_timer.value else: # python mcts_tree @@ -613,7 +637,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -629,7 +657,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -638,7 +666,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: # Update the data in game segment: @@ -655,7 +683,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -684,7 +712,7 @@ def _compute_target_policy_non_reanalyzed( - game_segment_lens - action_mask_segment - to_play_segment - - policy_shape: self._cfg.model.action_space_size + - policy_shape: self.action_space_size Returns: - batch_target_policies_non_re """ @@ -707,7 +735,7 @@ def _compute_target_policy_non_reanalyzed( ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -757,6 +785,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] + target_batch = [batch_rewards, batch_target_values, batch_target_policies] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index 651d7e4ef..5af5228a2 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -48,9 +48,19 @@ def __init__(self, cfg: dict): self.game_segment_buffer = [] self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - # self.task_id = self._cfg.task_id self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + + def reanalyze_buffer( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -112,21 +122,22 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -135,7 +146,7 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -274,18 +285,18 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -294,7 +305,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -323,7 +334,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: if self._cfg.model.continuous_action_space: # pad random action bootstrap_action_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp)) ] bootstrap_action_list.append(bootstrap_action_tmp) @@ -486,6 +497,12 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # calculate the target value # batch_action.shape (32, 10) # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 + + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= @@ -511,18 +528,24 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # cpp mcts_tree # roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots = MCTSCtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, + transition_batch_size, legal_actions, self.action_space_size, self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space ) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -626,7 +649,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ @@ -644,9 +667,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the target value # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) # ====================================================================== + # print(f'model.training:{model.training}') + # model.training = False + # if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ @@ -655,6 +684,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_output.policy_logits ] ) + network_output.append(m_output) if self._cfg.use_root_value: diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 6208ce24a..38c1935ea 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -48,6 +49,19 @@ def __init__(self, cfg: dict): self.game_segment_game_pos_look_up = [] self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -78,7 +92,7 @@ def sample( # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1], current_batch[-1]) # current_batch[1] is batch_action batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -95,6 +109,7 @@ def sample( train_data = [current_batch, target_batch] return train_data + #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -139,6 +154,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # TODO: original buffer mask + # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] + # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action actions_tmp += [ np.random.randint(0, game.action_space_size) @@ -412,11 +431,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -432,18 +451,25 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # To obtain the target policy from MCTS guided by the recent target model # TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # ======================================================================= - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) + # if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -451,7 +477,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -459,13 +485,20 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + # MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -476,7 +509,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: distributions = roots_distributions[policy_index] if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -485,7 +518,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: if self._cfg.env_type == 'not_board_games': @@ -495,7 +528,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -540,9 +573,16 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + # ====================================================================== + # if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ad216d196..2c45b328b 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -31,7 +31,7 @@ class GameSegment: - store_search_stats """ - def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: + def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None, task_id = None) -> None: """ Overview: Init the ``GameSegment`` according to the provided arguments. @@ -45,19 +45,31 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor - self.action_space_size = config.model.action_space_size + if not hasattr(config.model, "action_space_size_list"): + self.action_space_size = config.model.action_space_size self.gray_scale = config.gray_scale self.transform2string = config.transform2string self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder - if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape - elif len(config.model.observation_shape) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if task_id is None: + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + else: + if hasattr(config.model, "observation_shape_list"): + if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape_list[task_id] + elif len(config.model.observation_shape_list[task_id]) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) + else: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp index 7c5d11dd2..83f50e2da 100644 --- a/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_muzero/lib/cnode.cpp @@ -22,6 +22,7 @@ #include #include + #ifdef _WIN32 #include "..\..\common_lib\utils.cpp" #else diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index baef554a5..b9041c639 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -15,6 +15,7 @@ from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree +from line_profiler import line_profiler class UniZeroMCTSCtree(object): """ @@ -71,10 +72,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] + List[Any]], timestep: Union[int, List[Any]]=None, task_id=None ) -> None: """ Overview: @@ -132,7 +133,24 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) - latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + # latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + try: + # print ("latent_state_roots.shape:", latent_state_roots.shape) + # print ("latent_states[0].shape:", latent_states[0].shape) + # print ("latent_states[1].shape:", latent_states[1].shape) + # import ipdb; ipdb.set_trace() + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + except Exception as e: + print("="*20) + print(e) + # print("latent_states raw:", latent_states) + print("roots:", roots, "latent_state_roots:", latent_state_roots) + print ("latent_state_roots.shape:", latent_state_roots.shape) + # if not all(isinstance(x, np.ndarray) and x.shape == latent_states[0].shape for x in latent_states): + # raise ValueError(f"Inconsistent latent_states shapes: {[x.shape if isinstance(x, np.ndarray) else type(x) for x in latent_states]}") + import ipdb; ipdb.set_trace() + + # TODO: .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() @@ -149,7 +167,22 @@ def search( # search_depth is used for rope in UniZero search_depth = results.get_search_len() # print(f'simulation_index:{simulation_index}, search_depth:{search_depth}, latent_state_index_in_search_path:{latent_state_index_in_search_path}') - network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) + if timestep is None: + # for UniZero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth) + else: + # for UniZero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -230,10 +263,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -303,6 +336,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value)) @@ -500,7 +540,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, diff --git a/lzero/mcts/tree_search/mcts_ctree_sampled.py b/lzero/mcts/tree_search/mcts_ctree_sampled.py index b5143d4f8..19c9f0140 100644 --- a/lzero/mcts/tree_search/mcts_ctree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ctree_sampled.py @@ -82,7 +82,7 @@ def roots( # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] + List[Any]], timestep: Union[int, List[Any]], task_id=None ) -> None: """ Overview: @@ -140,7 +140,18 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) + # try: latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + # except Exception as e: + # print("="*20) + # print(e) + # # print("latent_states raw:", latent_states) + # print("roots:", roots, "latent_state_roots:", latent_state_roots) + # print ("latent_state_roots.shape:", latent_state_roots.shape) + # # if not all(isinstance(x, np.ndarray) and x.shape == latent_states[0].shape for x in latent_states): + # # raise ValueError(f"Inconsistent latent_states shapes: {[x.shape if isinstance(x, np.ndarray) else type(x) for x in latent_states]}") + # import ipdb; ipdb.set_trace() + if self._cfg.model.continuous_action_space is True: # continuous action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device) @@ -159,8 +170,12 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # for Sampled UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, - latent_state_index_in_search_path, timestep) + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, timestep, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -169,6 +184,9 @@ def search( latent_state_batch_in_search_path.append(network_output.latent_state) + # print("network_output.latent_state.shape:", network_output.latent_state.shape) + + # tolist() is to be compatible with cpp datatype. reward_batch = network_output.reward.reshape(-1).tolist() value_batch = network_output.value.reshape(-1).tolist() diff --git a/lzero/model/common.py b/lzero/model/common.py index 76cd591f2..cc607c2af 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -273,6 +273,8 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, super().__init__() assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + assert num_resblocks == 1, "num_resblocks must be 1 in DownSample" + self.observation_shape = observation_shape self.conv1 = nn.Conv2d( observation_shape[0], @@ -319,7 +321,7 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, [ ResBlock( in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(1) + ) for _ in range(num_resblocks) ] ) self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) @@ -455,12 +457,27 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: return cls_embedding +from torch.nn.utils import weight_norm + +# AdaptiveFeatureScaler:在对 1D 向量进行 scaling 时,加入 clamp 限制,避免 runaway +class AdaptiveFeatureScaler(nn.Module): + def __init__(self, init_scale=0.1, max_scale=1.0): + super().__init__() + self.scale = nn.Parameter(torch.tensor(init_scale)) + self.max_scale = max_scale + + def forward(self, x): + # 限制 scale 参数的最大值,避免数值爆炸 + clamped_scale = torch.clamp(self.scale, 0.0, self.max_scale) + return x * clamped_scale / math.sqrt(x.size(1)) + +# 假设 SimNorm, ResBlock, DownSample 在其他地方已经定义 +# 下面仅给出 RepresentationNetworkUniZero 的实现 class RepresentationNetworkUniZero(nn.Module): - def __init__( self, - observation_shape: SequenceType = (3, 64, 64), + observation_shape: tuple = (3, 64, 64), num_res_blocks: int = 1, num_channels: int = 64, downsample: bool = True, @@ -468,77 +485,112 @@ def __init__( norm_type: str = 'BN', embedding_dim: int = 256, group_size: int = 8, + final_norm_option_in_encoder: str = 'SimNorm', + use_adaptive_scale: bool = False + # use_global_pooling: bool = True # 新增超参数:是否使用全局平均池化 + # use_global_pooling: bool = False # 新增超参数:是否使用全局平均池化 ) -> None: """ - Overview: - Representation network used in UniZero. Encode the 2D image obs into latent state. - Currently, the network only supports obs images with both a width and height of 64. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - num_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - embedding_dim (:obj:`int`): The dimension of the latent state. - - group_size (:obj:`int`): The dimension for simplicial normalization. + Representation network used in UniZero. + 对于 channel 数较大的场景,可使用全局平均池化来降低全连接层的输入维度,提高训练稳定性。 """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - logging.info(f"Using norm type: {norm_type}") - logging.info(f"Using activation type: {activation}") + assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" + # 打印日志信息(可选) + print(f"Using norm type: {norm_type}") + print(f"Using activation type: {activation}") + + self.use_global_pooling = False self.observation_shape = observation_shape self.downsample = downsample + if self.downsample: + # DownSample 对象的实现需自行定义 self.downsample_net = DownSample( observation_shape, num_channels, activation=activation, norm_type=norm_type, + num_resblocks=1, ) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) - if norm_type == 'BN': self.norm = nn.BatchNorm2d(num_channels) elif norm_type == 'LN': - if downsample: - self.norm = nn.LayerNorm( - [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - else: - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + # 当不进行 downsample 时,观察图尺寸不变 + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + # 构建 residual block 层 self.resblocks = nn.ModuleList( [ ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + in_channels=num_channels, + activation=activation, + norm_type=norm_type, + res_type='basic', + bias=False ) for _ in range(num_res_blocks) ] ) self.activation = activation self.embedding_dim = embedding_dim + # 根据观察图尺寸确定空间维度 if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) - + spatial_size = 8 elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + spatial_size = 6 + else: + spatial_size = self.observation_shape[1] # 默认采用输入H + + if self.observation_shape[1] == 64: + last_linear_in_dim = num_channels * 8 * 8 + elif self.observation_shape[1] in [84, 96]: + last_linear_in_dim = num_channels * 6 * 6 + else: + # 默认采用完整 flatten 的维度 + last_linear_in_dim = num_channels * self.observation_shape[1] * self.observation_shape[2] - self.sim_norm = SimNorm(simnorm_dim=group_size) + self.last_linear = nn.Linear(last_linear_in_dim, self.embedding_dim, bias=False) + + + # 根据是否使用全局平均池化决定 last_linear 前的输入维度以及 norm 的形状 + if self.use_global_pooling: + linear_in_dim = num_channels # 全局池化后形状: (B, num_channels, 1, 1) + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + # 对 1D 向量使用 LayerNorm + self.norm_before_last_linear = nn.LayerNorm(linear_in_dim, eps=1e-5) + else: + linear_in_dim = num_channels * spatial_size * spatial_size + if use_adaptive_scale: + # 若通过 flatten 后进行 adaptive scaling,对 1D 向量归一化 + self.norm_before_last_linear = nn.LayerNorm(linear_in_dim, eps=1e-5) + else: + # 保留空间信息时,在 (C, H, W) 上归一化 + self.norm_before_last_linear = nn.LayerNorm([num_channels, spatial_size, spatial_size], eps=1e-5) + + self.last_linear = nn.Linear(linear_in_dim, self.embedding_dim, bias=False) + + self.use_adaptive_scale = use_adaptive_scale + if self.use_adaptive_scale: + self.adaptive_scaler = AdaptiveFeatureScaler(init_scale=0.1, max_scale=1.0) + + # 最后归一化层,根据 final_norm_option_in_encoder 进行选择 + if final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif final_norm_option_in_encoder == 'SimNorm': + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {final_norm_option_in_encoder}") def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + Args: + x: (B, C_in, H, W) + Returns: + x: (B, embedding_dim) """ if self.downsample: x = self.downsample_net(x) @@ -546,19 +598,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.norm(x) x = self.activation(x) + + # 依次通过多个 residual block for block in self.resblocks: x = block(x) + + # 分支1:使用全局平均池化 + if self.use_global_pooling: + x = self.global_pool(x) # 输出 shape: (B, num_channels, 1, 1) + x = x.view(x.size(0), -1) # 展平为 (B, num_channels) + x = self.norm_before_last_linear(x) # 对 1D 向量做归一化 + else: + # 分支2:不使用全局池化 + if self.use_adaptive_scale: + # 若启用 adaptive scaling:先展平再做 fan-in 缩放 + x = x.view(x.size(0), -1) # (B, num_channels * spatial_size^2) + x = self.adaptive_scaler(x) + x = self.norm_before_last_linear(x) # 归一化 1D 向量 + else: + # 保持完整空间信息:在 (B, C, H, W) 上归一化后,再展平 + x = self.norm_before_last_linear(x) + x = x.view(x.size(0), -1) - # Important: Transform the output feature plane to the latent state. - # For example, for an Atari feature plane of shape (64, 8, 8), - # flattening results in a size of 4096, which is then transformed to 768. - x = self.last_linear(x.view(x.size(0), -1)) - - x = x.view(-1, self.embedding_dim) - - # NOTE: very important for training stability. - x = self.sim_norm(x) - + # 最后一层全连接映射与归一化 + x = self.last_linear(x) + x = self.final_norm(x) return x @@ -999,9 +1063,9 @@ def __init__( self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) if observation_shape[1] == 96: - latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) + latent_shape = (observation_shape[1] // 16, observation_shape[2] // 16) elif observation_shape[1] == 64: - latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + latent_shape = (observation_shape[1] // 8, observation_shape[2] // 8) if norm_type == 'BN': self.norm_value = nn.BatchNorm2d(value_head_channels) diff --git a/lzero/model/muzero_model_multitask.py b/lzero/model/muzero_model_multitask.py new file mode 100644 index 000000000..6d7326152 --- /dev/null +++ b/lzero/model/muzero_model_multitask.py @@ -0,0 +1,389 @@ +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('MuZeroMTModel') +class MuZeroMTModel(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + analysis_sim_norm: bool = False, + task_num: int = 1, # 任务数量 + *args, + **kwargs + ): + """ + 多任务MuZero模型的定义,继承自MuZeroModel。 + 增加了多任务相关的处理,如任务数量和动作空间大小调整。 + """ + super(MuZeroMTModel, self).__init__() + + print(f'==========MuZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}, task_num:{task_num}===========') + + if discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + # to be compatible with LightZero model/policy, transform to shape: [C, W, H] + observation_shape = [1, observation_shape, 1] + + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + self.task_num = task_num + self.action_space_size = 18 # 假设每个任务的动作空间相同 + + self.categorical_distribution = categorical_distribution + + self.discrete_action_encoding_type = 'one_hot' + + # 共享表示网络 + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + activation=activation, + norm_type=norm_type + ) + + # ====== for analysis ====== + if analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + # 共享动态网络 + self.dynamics_network = DynamicsNetwork( + observation_shape, + action_encoding_dim=self.action_encoding_dim, + num_res_blocks=num_res_blocks, + num_channels=num_channels + self.action_encoding_dim, + reward_head_channels=reward_head_channels, + fc_reward_layers=fc_reward_layers, + output_support_size=reward_support_size, + flatten_output_size_for_reward_head=reward_head_channels * self._get_latent_size(observation_shape, downsample), + downsample=downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + # 独立的预测网络,每个任务一个 + # 计算flatten_output_size + value_flatten_size = int(value_head_channels * self._get_latent_size(observation_shape, downsample)) + policy_flatten_size = int(policy_head_channels * self._get_latent_size(observation_shape, downsample)) + + self.prediction_networks = nn.ModuleList([ + PredictionNetwork( + observation_shape, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + fc_value_layers, + fc_policy_layers, + self.value_support_size, + flatten_output_size_for_value_head=value_flatten_size, + flatten_output_size_for_policy_head=policy_flatten_size, + downsample=downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) for _ in range(task_num) + ]) + + # 共享投影和预测头(如果使用自监督学习损失) + if self_supervised_learning_loss: + self.projection_network = nn.Sequential( + nn.Linear(num_channels * self._get_latent_size(observation_shape, downsample), proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_out), + nn.BatchNorm1d(proj_out) + ) + + self.prediction_head = nn.Sequential( + nn.Linear(proj_out, pred_hid), + nn.BatchNorm1d(pred_hid), + activation, + nn.Linear(pred_hid, pred_out), + ) + + self.self_supervised_learning_loss = self_supervised_learning_loss + self.state_norm = state_norm + self.downsample = downsample + + def _get_latent_size(self, observation_shape: SequenceType, downsample: bool) -> int: + """ + 辅助函数,根据观测形状和下采样选项计算潜在状态的大小。 + """ + if downsample: + return math.ceil(observation_shape[-2] / 16) * math.ceil(observation_shape[-1] / 16) + else: + return observation_shape[-2] * observation_shape[-1] + + def initial_inference(self, obs: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + 多任务初始推理,基于任务ID选择对应的预测网络。 + """ + batch_size = obs.size(0) + latent_state = self.representation_network(obs) + if self.state_norm: + latent_state = renormalize(latent_state) + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(latent_state) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + 多任务递归推理,根据任务ID选择对应的预测网络。 + """ + next_latent_state, reward = self._dynamics(latent_state, action) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(next_latent_state) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + and ``reward``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action. + # The final action_encoding shape is (batch_size, action_space_size, latent_state[2], latent_state[3]), e.g. (8, 2, 4, 1). + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] + ) + + elif self.discrete_action_encoding_type == 'not_one_hot': + # Stack latent_state with the normalized encoded action. + # The final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). + if len(action.shape) == 2: + # (batch_size, action_dim=1) -> (batch_size, 1, 1, 1) + # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1) + elif len(action.shape) == 1: + # (batch_size,) -> (batch_size, 1, 1, 1) + # e.g., -> torch.Size([8, 1, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + + action_encoding = action.expand( + latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3] + ) / self.action_space_size + + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim, latent_state[2], latent_state[3]) or + # (batch_size, latent_state[1] + action_space_size, latent_state[2], latent_state[3]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, reward + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + 多任务投影方法,当前实现为共享投影网络。 + """ + if not self.self_supervised_learning_loss: + raise NotImplementedError("Self-supervised learning loss is not enabled for this model.") + + latent_state = latent_state.reshape(latent_state.shape[0], -1) + proj = self.projection_network(latent_state) + if with_grad: + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + flatten_output_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + DynamicsNetwork定义,适用于多任务共享。 + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must be in ['BN', 'LN']" + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head + + self.action_encoding_dim = action_encoding_dim + self.conv = nn.Conv2d(num_channels, num_channels - self.action_encoding_dim, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(num_channels - self.action_encoding_dim) + elif norm_type == 'LN': + if downsample: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_common = nn.LayerNorm([num_channels - self.action_encoding_dim, observation_shape[-2], observation_shape[-1]]) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - self.action_encoding_dim, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_reward = nn.Conv2d(num_channels - self.action_encoding_dim, reward_head_channels, 1) + + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_reward = nn.LayerNorm([reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + else: + self.norm_reward = nn.LayerNorm([reward_head_channels, observation_shape[-2], observation_shape[-1]]) + + self.fc_reward_head = MLP( + self.flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + layer_num=len(fc_reward_layers) + 1, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + DynamicsNetwork的前向传播,预测下一个潜在状态和奖励。 + """ + # 提取状态编码(去除动作编码部分) + state_encoding = state_action_encoding[:, :-self.action_encoding_dim, :, :] + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # 残差连接 + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + x = self.conv1x1_reward(next_latent_state) + x = self.norm_reward(x) + x = self.activation(x) + x = x.view(x.shape[0], -1) + + # 使用全连接层预测奖励 + reward = self.fc_reward_head(x) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: + return get_reward_mean(self) \ No newline at end of file diff --git a/lzero/model/sampled_unizero_model_multitask.py b/lzero/model/sampled_unizero_model_multitask.py new file mode 100644 index 000000000..a8c4f850e --- /dev/null +++ b/lzero/model/sampled_unizero_model_multitask.py @@ -0,0 +1,326 @@ +from typing import Optional, List + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, LatentDecoder, \ + FeatureAndGradientHook, SimNorm +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +class RepresentationNetworkMLPMT(nn.Module): + def __init__( + self, + observation_shape_list: List[int], # List of observation shapes for each task + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_shared_projection: bool = False, # 控制是否启用共享投影层 + shared_projection_dim: Optional[int] = None, # 共享投影层的维度 + final_norm_option_in_encoder: str = 'LayerNorm', # TODO + ) -> torch.Tensor: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ + with Multi-Layer Perceptron (MLP), optionally followed by a shared projection layer. + Arguments: + - observation_shape_list (:obj:`List[int]`): The list of observation shape for each task. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - layer_num (:obj:`int`): The number of layers in the MLP. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + - norm_type (:obj:`str`): The type of normalization in networks, defaults to 'BN'. + - group_size (:obj:`int`): The group size used in SimNorm. + - use_shared_projection (:obj:`bool`): Whether to use a shared projection layer, defaults to False. + - shared_projection_dim (:obj:`Optional[int]`): The dimension of the shared projection layer. \ + If None, defaults to `hidden_channels`. + """ + super().__init__() + self.env_num = len(observation_shape_list) + self.use_shared_projection = use_shared_projection + self.hidden_channels = hidden_channels + self.shared_projection_dim = shared_projection_dim or hidden_channels + + # Task-specific representation networks + self.fc_representation = nn.ModuleList([ + MLP( + in_channels=obs_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + for obs_shape in observation_shape_list + ]) + + # Shared projection layer + if self.use_shared_projection: + self.shared_projection = nn.Linear(hidden_channels, self.shared_projection_dim) + # self.projection_norm = nn.LayerNorm(self.shared_projection_dim) # Optional normalization for shared space + self.projection_norm = SimNorm(simnorm_dim=group_size) # Optional normalization for shared space + + self.embedding_dim = embedding_dim + # SimNorm for task-specific outputs + # self.sim_norm = SimNorm(simnorm_dim=group_size) + self.final_norm_option_in_encoder = final_norm_option_in_encoder + if self.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'SimNorm': + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + + + def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - task_id (:obj:`int`): The ID of the current task. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)` if shared projection is not used, \ + otherwise :math:`(B, shared_projection_dim)`. + """ + # Task-specific representation + x = self.fc_representation[task_id](x) + x = self.final_norm(x) + # x = self.sim_norm(x) + + # Shared projection layer (if enabled) + if self.use_shared_projection: + x = self.shared_projection(x) + x = self.projection_norm(x) # Optional normalization + return x + + +# class RepresentationNetworkMLPMT(nn.Module): +# def __init__( +# self, +# observation_shape_list: List[int], # List of observation shapes for each task +# hidden_channels: int = 64, +# layer_num: int = 2, +# activation: nn.Module = nn.GELU(approximate='tanh'), +# norm_type: Optional[str] = 'BN', +# group_size: int = 8, +# ) -> torch.Tensor: +# """ +# Overview: +# Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ +# with Multi-Layer Perceptron (MLP). +# Arguments: +# - observation_shape_list (:obj:`List[int]`): The list of observation shape for each task. +# - hidden_channels (:obj:`int`): The channel of output hidden state. +# - layer_num (:obj:`int`): The number of layers in the MLP. +# - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). +# - norm_type (:obj:`str`): The type of normalization in networks, defaults to 'BN'. +# - group_size (:obj:`int`): The group size used in SimNorm. +# """ +# super().__init__() +# self.env_num = len(observation_shape_list) +# self.fc_representation = nn.ModuleList([ +# MLP( +# in_channels=obs_shape, +# hidden_channels=hidden_channels, +# out_channels=hidden_channels, +# layer_num=layer_num, +# activation=activation, +# norm_type=norm_type, +# # don't use activation and norm in the last layer of representation network is important for convergence. +# output_activation=False, +# output_norm=False, +# # last_linear_layer_init_zero=True is beneficial for convergence speed. +# last_linear_layer_init_zero=True, +# ) +# for obs_shape in observation_shape_list +# ]) +# self.sim_norm = SimNorm(simnorm_dim=group_size) + +# def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: +# """ +# Shapes: +# - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. +# - task_id (:obj:`int`): The ID of the current task. +# - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. +# """ +# x = self.fc_representation[task_id](x) +# x = self.sim_norm(x) +# return x + + +@MODEL_REGISTRY.register('SampledUniZeroMTModel') +class SampledUniZeroMTModel(nn.Module): + def __init__( + self, + observation_shape_list: List[SequenceType], # List of observation shapes for each task + action_space_size_list: List[int], # List of action space sizes for each task + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'LN', + # world_model_cfgs: List[EasyDict] = None, # List of world model configs for each task + world_model_cfg: List[EasyDict] = None, # List of world model configs for each task + *args, + **kwargs + ): + """ + Overview: + The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: + - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. + - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + Arguments: + - observation_shape_list (:obj:`List[SequenceType]`): List of observation space shapes for each task, e.g. [C, W, H]=[3, 64, 64] for Atari. + - action_space_size_list (:obj:`List[int]`): List of action space sizes for each task. + - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. + - num_channels (:obj:`int`): The channels of hidden states in representation network. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj=`str`): The type of normalization in networks. Defaults to 'LN'. + - world_model_cfgs (:obj=`List[EasyDict]`): The list of world model configurations for each task. + """ + super(SampledUniZeroMTModel, self).__init__() + self.task_num = len(observation_shape_list) + self.activation = activation + self.downsample = downsample + + # Initialize environment-specific networks and models + self.representation_networks = nn.ModuleList() + # self.decoder_networks = nn.ModuleList() + # self.world_models = nn.ModuleList() + + if world_model_cfg.task_embed_option == "concat_task_embed": + obs_act_embed_dim = world_model_cfg.embed_dim - world_model_cfg.task_embed_dim if hasattr(world_model_cfg, "task_embed_dim") else 96 + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + for task_id in range(self.task_num): + # world_model_cfg = world_model_cfgs[task_id] + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + + if world_model_cfg.obs_type == 'vector': + self.representation_network = RepresentationNetworkMLPMT( + observation_shape_list=observation_shape_list, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + use_shared_projection=world_model_cfg.use_shared_projection, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # TODO: N independent encoder + for task_id in range(1): # TODO: one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape_list[task_id], + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + )) + # TODO: we should change the output_shape to the real observation shape + # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + + + # Print model parameters for debugging + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of UniZero model, which is the first step of the UniZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. + - task_id (:obj:`int`): The ID of the current task. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj=`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj=`torch.Tensor`): :math=`(B, value_support_size)`, where B is batch_size. + - reward (:obj=`torch.Tensor`): :math=`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj=`torch.Tensor`): :math=`(B, action_dim)`, where B is batch_size. + - latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, + latent_state_index_in_search_path=[], task_id=0) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of UniZero model. To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) + and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Arguments: + - state_action_history (:obj:`torch.Tensor`): The history of states and actions. + - task_id (:obj:`int`): The ID of the current task. + - simulation_index (:obj=`int`): The index of the current simulation. + - latent_state_index_in_search_path (:obj=`List[int]`): The indices of latent states in the search path. + Returns (MZNetworkOutput): + - value (:obj=`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj=`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj=`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj=`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj=`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj=`torch.Tensor`): :math=`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj=`torch.Tensor`): :math=`(B, )`, where B is batch_size. + - value (:obj=`torch.Tensor`): :math=`(B, value_support_size)`, where B is batch_size. + - reward (:obj=`torch.Tensor`): :math=`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj=`torch.Tensor`): :math=`(B, action_dim)`, where B is batch_size. + - latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + - next_latent_state (:obj=`torch.Tensor`): :math=`(B, H_, W_)`, where B is batch_size, H_ is the height of latent state, W_ is the width of latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + reward = reward.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 62e39a2fd..6b092a978 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -81,7 +81,7 @@ def __init__( # TODO: only for MemoryEnv now self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=False) + decoder_network=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -107,13 +107,15 @@ def __init__( norm_type=norm_type, embedding_dim=world_model_cfg.embed_dim, group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, ) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -144,7 +146,7 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py new file mode 100644 index 000000000..71cf60ea6 --- /dev/null +++ b/lzero/model/unizero_model_multitask.py @@ -0,0 +1,256 @@ +from typing import Optional + +import torch +import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +from line_profiler import line_profiler + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('UniZeroMTModel') +class UniZeroMTModel(nn.Module): + + #@profile + def __init__( + self, + observation_shape: SequenceType = (4, 64, 64), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'BN', + world_model_cfg: EasyDict = None, + task_num: int = 1, + *args, + **kwargs + ): + """ + Overview: + The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: + - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. + - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[3, 64, 64] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. + - num_channels (:obj:`int`): The channels of hidden states in representation network. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - world_model_cfg (:obj:`EasyDict`): The configuration of the world model, including the following keys: + - obs_type (:obj:`str`): The type of observation, which can be 'image', 'vector', or 'image_memory'. + - embed_dim (:obj:`int`): The dimension of the embedding. + - group_size (:obj:`int`): The group size of the transformer. + - max_blocks (:obj:`int`): The maximum number of blocks in the transformer. + - max_tokens (:obj:`int`): The maximum number of tokens in the transformer. + - context_length (:obj:`int`): The context length of the transformer. + - device (:obj:`str`): The device of the model, which can be 'cuda' or 'cpu'. + - action_space_size (:obj:`int`): The shape of the action. + - num_layers (:obj:`int`): The number of layers in the transformer. + - num_heads (:obj:`int`): The number of heads in the transformer. + - policy_entropy_weight (:obj:`float`): The weight of the policy entropy. + - analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm. + """ + super(UniZeroMTModel, self).__init__() + + print(f'==========UniZeroMTModel, num_res_blocks:{num_res_blocks}, num_channels:{num_channels}===========') + + self.action_space_size = action_space_size + + # for multi-task + self.action_space_size = 18 + self.task_num = task_num + + self.activation = activation + self.downsample = downsample + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + + if world_model_cfg.task_embed_option == "concat_task_embed": + obs_act_embed_dim = world_model_cfg.embed_dim - world_model_cfg.task_embed_dim if hasattr(world_model_cfg, "task_embed_dim") else 96 + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + if world_model_cfg.obs_type == 'vector': + self.representation_network = RepresentationNetworkMLP( + observation_shape, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: only for MemoryEnv now + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # TODO: N independent encoder + for task_id in range(1): # TODO: one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + use_adaptive_scale=world_model_cfg.use_adaptive_scale, + )) + # self.representation_network = RepresentationNetworkUniZero( + # observation_shape, + # num_res_blocks, + # num_channels, + # self.downsample, + # activation=self.activation, + # norm_type=norm_type, + # embedding_dim=world_model_cfg.embed_dim, + # group_size=world_model_cfg.group_size, + # ) + # TODO: we should change the output_shape to the real observation shape + # self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + + # ====== for analysis ====== + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image_memory': + # todo for concat_task_embed + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + + if world_model_cfg.analysis_sim_norm: + # ====== for analysis ====== + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network, + decoder_network=self.decoder_network, obs_type=world_model_cfg.obs_type) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + print('==' * 20) + + #@profile + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of UniZero model, which is the first step of the UniZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs_batch.size(0) + # print('=here 5='*20) + # import ipdb; ipdb.set_trace() + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + #@profile + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, + latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) + and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + reward = reward.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 28b7b0ba2..f373739c6 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -70,10 +70,59 @@ def update(self, x: torch.Tensor, tokens: int) -> None: - x (:obj:`torch.Tensor`): The new values to update the cache with. - tokens (:obj:`int`): The number of tokens to update. """ - # assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) - # assert self._size + tokens <= self._cache.shape[2] # TODO - self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens) - self._size += tokens + try: + # Calculate the required capacity after adding the new tokens + required_capacity = self._size + tokens + # print(f'self._size:{self._size}, tokens:{tokens}') + + # Check if the cache has enough space to accommodate the new tokens, + # kv_cache, z/a, register_token + # 这样修复后kv_cache的位置编码不是从0开始的, 那后面按照从零开始矫正也就是错误的, + # 但是由于self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1,所以不会矫正 + # 但是在_add_position_embeddings时,prev_steps是错误的,导致新增的z/a的位置编码索引与前面的kv不连续 + if required_capacity > self._cache.shape[2]: + # Shift existing cache data by removing the oldest entries + shift_amount = required_capacity - self._cache.shape[2] + # =======TODO: 应该去掉偶数个(z,a)以保证 head 输出pattern保持不变======= + if shift_amount % 2 != 0: + shift_amount = shift_amount + 1 + # print(f'required_capacity:{required_capacity}, self._cache.shape[2]:{self._cache.shape[2]}, shift_amount:{shift_amount}') + if shift_amount >= self._size: + # If the shift amount exceeds or equals the current size, just reset the cache + print("Cache too small; resetting the entire cache") + self._cache = torch.zeros_like(self._cache) # Reset cache to zeros + self._size = 0 # Reset size + else: + # Shift the cache to make room for new data + self._cache[:, :, :self._size - shift_amount, :] = self._cache[:, :, shift_amount:self._size, :] + self._size -= shift_amount # Update the size after shifting + + # Update the cache with new values + self._cache = AssignWithoutInplaceCheck.apply( + self._cache, x, 2, self._size, self._size + tokens + ) + self._size += tokens # Update the size after adding new values + + except Exception as e: + print(f"An error occurred during cache update: {e}") + + # def update(self, x: torch.Tensor, tokens: int) -> None: + # """ + # Overview: + # Update the cache with new values. + # Arguments: + # - x (:obj:`torch.Tensor`): The new values to update the cache with. + # - tokens (:obj:`int`): The number of tokens to update. + # """ + # # assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) + # # assert self._size + tokens <= self._cache.shape[2] # TODO + # try: + # self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens) + # self._size += tokens + # except Exception as e: + # print(e) + # # import ipdb; ipdb.set_trace() + class KVCache: @@ -91,6 +140,12 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, devi self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) + # self.register_token_num = 2 # Number of register tokens TODO====== + + # def set_register_token_num(self, num: int) -> None: + # """Set the number of register tokens.""" + # self.register_token_num = num + @property def shape(self) -> Tuple[int, int, int, int]: """ @@ -133,14 +188,11 @@ def update(self, k: torch.Tensor, v: torch.Tensor): """ Overview: Update both key and value caches with new values. - Arguments: - - k (:obj:`torch.Tensor`): The new values to update the key cache with. - - v (:obj:`torch.Tensor`): The new values to update the value cache with. + If `is_register_token` is True, prepend the register tokens to the cache. """ self._k_cache.update(k, k.size(2)) self._v_cache.update(v, v.size(2)) - class KeysValues: def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: """ @@ -204,6 +256,18 @@ def prune(self, mask: np.ndarray) -> None: for kv_cache in self._keys_values: kv_cache.prune(mask) + def remove_register_tokens(self, register_token_num: int): + """ + Overview: + 移除所有层 KV 缓存开头的 Register Token。 + 在推理结束后调用,保证外层看到的 KV 不包含 Register Token。 + """ + # import ipdb; ipdb.set_trace() + for kv_cache in self._keys_values: + # 移除 KVCache 中后面的 register_token_num 个 token + kv_cache._k_cache._size -= register_token_num + kv_cache._v_cache._size -= register_token_num + class AssignWithoutInplaceCheck(torch.autograd.Function): """ diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index c6ee6426c..7abd5c062 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -22,11 +22,13 @@ def __init__(self, use_dropout: bool = True): self.chns = [64, 128, 256, 512, 512] # vg16 features # Comment out the following line if you don't need perceptual loss # self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + + # self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + # self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + # self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + # self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + # self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # Comment out the following line if you don't need perceptual loss # self.load_from_pretrained() # for param in self.parameters(): diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py new file mode 100644 index 000000000..159afd69e --- /dev/null +++ b/lzero/model/unizero_world_models/moe.py @@ -0,0 +1,49 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 +class MultiplicationFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + + self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore + +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # if len(self.experts) == 1: + # # 只有一个专家时,直接使用该专家 + # return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + # batch_idx, nth_expert = torch.where(selected_experts == i) + # results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results \ No newline at end of file diff --git a/lzero/model/unizero_world_models/test_moe.py b/lzero/model/unizero_world_models/test_moe.py new file mode 100644 index 000000000..6ab93cc16 --- /dev/null +++ b/lzero/model/unizero_world_models/test_moe.py @@ -0,0 +1,107 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# 定义MoeArgs数据类,用于存储MoE的配置参数 +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + +# 定义Mixture of Experts(MoE)层 +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if len(self.experts) == 1: + # 只有一个专家时,直接使用该专家 + return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results + +# 定义一个简单的Transformer块 +class TransformerBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + if config.moe_in_transformer: + self.feed_forward = MoeLayer( + experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + print("="*20) + print('使用MoE在Transformer的feed_forward中') + print("="*20) + else: + self.feed_forward = self.mlp + + def forward(self, x): + return self.feed_forward(x) + +# 定义配置类 +class Config: + def __init__(self, embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer): + self.embed_dim = embed_dim + self.resid_pdrop = resid_pdrop + self.num_experts_of_moe_in_transformer = num_experts_of_moe_in_transformer + self.moe_in_transformer = moe_in_transformer + +# 测试代码 +def test_transformer_block(): + # 初始化配置 + embed_dim = 64 + resid_pdrop = 0.1 + num_experts_of_moe_in_transformer = 1 + + # 创建输入数据 + inputs = torch.randn(10, 5, embed_dim) # (batch_size, seq_len, embed_dim) + + # 初始化两个输出变量 + outputs_true = None + outputs_false = None + + # 对于moe_in_transformer为True和False分别进行测试 + for moe_in_transformer in [True, False]: + config = Config(embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer) + transformer_block = TransformerBlock(config) + + outputs = transformer_block(inputs) + print(f"moe_in_transformer={moe_in_transformer}: outputs={outputs}") + + if moe_in_transformer: + outputs_true = outputs + else: + outputs_false = outputs + + # 计算输出的差异 + mse_difference = None + if outputs_true is not None and outputs_false is not None: + mse_difference = F.mse_loss(outputs_true, outputs_false).item() + + print(f"输出差异的均方误差(MSE): {mse_difference}") + +if __name__ == "__main__": + test_transformer_block() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index bd066ccec..1e87efb17 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -36,7 +36,7 @@ class Tokenizer(nn.Module): Overview: Tokenizer model that encodes and decodes observations. """ - def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False) -> None: + def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False, obs_type=None) -> None: """Initialize the Tokenizer. Arguments: @@ -53,36 +53,66 @@ def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False) self.encoder = encoder self.decoder_network = decoder_network + self.obs_type = obs_type - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Tensor: """ Encode observations to embeddings. Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). + - x (torch.Tensor): Input tensor of shape (B, ...). Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: Encoded embeddings of shape (B, 1, E). """ shape = x.shape + # TODO: ====== + if task_id is None: + # for compatibility with multitask setting + task_id = 0 + else: + # task_id = 0 # one share encoder + task_id = task_id # TODO: one encoder per task + # print(f'='*20) + # print(f'x.shape:{x.shape}') + # print(f'self.encoder:{self.encoder}') + # Process input tensor based on its dimensionality if len(shape) == 2: # Case when input is 2D (B, E) - obs_embeddings = self.encoder(x) + # obs_embeddings = self.encoder[task_id](x) + obs_embeddings = self.encoder(x, task_id) # TODO: + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 3: # Case when input is 3D (B, T, E) x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - obs_embeddings = self.encoder(x) + # obs_embeddings = self.encoder[task_id](x) + obs_embeddings = self.encoder(x,task_id) # TODO: + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 4: # Case when input is 4D (B, C, H, W) - obs_embeddings = self.encoder(x) + if self.obs_type == 'vector': + obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask + elif self.obs_type == 'image': + try: + obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env + except: + obs_embeddings = self.encoder(x) # TODO: for atari/memory env single-task + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 5: # Case when input is 5D (B, T, C, H, W) x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - obs_embeddings = self.encoder(x) + if self.obs_type == 'vector': + obs_embeddings = self.encoder[task_id](x) + elif self.obs_type == 'image': + try: + obs_embeddings = self.encoder[0](x) # TODO: for atari/memory env + except: + obs_embeddings = self.encoder(x) # TODO: for atari/memory env single-task + obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') else: raise ValueError(f"Invalid input shape: {shape}") diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index c2feb8497..5e2e3e670 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -1,5 +1,9 @@ """ -The following code is modified from https://github.com/karpathy/nanoGPT. +Modified from https://github.com/karpathy/nanoGPT + +在原 transformer.py 基础上增加 LoRA 微调相关代码, +并通过传入配置参数控制 LoRA 微调的模块(默认是 attention 中的 k, q, v, proj 和 feed_forward) +保持原有代码的可扩展性。 """ import numpy as np @@ -15,7 +19,98 @@ from einops import rearrange from .kv_caching import KeysValues +from .moe import MoeLayer, MultiplicationFeedForward +from line_profiler import line_profiler +from lzero.model.common import SimNorm + + +############################################# +# 新增:LoRA 微调相关代码 +############################################# +class LoRALinear(nn.Module): + """ + LoRA 适配器包装的线性层。 + + 原理: + 使用冻结的原始 nn.Linear 层,并添加两个小型低秩矩阵, + 计算公式为:y = x @ W^T + scaling * ((drop(x) @ A^T) @ B^T) + 其中 A 和 B 为低秩参数,scaling = lora_alpha / r. + """ + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0 + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r + self.lora_alpha = lora_alpha + self.scaling = lora_alpha / r if r > 0 else 1.0 + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + + # 原始权重(冻结参数,不更新) + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + self.bias = nn.Parameter(torch.empty(out_features)) if bias else None + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if bias: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + # 低秩矩阵参数(仅在 r > 0 时添加) + if r > 0: + # A 将 in_features 映射到低秩 r;B 从低秩 r 映射回 out_features + self.lora_A = nn.Parameter(torch.randn(r, in_features) * 0.01) + self.lora_B = nn.Parameter(torch.zeros(out_features, r)) + else: + self.lora_A = None + self.lora_B = None + + # 冻结原始权重参数,保证仅更新 LoRA 参数 + self.weight.requires_grad = False + if self.bias is not None: + self.bias.requires_grad = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # 原始线性输出(冻结部分) + result = F.linear(x, self.weight, self.bias) + # 如启用了 LoRA,则加上低秩部分 + if self.r > 0: + lora_out = F.linear(self.lora_dropout(x), self.lora_A) # (…, r) + lora_out = F.linear(lora_out, self.lora_B) # (…, out_features) + result = result + self.scaling * lora_out + return result + + +def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Module: + """ + 辅助函数:当 config.lora_r > 0 且 module_label 存在于 config.lora_target_modules 时, + 将传入的线性层替换为 LoRALinear,并复制原始权重数据。 + module_label 的取值含义由上层逻辑定义,例如: + - 若 module_label 为 "attn",表示在 SelfAttention 中替换 k, q, v, proj 等层。 + - 若 module_label 为 "feed_forward",表示在 Transformer Block 的 MLP 中替换线性层。 + """ + if config.lora_r > 0 and module_label in config.lora_target_modules: + new_linear = LoRALinear( + in_features=linear.in_features, + out_features=linear.out_features, + bias=(linear.bias is not None), + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout + ) + new_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + new_linear.bias.data.copy_(linear.bias.data) + return new_linear + else: + return linear @dataclass class TransformerConfig: @@ -36,6 +131,24 @@ class TransformerConfig: max_seq_len: int rotary_emb: bool = False + # LoRA 参数: + lora_r: int = 0 + lora_alpha: int = 1 + lora_dropout: float = 0.0 + # 指定哪些模块应用 LoRA,默认:attention 中的 k, q, v, proj 和 feed_forward 层(当非 moe 模型时) + lora_target_modules: list = None + + # Register Token 相关 + task_embed_option: str = "none" + register_token_num: int = 4 + register_token_shared: bool = True + + # 其它配置项 + gru_gating: bool = False + moe_in_transformer: bool = False + multiplication_moe_in_transformer: bool = False + num_experts_of_moe_in_transformer: int = 1 + @property def max_tokens(self): return self.tokens_per_block * self.max_blocks @@ -119,7 +232,7 @@ class Transformer(nn.Module): - ln_f (:obj:`nn.LayerNorm`): Layer normalization applied to the final output. """ - def __init__(self, config: TransformerConfig) -> None: + def __init__(self, config: TransformerConfig, task_embed=None) -> None: super().__init__() self.config = config self.drop = nn.Dropout(config.embed_pdrop) @@ -133,6 +246,71 @@ def __init__(self, config: TransformerConfig) -> None: self.config.rope_theta, ) self.register_buffer("freqs_cis", freqs_cis) + self.task_embed = task_embed + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.register_token_shared = True + + # TODO: 共享模式下,所有任务使用同一参数 + + if self.task_embed_option == "register_task_embed": + self.use_register_token = True # TODO + # Register token setup + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + + # 判断是否采用共享模式 + self.register_token_shared = getattr(config, "register_token_shared", True) + if self.register_token_shared: + # print(f'self.register_token_shared:{self.register_token_shared}') + # print(f'='*20) + # 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim) + self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) + nn.init.xavier_uniform_(self.register_tokens) + else: + # 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding, + # 并通过 SimNorm 归一化后复制出 register token + self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding + self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings + + else: + self.use_register_token = False # TODO + + + def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: + """ + 将 register_token_num 个 Register Token 拼接到序列最前面。 + + Arguments: + - sequences (:obj:`torch.Tensor`): (B, T, C) + - task_id (:obj:`int`): 当前任务的 ID + + Returns: + - new_sequences (:obj:`torch.Tensor`): (B, T + register_token_num, C) + """ + B = sequences.size(0) + device = sequences.device + + if self.register_token_shared: + # 共享模式:直接使用同一组 register_tokens 参数 + # register_tokens 形状为 (register_token_num, embed_dim) + register_tokens = self.register_tokens + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim) + else: + # 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens + task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim) + task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,) + register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim) + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim) + + new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C) + return new_sequences + + def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None: + """ + 移除所有层 KV 中最前面的 register_token_num 个 token,用于在 forward() 结束时调用。 + """ + if past_keys_values is None: + return + past_keys_values.remove_register_tokens(self.register_token_num) def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: """ @@ -160,7 +338,7 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues - start_pos (:obj:`int`): Starting position for rotary embeddings (default: 0). Returns: - - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). + - 输出张量 (B, T + register_token_num, C) 或 (B, T, C),视是否添加 Register Token 而定 """ seqlen = sequences.shape[1] # If using Rotary Position Embeddings (RoPE), slice the frequency components accordingly @@ -207,9 +385,23 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, freqs_cis) # Apply final layer normalization x = self.ln_f(x) + + # 如果 past_keys_values 不为 None,说明是推理阶段,此时我们需要把 KV 缓存中 + # 尾部多加的 Register Token 移除,以保证外键信息一致,不用修改外部逻辑 + # if self.use_register_token and (past_keys_values is not None): + if self.use_register_token: + self.remove_register_tokens_from_kv(past_keys_values) + + # TODO + if self.use_register_token: + # import ipdb; ipdb.set_trace() + x = x[:, :-self.register_token_num, :] + return x + + class Block(nn.Module): """ Transformer block class. @@ -240,12 +432,55 @@ def __init__(self, config: TransformerConfig) -> None: self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + + if config.moe_in_transformer: + # 创Create multiple independent MLP instances + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + elif config.multiplication_moe_in_transformer: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + else: + # self.feed_forward = nn.Sequential( + # nn.Linear(config.embed_dim, 4 * config.embed_dim), + # nn.GELU(approximate='tanh'), + # nn.Linear(4 * config.embed_dim, config.embed_dim), + # nn.Dropout(config.resid_pdrop), + # ) + # 普通的 MLP,若在 feed_forward 上启用 LoRA,则对其中线性层进行包装 + self.feed_forward = nn.Sequential( + _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward"), + nn.GELU(approximate='tanh'), + _maybe_wrap_linear(nn.Linear(4 * config.embed_dim, config.embed_dim), config, "feed_forward"), + nn.Dropout(config.resid_pdrop), + ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: @@ -264,10 +499,10 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis) if self.gru_gating: x = self.gate1(x, x_attn) - x = self.gate2(x, self.mlp(self.ln2(x))) + x = self.gate2(x, self.feed_forward(self.ln2(x))) else: x = x + x_attn - x = x + self.mlp(self.ln2(x)) + x = x + self.feed_forward(self.ln2(x)) return x @@ -295,19 +530,39 @@ def __init__(self, config: TransformerConfig) -> None: assert config.embed_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads." self.config = config + + self.task_embed_option = self.config.task_embed_option + if self.task_embed_option == "register_task_embed": + self.use_register_token = True # TODO + # Register token setup + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + else: + self.use_register_token = False # TODO + self.num_heads = config.num_heads - self.key = nn.Linear(config.embed_dim, config.embed_dim) - self.query = nn.Linear(config.embed_dim, config.embed_dim) - self.value = nn.Linear(config.embed_dim, config.embed_dim) + if config.lora_r > 0 and ("attn" in config.lora_target_modules): + self.key = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.query = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.value = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.proj = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + else: + self.key = nn.Linear(config.embed_dim, config.embed_dim) + self.query = nn.Linear(config.embed_dim, config.embed_dim) + self.value = nn.Linear(config.embed_dim, config.embed_dim) + self.proj = nn.Linear(config.embed_dim, config.embed_dim) self.attn_drop = nn.Dropout(config.attn_pdrop) self.resid_drop = nn.Dropout(config.resid_pdrop) - self.proj = nn.Linear(config.embed_dim, config.embed_dim) - causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) + if self.use_register_token: # ======= TODO ======== + causal_mask = torch.tril(torch.ones(config.max_tokens+self.register_token_num*5, config.max_tokens+self.register_token_num*5)) + else: + causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) + self.register_buffer('mask', causal_mask) + #@profile def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: """ @@ -326,7 +581,10 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, B, T, C = x.size() if kv_cache is not None: b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + try: + assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + except Exception as e: + print('debug') else: L = 0 @@ -338,6 +596,7 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) if kv_cache is not None: + # import ipdb; ipdb.set_trace() kv_cache.update(k, v) # time occupancy 21% k, v = kv_cache.get() # time occupancy 5% @@ -359,16 +618,39 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, # mask.shape: (T, L + T) mask = self.mask[L:L + T, :L + T] + # import ipdb; ipdb.set_trace() + + # Adjust mask for register tokens if applicable + if self.use_register_token and self.register_token_num > 0: + # Allow all positions to attend to the last `register_token_num` tokens + register_mask = mask.clone() # (T, L + T) + register_mask[-self.register_token_num:, :] = 1 # Allow register tokens to see all positions + register_mask[:, -self.register_token_num:] = 1 # Allow all positions to see register tokens + mask = register_mask + + if kv_cache is not None: + # =============TODO============= + # import ipdb; ipdb.set_trace() + b, nh, new_L, c = kv_cache.shape # new_L可能小于L + T + mask = mask[:,-new_L:] + # else: + # import ipdb; ipdb.set_trace() + + # att.shape: (B, num_heads, T, L + T) att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_drop(att) + + # import ipdb; ipdb.set_trace() y = att @ v # (B, num_heads, T, L + T) x (B, num_heads, L + T, head_size) -> (B, num_heads, T, head_size) y = rearrange(y, 'b h t e -> b t (h e)') # Combine the heads back together (B, T, embed_dim) y = self.resid_drop(self.proj(y)) + + return y @torch.no_grad() diff --git a/lzero/model/unizero_world_models/transformer_no-lora.py b/lzero/model/unizero_world_models/transformer_no-lora.py new file mode 100644 index 000000000..e0f0f0c0b --- /dev/null +++ b/lzero/model/unizero_world_models/transformer_no-lora.py @@ -0,0 +1,477 @@ +""" +Modified from https://github.com/karpathy/nanoGPT +""" + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from ding.torch_utils.network import GRUGatingUnit +from einops import rearrange +from torch.nn import functional as F + +from .kv_caching import KeysValues +from .moe import MoeLayer, MultiplicationFeedForward +from line_profiler import line_profiler +from lzero.model.common import SimNorm + + +@dataclass +class TransformerConfig: + tokens_per_block: int + max_blocks: int + attention: str + + num_layers: int + num_heads: int + embed_dim: int + + embed_pdrop: float + resid_pdrop: float + attn_pdrop: float + + @property + def max_tokens(self): + return self.tokens_per_block * self.max_blocks + + +class Transformer(nn.Module): + """ + Transformer model class. + + Arguments: + config (:obj:`TransformerConfig`): Configuration for the Transformer model. + + Attributes: + - config (:obj:`TransformerConfig`): Configuration object. + - drop (:obj:`nn.Dropout`): Dropout layer for embedding dropout. + - blocks (:obj:`nn.ModuleList`): List of Transformer blocks. + - ln_f (:obj:`nn.LayerNorm`): Layer normalization applied to the final output. + """ + + def __init__(self, config: TransformerConfig, task_embed=None) -> None: + super().__init__() + self.config = config + self.drop = nn.Dropout(config.embed_pdrop) + self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) + self.ln_f = nn.LayerNorm(config.embed_dim) + + self.task_embed = task_embed + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.register_token_shared = True + + # TODO: 共享模式下,所有任务使用同一参数 + + if self.task_embed_option == "register_task_embed": + self.use_register_token = True # TODO + # Register token setup + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + + # 判断是否采用共享模式 + self.register_token_shared = getattr(config, "register_token_shared", True) + if self.register_token_shared: + # print(f'self.register_token_shared:{self.register_token_shared}') + # print(f'='*20) + # 共享模式:所有任务使用同一个 register_tokens 参数,形状为 (register_token_num, embed_dim) + self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) + nn.init.xavier_uniform_(self.register_tokens) + else: + # 非共享模式:依赖外部传入的 task_embed 模块来生成 task embedding, + # 并通过 SimNorm 归一化后复制出 register token + self.task_embed = task_embed # 外部传入的模块,如 nn.Embedding + self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) # Normalization for task embeddings + + else: + self.use_register_token = False # TODO + + + def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: + """ + 将 register_token_num 个 Register Token 拼接到序列最前面。 + + Arguments: + - sequences (:obj:`torch.Tensor`): (B, T, C) + - task_id (:obj:`int`): 当前任务的 ID + + Returns: + - new_sequences (:obj:`torch.Tensor`): (B, T + register_token_num, C) + """ + B = sequences.size(0) + device = sequences.device + + if self.register_token_shared: + # 共享模式:直接使用同一组 register_tokens 参数 + # register_tokens 形状为 (register_token_num, embed_dim) + register_tokens = self.register_tokens + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # 形状 (B, register_token_num, embed_dim) + else: + # 非共享模式:依靠 task_embed 动态生成 task embedding,然后复制出 register tokens + task_embedding = self.task_embed(torch.tensor([task_id], device=device)) # (1, embed_dim) + task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) # (embed_dim,) + register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) # (register_token_num, embed_dim) + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) # (B, register_token_num, embed_dim) + + new_sequences = torch.cat([sequences, register_tokens], dim=1) # 在序列末尾拼接 register tokens (B, register_token_num + T, C) + return new_sequences + + def remove_register_tokens_from_kv(self, past_keys_values: KeysValues) -> None: + """ + 移除所有层 KV 中最前面的 register_token_num 个 token,用于在 forward() 结束时调用。 + """ + if past_keys_values is None: + return + past_keys_values.remove_register_tokens(self.register_token_num) + + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + """ + Generate a placeholder for keys and values. + + Arguments: + - n (:obj:`int`): Batch size. + - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. + + Returns: + - KeysValues: An object containing empty keys and values. + """ + device = self.ln_f.weight.device # Assumption: All submodules are on the same device + return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + + + #@profile + def forward( + self, + sequences: torch.Tensor, # (B, T, C) + past_keys_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0 + ) -> torch.Tensor: + """ + Forward pass of the Transformer model. + + Arguments: + - sequences (:obj:`torch.Tensor`): (B, T, C) + - past_keys_values (:obj:`Optional[KeysValues]`): 缓存,用于推理时加速 + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): 某些场景下可用的有效上下文长度 + - task_id (:obj:`int`): 任务 ID + + Returns: + - 输出张量 (B, T + register_token_num, C) 或 (B, T, C),视是否添加 Register Token 而定 + """ + # 若使用 Register Token,则将其拼到序列最前面 + # 训练阶段和推理阶段都统一处理 + if self.use_register_token: + sequences = self.add_register_tokens(sequences, task_id) + + # 接入 dropout + x = self.drop(sequences) + + # 逐层调用 + for i, block in enumerate(self.blocks): + x = block(x, + None if past_keys_values is None else past_keys_values[i], + valid_context_lengths) + + # 最后层 LN + x = self.ln_f(x) + + # 如果 past_keys_values 不为 None,说明是推理阶段,此时我们需要把 KV 缓存中 + # 尾部多加的 Register Token 移除,以保证外键信息一致,不用修改外部逻辑 + # if self.use_register_token and (past_keys_values is not None): + if self.use_register_token: + self.remove_register_tokens_from_kv(past_keys_values) + + # TODO + if self.use_register_token: + # import ipdb; ipdb.set_trace() + x = x[:, :-self.register_token_num, :] + + return x + + + + +class Block(nn.Module): + """ + Transformer block class. + + Arguments: + config (:obj:`TransformerConfig`): Configuration for the Transformer block. + + Attributes: + - gru_gating (:obj:`bool`): Flag to use GRU gating mechanism. + - gru_bias (:obj:`float`): Bias for the GRU gating mechanism. + - gate1 (:obj:`Optional[GRUGatingUnit]`): First GRU gating unit (if GRU gating is enabled). + - gate2 (:obj:`Optional[GRUGatingUnit]`): Second GRU gating unit (if GRU gating is enabled). + - ln1 (:obj:`nn.LayerNorm`): Layer normalization before the attention layer. + - ln2 (:obj:`nn.LayerNorm`): Layer normalization before the MLP. + - attn (:obj:`SelfAttention`): Self-attention mechanism. + - mlp (:obj:`nn.Sequential`): Multi-layer perceptron. + """ + + def __init__(self, config: TransformerConfig) -> None: + super().__init__() + # NOTE: GRU gating as in GTrXL + self.gru_gating = config.gru_gating + self.gru_bias = 2.0 + if self.gru_gating: + self.gate1 = GRUGatingUnit(config.embed_dim, self.gru_bias) + self.gate2 = GRUGatingUnit(config.embed_dim, self.gru_bias) + + self.ln1 = nn.LayerNorm(config.embed_dim) + self.ln2 = nn.LayerNorm(config.embed_dim) + self.attn = SelfAttention(config) + if config.moe_in_transformer: + # 创Create multiple independent MLP instances + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + elif config.multiplication_moe_in_transformer: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) + else: + self.feed_forward = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of the Transformer block. + + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). + - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). + + Returns: + - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). + """ + x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) + if self.gru_gating: + x = self.gate1(x, x_attn) + x = self.gate2(x, self.feed_forward(self.ln2(x))) + else: + x = x + x_attn + x = x + self.feed_forward(self.ln2(x)) + + return x + + +class SelfAttention(nn.Module): + """ + Implements self-attention mechanism for transformers. + + Arguments: + config (:obj:`TransformerConfig`): Configuration object containing hyperparameters. + + Attributes: + - config (:obj:`TransformerConfig`): Stores the configuration for the self-attention module. + - num_heads (:obj:`int`): Number of attention heads. + - key (:obj:`nn.Linear`): Linear layer to project input to key vectors. + - query (:obj:`nn.Linear`): Linear layer to project input to query vectors. + - value (:obj:`nn.Linear`): Linear layer to project input to value vectors. + - attn_drop (:obj:`nn.Dropout`): Dropout layer for attention weights. + - resid_drop (:obj:`nn.Dropout`): Dropout layer for residual connection. + - proj (:obj:`nn.Linear`): Final linear layer for projection. + - mask (:obj:`torch.Tensor`): Mask tensor for causal or block-causal attention. + """ + def __init__(self, config: TransformerConfig) -> None: + super().__init__() + assert config.embed_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads." + + self.config = config + + self.task_embed_option = self.config.task_embed_option + if self.task_embed_option == "register_task_embed": + self.use_register_token = True # TODO + # Register token setup + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + else: + self.use_register_token = False # TODO + + self.num_heads = config.num_heads + + self.key = nn.Linear(config.embed_dim, config.embed_dim) + self.query = nn.Linear(config.embed_dim, config.embed_dim) + self.value = nn.Linear(config.embed_dim, config.embed_dim) + + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + self.proj = nn.Linear(config.embed_dim, config.embed_dim) + + if self.use_register_token: # ======= TODO ======== + causal_mask = torch.tril(torch.ones(config.max_tokens+self.register_token_num*5, config.max_tokens+self.register_token_num*5)) + else: + causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) + + self.register_buffer('mask', causal_mask) + + #@profile + def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ + Forward pass for the self-attention mechanism. + + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C) where B is batch size, + T is sequence length, and C is embedding dimension. + - kv_cache (:obj:`Optional[KeysValues]`): Optional key-value cache for faster inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Optional tensor containing valid context lengths. + + Returns: + - torch.Tensor: Output tensor of shape (B, T, C). + """ + B, T, C = x.size() + if kv_cache is not None: + b, nh, L, c = kv_cache.shape + try: + assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + except Exception as e: + print('debug') + else: + L = 0 + + q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) + k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) + v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) + + if kv_cache is not None: + # import ipdb; ipdb.set_trace() + kv_cache.update(k, v) # time occupancy 21% + k, v = kv_cache.get() # time occupancy 5% + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + if valid_context_lengths is not None: + # Final mask.shape: (B, T, L + T) + # L is the context length, T is the current input length, + # valid_context_lengths is the valid length at the end of the context. + mask = torch.zeros(B, T, L + T, device=att.device) + # For each sample, set the invalid parts to 0 based on its valid length. + for i in range(B): + mask[i] = self.mask[L:L + T, :L + T].clone() + mask[i, :, :(L - valid_context_lengths[i])] = 0 # Set invalid parts to 0. + # Adjust mask dimensions to match the last two dimensions of att. + # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T) + mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) + else: + # mask.shape: (T, L + T) + mask = self.mask[L:L + T, :L + T] + + # import ipdb; ipdb.set_trace() + + # Adjust mask for register tokens if applicable + if self.use_register_token and self.register_token_num > 0: + # Allow all positions to attend to the last `register_token_num` tokens + register_mask = mask.clone() # (T, L + T) + register_mask[-self.register_token_num:, :] = 1 # Allow register tokens to see all positions + register_mask[:, -self.register_token_num:] = 1 # Allow all positions to see register tokens + mask = register_mask + + if kv_cache is not None: + # =============TODO============= + # import ipdb; ipdb.set_trace() + b, nh, new_L, c = kv_cache.shape # new_L可能小于L + T + mask = mask[:,-new_L:] + # else: + # import ipdb; ipdb.set_trace() + + + # att.shape: (B, num_heads, T, L + T) + att = att.masked_fill(mask == 0, float('-inf')) + + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + + # import ipdb; ipdb.set_trace() + y = att @ v # (B, num_heads, T, L + T) x (B, num_heads, L + T, head_size) -> (B, num_heads, T, head_size) + + y = rearrange(y, 'b h t e -> b t (h e)') # Combine the heads back together (B, T, embed_dim) + y = self.resid_drop(self.proj(y)) + + + + return y + + @torch.no_grad() + def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute the attention map for the input sequence. This is useful for visualization purposes. + More details can be found in visualizing_utils.py. + + Arguments: + - x (:obj:`torch.Tensor`): Input sequence with shape (B, T, C). + - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for supporting long sequence inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for handling variable-length contexts. + + Returns: + - torch.Tensor: Attention map with shape (B, nh, T, L + T), representing the distribution of attention. + """ + B, T, C = x.size() + if kv_cache is not None: + b, nh, L, c = kv_cache.shape + assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions are inconsistent with input dimensions." + else: + L = 0 + + # Compute query, key, and value projections + q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + + if kv_cache is not None: + # Update the kv_cache with the new keys and values + kv_cache.update(k, v) + k, v = kv_cache.get() + + # Compute the attention scores + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + if valid_context_lengths is not None: + mask = torch.zeros(B, T, L + T, device=att.device) + for i in range(B): + # Create attention mask for each batch + mask[i] = self.mask[L:L + T, :L + T].clone() + mask[i, :, :(L - valid_context_lengths[i])] = 0 + mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) + else: + mask = self.mask[L:L + T, :L + T] + + # Apply the attention mask + att = att.masked_fill(mask == 0, float('-inf')) + att = F.softmax(att, dim=-1) + + return att \ No newline at end of file diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..0a0c9dd51 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -215,8 +215,14 @@ def init_weights(module, norm_type='BN'): module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() - module.weight.data.fill_(1.0) + try: + module.bias.data.zero_() + except Exception as e: + print(e) + try: + module.weight.data.fill_(1.0) + except Exception as e: + print(e) elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -294,7 +300,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 833e4887e..86583d198 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -9,7 +9,7 @@ from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform from lzero.model.common import SimNorm -from lzero.model.utils import cal_dormant_ratio +from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer @@ -41,8 +41,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer self.config = config - self.transformer = Transformer(self.config) + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.transformer = Transformer(self.config) + self.task_num = 1 if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -51,6 +53,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: logging.info(f"self.device: {self.device}") self.to(self.device) + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + # Initialize configuration parameters self._initialize_config_parameters() @@ -65,6 +69,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.precompute_pos_emb_diff_kv() print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + self.continuous_action_space = self.config.continuous_action_space # Initialize action embedding table @@ -78,10 +84,16 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) logging.info(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') + # self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') # TODO + # Head modules self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) - self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, - self.sim_norm) # NOTE: we add a sim_norm to the head for observations + self.head_observations = self._create_head( + self.all_but_last_latent_state_pattern, + self.config.embed_dim, + self._get_final_norm(self.final_norm_option_in_obs_head) # 使用指定的归一化方法 + ) if self.continuous_action_space: self.sigma_type = self.config.sigma_type self.bound_type = self.config.bound_type @@ -90,6 +102,14 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + # 对于 head 部分,查找所有以 "head_" 开头的子模块 + self.head_dict = {} + for name, module in self.named_children(): + if name.startswith("head_"): + self.head_dict[name] = module + if self.head_dict: + self.head_dict = nn.ModuleDict(self.head_dict) + # Apply weight initialization, the order is important self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) self._initialize_last_layer() @@ -131,6 +151,18 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.reanalyze_phase = False + + def _get_final_norm(self, norm_option: str) -> nn.Module: + """ + 根据指定的归一化选项返回相应的归一化模块。 + """ + if norm_option == 'LayerNorm': + return nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif norm_option == 'SimNorm': + return SimNorm(simnorm_dim=self.config.group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") + def custom_copy_kv_cache_to_shared_init_envs(self, src_kv: KeysValues, env_id) -> int: """ Overview: @@ -179,6 +211,7 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape if self.shared_pool_wm[self.shared_pool_index_wm] is None: + # import ipdb; ipdb.set_trace() self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( src_kv_shape[0], # Number of elements (n) src_kv_shape[1], # Number of attention heads (num_heads) @@ -192,7 +225,10 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): # Copy the key and value caches using torch.copy_() for efficient data transfer + # try: dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + # except Exception as e: + # import ipdb; ipdb.set_trace() dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) dst_layer._k_cache._size = src_layer._k_cache._size dst_layer._v_cache._size = src_layer._v_cache._size @@ -248,7 +284,7 @@ def _initialize_config_parameters(self) -> None: self.gamma = self.config.gamma self.context_length = self.config.context_length self.dormant_threshold = self.config.dormant_threshold - self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank self.num_observations_tokens = self.config.tokens_per_block - 1 self.latent_recon_loss_weight = self.config.latent_recon_loss_weight self.perceptual_loss_weight = self.config.perceptual_loss_weight @@ -257,7 +293,6 @@ def _initialize_config_parameters(self) -> None: self.max_cache_size = self.config.max_cache_size self.env_num = self.config.env_num self.num_layers = self.config.num_layers - self.obs_per_embdding_dim = self.config.embed_dim self.sim_norm = SimNorm(simnorm_dim=self.group_size) def _initialize_patterns(self) -> None: @@ -333,7 +368,15 @@ def _initialize_projection_input_dim(self) -> None: if self.num_observations_tokens == 16: self.projection_input_dim = 128 elif self.num_observations_tokens == 1: - self.projection_input_dim = self.obs_per_embdding_dim + # self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim = self.config.embed_dim - self.task_embed_dim + elif self.task_embed_option == "register_task_embed": + self.projection_input_dim = self.config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.projection_input_dim = self.config.embed_dim + else: + self.projection_input_dim = self.config.embed_dim def _initialize_statistics(self) -> None: """Initialize counters for hit count and query count statistics.""" @@ -389,6 +432,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -599,6 +643,7 @@ def forward( # The 'logits_ends' is intentionally set to None. return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -627,6 +672,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -666,6 +712,7 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) return return_result, num_steps + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -718,6 +765,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) + #@profile @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: """ @@ -752,6 +800,7 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos return outputs_wm, self.latent_state + #@profile @torch.no_grad() def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, batch_action=None, @@ -859,7 +908,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ # [192, 16, 64] -> [32, 6, 16, 64] last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, - self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + self.config.embed_dim) # (BL, K) for unroll_step=1 last_obs_embeddings = last_obs_embeddings[:, :-1, :] batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) @@ -890,6 +939,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens return outputs_wm + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ @@ -907,6 +957,7 @@ def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, search_depth=[], start_pos: int = 0): @@ -993,6 +1044,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -1045,6 +1097,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, search_depth=[], valid_context_lengths=None): """ @@ -1188,6 +1241,7 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.past_kv_cache_recurrent_infer[cache_key] = cache_index + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0, start_pos: int = 0) -> list: """ @@ -1272,17 +1326,44 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') # ========= logging for analysis ========= - if self.analysis_dormant_ratio: + if self.analysis_dormant_ratio_weight_rank: # Calculate dormant ratio of the encoder shape = batch['observations'].shape # (..., C, H, W) inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) - dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), - percentage=self.dormant_threshold) + dormant_ratio_encoder_dict = cal_dormant_ratio(self.tokenizer.encoder, inputs.detach(), + dormant_threshold=self.dormant_threshold) + # print(dormant_ratio_encoder_dict) + dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] + + # 计算全局平均权重绝对值 + avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder) + # print("Average Weight Magnitude of encoder:", avg_weight_mag_encoder) + # 计算全局平均权重绝对值 + avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) + # print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer) + # print(f"self.head_dict:{self.head_dict}") + avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) + # print("Average Weight Magnitude of head:", avg_weight_mag_head) + + # 计算 effective rank,对于 representation 层,注意: + # representation 层在 model.named_modules() 的名称为 "representation" + # print(f"self.tokenizer.encoder:{self.tokenizer.encoder}") + e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder, inputs, representation_layer_name="last_linear") + # print("Effective Rank of encoder_last_linear:", e_rank_last_linear) + e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder, inputs, representation_layer_name="sim_norm") + # print("Effective Rank of encoder_sim_norm:", e_rank_sim_norm) + + self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: dormant_ratio_encoder = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) # Calculate the L2 norm of the latent state roots latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() @@ -1362,16 +1443,20 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) # ========= logging for analysis ========= - if self.analysis_dormant_ratio: + if self.analysis_dormant_ratio_weight_rank: # Calculate dormant ratio of the world model dormant_ratio_world_model = cal_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, - percentage=self.dormant_threshold) + dormant_threshold=self.dormant_threshold) + dormant_ratio_transformer = dormant_ratio_world_model['transformer'] + dormant_ratio_head = dormant_ratio_world_model['head'] + self.past_kv_cache_recurrent_infer.clear() self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: - dormant_ratio_world_model = torch.tensor(0.) + dormant_ratio_transformer = torch.tensor(0.) + dormant_ratio_head = torch.tensor(0.) # ========== for visualization ========== # Uncomment the lines below for visualization @@ -1525,7 +1610,13 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_world_model=dormant_ratio_world_model, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, policy_mu=mu, policy_sigma=sigma, @@ -1548,7 +1639,13 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_world_model=dormant_ratio_world_model, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, ) @@ -1614,7 +1711,7 @@ def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate the policy loss for continuous actions. @@ -1629,9 +1726,12 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso - mu (:obj:`torch.Tensor`): The mean of the normal distribution. - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. """ - batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + if task_id is None: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ 0], self.config.num_unroll_steps, self.config.action_space_size - + else: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size_list[task_id] policy_logits_all = outputs.logits_policy mask_batch = batch['mask_padding'] child_sampled_actions_batch = batch['child_sampled_actions'] @@ -1673,6 +1773,8 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso # KL as projector target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + + # KL as projector policy_loss = -torch.sum( torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 ) * mask_batch @@ -1722,6 +1824,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1731,6 +1834,7 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -1750,6 +1854,7 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.view(-1, self.support_size), None + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py new file mode 100644 index 000000000..ccb47eb5a --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -0,0 +1,2227 @@ +import collections +import logging +from typing import Any, Tuple +from typing import Optional +from typing import Union, Dict + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from lzero.model.common import SimNorm +from lzero.model.unizero_world_models.world_model import WorldModel +from lzero.model.utils import cal_dormant_ratio, compute_average_weight_magnitude, cal_effective_rank +from .moe import MoeLayer, MultiplicationFeedForward +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights +from .utils import WorldModelOutput, hash_state + +logging.getLogger().setLevel(logging.DEBUG) +from ding.utils import get_rank +import torch.distributed as dist +from sklearn.manifold import TSNE +import os +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Patch +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import torch + + +class WorldModelMT(WorldModel): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + + #@profile + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + + - task_embed_option (str): Strategy for incorporating task embeddings. Options: + - "add_task_embed": Adds task embeddings to observation embeddings (default). + - "concat_task_embed": Concatenates task embeddings with observation embeddings. + - "register_task_embed": Uses task embeddings as additional input tokens. + """ + super().__init__(config, tokenizer) + self.tokenizer = tokenizer + self.config = config + self.share_head = config.share_head # 新增参数 + + if self.config.device == 'cpu': + self.device = torch.device('cpu') + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Move all modules to the specified device + print(f"self.device: {self.device}") + # Position embedding + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + if self.task_embed_option == "register_task_embed": + # 由于 "register_task_embed"设定下的位置编码没有矫正 + # 使用 nn.Embedding,但初始化为全零并禁止学习 + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + nn.init.constant_(self.pos_emb.weight, 0.0) # 初始化全零 + self.pos_emb.weight.requires_grad = False # 禁止更新 + + # Task embedding setup + self.use_task_embed = config.use_task_embed + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.task_num = config.task_num + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + + self.precompute_pos_emb_diff_kv() + + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + if self.task_embed_option == "concat_task_embed": + # TODO:目前在 "concat_task_embed"下面,self.pos_emb需要设置为固定的0 + self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TODO: TDMPC2:max_norm=1性能更好 + # self.task_emb.weight = self.sim_norm(self.task_emb.weight) + self.obs_act_embed_dim = config.embed_dim - self.task_embed_dim + self.register_token_num = 0 + elif self.task_embed_option == "register_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.obs_act_embed_dim = config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.obs_act_embed_dim = config.embed_dim + else: + self.task_emb = None + self.obs_act_embed_dim = config.embed_dim + self.register_token_num = 0 + + + self.transformer = Transformer(self.config, self.task_emb) + + # TODO ======== + self.analysis_tsne = self.config.get('analysis_tsne', False) + + if self.analysis_tsne: + self.env_id_list = self.config.env_id_list + # 自动生成 self.env_short_names + self.env_short_names = {} + + # 遍历 env_id_list,提取短名称 + for env_id in self.config.env_id_list: + # 提取 'NoFrameskip-v4' 之前的部分作为短名称 + short_name = env_id.replace('NoFrameskip-v4', '') + self.env_short_names[env_id] = short_name + # 映射环境 ID 到简写名称 + # self.env_short_names = { + # 'PongNoFrameskip-v4': 'Pong', + # 'MsPacmanNoFrameskip-v4': 'MsPacman', + # 'SeaquestNoFrameskip-v4': 'Seaquest', + # 'BoxingNoFrameskip-v4': 'Boxing', + # 'AlienNoFrameskip-v4': 'Alien', + # 'ChopperCommandNoFrameskip-v4': 'Chopper', + # 'HeroNoFrameskip-v4': 'Hero', + # 'RoadRunnerNoFrameskip-v4': 'RoadRunner' + # } + # 颜色映射,确保每个任务有固定的颜色 + self.num_tasks = len(self.env_id_list) + + # 生成足够多的颜色 + self.colors = self._generate_colors(len(self.env_id_list)) + + + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head + + + self.to(self.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + self.continuous_action_space = self.config.continuous_action_space + + # Initialize action embedding table + if self.continuous_action_space: + # TODO: check the effect of SimNorm + # self.act_embedding_table = nn.Sequential( + # nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), + # SimNorm(simnorm_dim=self.group_size)) + # print(f'config.action_space_size_list:{config.action_space_size_list}') + self.act_embedding_table = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size) + ) + for task_id in range(self.task_num) + ]) + else: + # for discrete action space + self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + print(f'='*20) + print(f"self.obs_act_embed_dim:{self.obs_act_embed_dim}") + print(f'='*20) + + + # if self.num_experts_in_moe_head == -1: + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm') + # self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') # TODO + + print('We use normal head') + # TODO: Normal Head + for task_id in range(self.task_num): + if self.continuous_action_space: + # TODO + self.sigma_type = self.config.sigma_type + self.bound_type = self.config.bound_type + self.head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) # TODO + else: + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + + if not self.share_head or task_id == 0: + self.head_policy_multi_task.append(self.head_policy) + + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + if not self.share_head or task_id == 0: + self.head_value_multi_task.append(self.head_value) + + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + if not self.share_head or task_id == 0: + self.head_rewards_multi_task.append(self.head_rewards) + + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, + self.config.embed_dim, + # self.sim_norm + self._get_final_norm(self.final_norm_option_in_obs_head) # 使用指定的归一化方法 + ) # NOTE: we add a sim_norm to the head for observations + if not self.share_head or task_id == 0: + self.head_observations_multi_task.append(self.head_observations) + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store SoftMoE instances + self.soft_moe_instances = {} + + # Create softmoe head modules + self.create_head_modules_softmoe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store moe instances + self.moe_instances = {} + + # Create moe head modules + self.create_head_modules_moe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + + # 对于 head 部分,查找所有以 "head_" 开头的子模块 + # self.head_dict = {} + # for name, module in self.named_children(): + # # TODO: check + # if name.startswith("head_") and name.endswith("_multi_task") : + # self.head_dict[name] = module + # if self.head_dict: + # self.head_dict = nn.ModuleDict(self.head_dict) + + self.head_dict = nn.ModuleDict({ + name: module for name, module in self.named_children() + if name.startswith("head_") and name.endswith("_multi_task") + }) + print("="*20) + print(f"self.head_dict:{self.head_dict}") + + # Apply weight initialization, the order is important + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer() + + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # TODO: check the size of the shared pool + # for self.kv_cache_recurrent_infer + # If needed, recurrent_infer should store the results of the one MCTS search. + self.shared_pool_size = int(50*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size + self.shared_pool_index = 0 + + # for self.kv_cache_init_infer + # In contrast, init_infer only needs to retain the results of the most recent step. + # self.shared_pool_size_init = int(2*self.env_num) + self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # for self.kv_cache_wm + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + self._rank = get_rank() + + def _generate_colors(self, num_colors): + """ + 生成足够多的独特颜色,适用于大量分类。 + + 参数: + - num_colors: 所需颜色数量。 + + 返回: + - colors: 颜色列表。 + """ + # 使用多个matplotlib离散色图拼接 + color_maps = ['tab20', 'tab20b', 'tab20c'] + colors = [] + for cmap_name in color_maps: + cmap = plt.get_cmap(cmap_name) + colors.extend([cmap(i) for i in range(cmap.N)]) + if len(colors) >= num_colors: + break + if len(colors) < num_colors: + # 生成额外的颜色,如果需要 + additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors)) + colors.extend([additional_colors(i) for i in range(num_colors - len(colors))]) + return colors[:num_colors] + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.device = self.config.device + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _get_final_norm(self, norm_option: str) -> nn.Module: + """ + 根据指定的归一化选项返回相应的归一化模块。 + """ + if norm_option == 'LayerNorm': + return nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif norm_option == 'SimNorm': + return SimNorm(simnorm_dim=self.config.group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, moe=None) -> Head: + """Create moe head modules for the transformer.""" + modules = [ + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + def get_moe(self, name): + """Get or create a MoE instance""" + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + + self.moe_instances[name] = MoeLayer( + experts=self.experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + return self.moe_instances[name] + + def create_head_modules_moe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_moe( + self.act_tokens_pattern, + self.support_size, + moe=self.get_moe("rewards_moe") + ) + + # Observations head + self.head_observations = self._create_head_moe( + self.all_but_last_latent_state_pattern, + self.embdding_dim, + norm_layer=self.sim_norm, # NOTE + moe=self.get_moe("observations_moe") + ) + + # Policy head + self.head_policy = self._create_head_moe( + self.value_policy_tokens_pattern, + self.action_space_size, + moe=self.get_moe("policy_moe") + ) + + # Value head + self.head_value = self._create_head_moe( + self.value_policy_tokens_pattern, + self.support_size, + moe=self.get_moe("value_moe") + ) + + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: + """Create softmoe head modules for the transformer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name): + """Get or create a SoftMoE instance""" + # from soft_moe_pytorch import SoftMoE + # if name not in self.soft_moe_instances: + # self.soft_moe_instances[name] = SoftMoE( + # dim=self.embed_dim, + # seq_len=20, # TODO + # num_experts=self.num_experts_in_moe_head, + # ) + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + num_experts=self.num_experts_in_moe_head, + geglu = True + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_softmoe( + self.act_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("rewards_soft_moe") + ) + + # Observations head + self.head_observations = self._create_head_softmoe( + self.all_but_last_latent_state_pattern, + self.config.embed_dim, + norm_layer=self.sim_norm, # NOTE + soft_moe=self.get_soft_moe("observations_soft_moe") + ) + + # Policy head + self.head_policy = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.action_space_size, + soft_moe=self.get_soft_moe("policy_soft_moe") + ) + + # Value head + self.head_value = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("value_soft_moe") + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True + print(f'world_model_mt.py:self.task_num:{self.task_num}') + if last_linear_layer_init_zero: + if self.continuous_action_space: + module_to_initialize = [self.head_value, self.head_rewards, self.head_observations] + else: + module_to_initialize = [self.head_policy, self.head_value, self.head_rewards, self.head_observations] + + # TODO: multitask + if self.task_num == 1: + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + elif self.task_num > 1: + if self.continuous_action_space: + module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + else: + module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + self.past_kv_cache_recurrent_infer = collections.OrderedDict() + self.past_kv_cache_init_infer = collections.OrderedDict() + self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim = self.config.embed_dim - self.task_embed_dim + elif self.task_embed_option == "register_task_embed": + self.projection_input_dim = self.config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.projection_input_dim = self.config.embed_dim + else: + self.projection_input_dim = self.config.embed_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + #@profile + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + #@profile + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: # TODO + # for end in [self.context_length - self.register_token_num - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + #@profile + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + # TODO: detach() ========== + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + #@profile + def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, task_id=0) -> WorldModelOutput: + """ + Forward pass for the model. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing observation embeddings or action tokens. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths. + Returns: + - WorldModelOutput: Model output containing logits for observations, rewards, policy, and value. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + else: + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # ============= TODO: no task_embeddings now ============= + + # Determine previous steps based on key-value caching method + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device) + else: + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid_context_lengths during initial inference + if is_init_infer: + valid_context_lengths = None + + # inference阶段: collect或者eval Process observation embeddings + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + + # TODO: multitask + if self.task_embed_option == "add_task_embed": + obs_embeddings = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + + # print(f'=='*20) + # print(f"is_init_infer:{is_init_infer}") + # print(f'obs_embeddings.shape:{obs_embeddings.shape}') + # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') + # print(f'=='*20) + + # if is_init_infer: + # # 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了 + # # Expand task embeddings to match the sequence shape + # task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + # obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + + if is_init_infer and not self.reanalyze_phase: + # 注意只有在inference时,只有在is_init_infer时拼接task embeddings,recurr_infer中已经在init_infer中增加了task embeddings的信息了 + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + + # if is_init_infer: + # if self.task_embed_option == "register_task_embed": + # # Register task embeddings as input tokens + # task_tokens = self.task_embeddings.expand(obs_embeddings.shape[0], self.register_token_length, -1) + # obs_embeddings = torch.cat([task_tokens, obs_embeddings], dim=1) + + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + + # inference阶段: collect或者eval Process action tokens + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + + if self.continuous_action_space: + num_steps = 1 + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(1) + else: + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + if self.task_num >= 1 and self.continuous_action_space: + act_embeddings = self.act_embedding_table[task_id](act_tokens) + else: + act_embeddings = self.act_embedding_table(act_tokens) + + if self.task_embed_option == "add_task_embed": + # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 + # obs_embeddings = obs_embeddings + self.task_embeddings + pass + elif self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'act_embeddings.shape:{act_embeddings.shape}') + # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') + # print(f'=='*20) + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(act_embeddings.shape[0], act_embeddings.shape[1], -1) + act_embeddings = torch.cat([act_embeddings, task_emb_expanded], dim=-1) + + + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + # 训练阶段: Process combined observation embeddings and action tokens + else: + # "add_task_embed"在self._process_obs_act_combined_cont方法内部处理, + # process_obs_act_combined目前还没有增加task_embed的concat和register模式 + if self.continuous_action_space: + sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + + # Pass sequences through transformer + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=task_id) + + # Generate logits + + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + # TODO: one head or moe head + if self.use_moe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + # 使用共享head或任务特定的head + head_index = 0 if self.share_head else task_id + # print(f"="*20) + # print(f"head_index:{head_index}") + # print(f"="*20) + logits_observations = self.head_observations_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + + # logits_ends is None + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + + #@profile + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + # 修复前面kv_cache和z/a的位置编码不对, kv_cache, z/a, register_token + # if self.use_task_embed and self.task_embed_option == "register_task_embed": + # if prev_steps + num_steps + self.register_token_num > self.context_length: + # prev_steps = self.context_length - self.register_token_num - 1 + + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + + # try: + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + # except Exception as e: + # print(e) + # import ipdb; ipdb.set_trace() + + return embeddings + position_embeddings + + #@profile + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + if self.continuous_action_space: + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: # TODO + act_tokens = act_tokens.unsqueeze(-1) + + # B, L, E + act_embeddings = self.act_embedding_table[task_id](act_tokens) + + B, L, K, E = obs_embeddings.size() + + if self.task_embed_option == "concat_task_embed": + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + else: + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + + if self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'self.task_embeddings.shape:{self.task_embeddings.shape}') + # print(f'=='*20) + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + for i in range(L): + if self.task_embed_option == "add_task_embed": + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + elif self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'obs_embeddings.shape:{obs_embeddings.shape}') + # print(f'=='*20) + obs = torch.cat([obs_embeddings[:, i, :, :], task_emb_expanded], dim=-1) + else: + obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) + + act = act_embeddings[:, i, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'act_embeddings.shape:{act_embeddings.shape}') + # print(f'=='*20) + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + # print(f'obs_act.shape:{obs_act.shape}') + + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + + #@profile + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + if self.task_embed_option == "concat_task_embed": + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + else: + # B, L*2, E + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + if self.task_embed_option == "concat_task_embed": + # Expand task embeddings to match the sequence shape + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + + for i in range(L): + if self.task_embed_option == "add_task_embed": + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + elif self.task_embed_option == "concat_task_embed": + obs = torch.cat([obs_embeddings[:, i, :, :], task_emb_expanded], dim=-1) + else: + obs = obs_embeddings[:, i, :, :] # Shape: (B, K, E) + + act = act_embeddings[:, i, 0, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + # print(f'obs_act.shape:{obs_act.shape}') + + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + + #@profile + # def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps, task_id=0): + # """ + # Process combined observation embeddings and action tokens. + + # Arguments: + # - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + # - prev_steps (:obj:`torch.Tensor`): Previous steps. + # Returns: + # - torch.Tensor: Combined observation and action embeddings with position information added. + # """ + # obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + # if len(obs_embeddings.shape) == 3: + # obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + # -1) + + # num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + # # act_embeddings = self.act_embedding_table[task_id](act_tokens) + # act_embeddings = self.act_embedding_table(act_tokens) + + # B, L, K, E = obs_embeddings.size() + # obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + # for i in range(L): + # # obs = obs_embeddings[:, i, :, :] + # obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + # act = act_embeddings[:, i, 0, :].unsqueeze(1) + # obs_act = torch.cat([obs, act], dim=1) + # obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + # return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + #@profile + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=0): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + #@profile + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, task_id = 0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + else: + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) # ============= TODO: no task_embeddings now ============= + + + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) + + if batch_current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + + if self.use_task_embed and self.task_embed_option == "register_task_embed": + self.latent_state = current_obs_embeddings + elif not self.use_task_embed: + self.latent_state = current_obs_embeddings + + # ================ NOTE ================ + # import ipdb; ipdb.set_trace() + # self.latent_state 是原来的obs_embeddings与task_embedding的组合: add或者concat + if self.use_task_embed and self.task_embed_option == "add_task_embed": + self.latent_state = current_obs_embeddings + self.task_embeddings + if self.use_task_embed and self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(current_obs_embeddings.shape[0], current_obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([current_obs_embeddings, task_emb_expanded], dim=-1) + # ================ NOTE ================ + + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) + else: + # ================ calculate the target value in Train phase ================ + + # self.latent_state = obs_embeddings + + # ================ NOTE ================ + # import ipdb; ipdb.set_trace() + # self.latent_state 是原来的obs_embeddings与task_embedding的组合: add或者concat + if self.use_task_embed and self.task_embed_option == "add_task_embed": + self.latent_state = obs_embeddings + self.task_embeddings + elif self.use_task_embed and self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + else: + self.latent_state = obs_embeddings + + # print(f" Train phase self.latent_state.shape: {self.latent_state.shape}") + # ================ NOTE ================ + + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + + #@profile + @torch.no_grad() + def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, task_id = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(self.latent_state, is_init_infer=True) + else: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state( + state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + batch_action = batch_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + if self.continuous_action_space: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) + else: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(self.latent_state, is_init_infer=True) + else: + # import ipdb; ipdb.set_trace() + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + elif batch_action is not None and current_obs_embeddings is None: + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_act_embed_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + + if self.continuous_action_space: + act_tokens = batch_action + else: + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) + + # if self.reanalyze_phase: + # # TODO + # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, is_init_infer=False, task_id=task_id) + # else: + # outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, is_init_infer=True, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + + #@profile + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + # import ipdb; ipdb.set_trace() + + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + if not self.continuous_action_space: + token = action.reshape(-1, 1) + else: + token = action.reshape(-1, self.config.action_space_size_list[task_id]) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # try: + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + # except Exception as e: + # print(e) + # import ipdb; ipdb.set_trace() + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + # if self.task_embed_option == "register_task_embed": + # # kv_cache, z/a, register_token + # # 这样修复后kv_cache的位置编码不是从0开始的, 那后面按照从零开始矫正也就是错误的, + # # 但是由于self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1,所以不会矫正 + # # 但是在_add_position_embeddings时,prev_steps是错误的,导致新增的z/a的位置编码索引与前面的kv不连续 + # # import ipdb; ipdb.set_trace() + # print(f'self.keys_values_wm_size_list_current:{self.keys_values_wm_size_list_current}') + # print(f'self.keys_values_wm.size:{self.keys_values_wm.size}') + # self.keys_values_wm_size_list_current = [min(self.keys_values_wm.size, i + 1) for i in self.keys_values_wm_size_list_current] + # else: + # self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + # print(f'token.shape:{token.shape}') + + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + #@profile + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + # if self.task_embed_option == "register_task_embed": + # context_length = self.context_length - self.register_token_num + # else: + # context_length = self.context_length + + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # import ipdb; ipdb.set_trace() + + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + # import ipdb; ipdb.set_trace() + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # import ipdb; ipdb.set_trace() + + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + if is_init_infer: + # Store the latest key-value cache for initial inference + # import ipdb; ipdb.set_trace() + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + else: + # Store the latest key-value cache for recurrent inference + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + #@profile + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + # import ipdb; ipdb.set_trace() + matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + # if self.reanalyze_phase: + # self.forward( + # {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + # past_keys_values=self.keys_values_wm_single_env, is_init_infer=False, task_id=task_id + # ) + # else: + # self.forward( + # {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + # past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + # ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + + def plot_embeddings(self, tsne_results, task_ids, observations, samples_per_task=5, save_dir='tsne_plots_26games'): + """ + 生成 t-SNE 可视化图,并在图中为每个任务随机标注指定数量的观测样本图像。 + + 参数: + - tsne_results: t-SNE 降维结果 (N x 2 的数组) + - task_ids: 环境任务 ID,用于着色 (N 的数组) + - observations: 对应的观测样本 (N x C x H x W 的张量或数组) + - samples_per_task: 每个任务选择的样本数量,默认 5 + - save_dir: 保存路径,默认 'tsne_plots_26games' + """ + + # 创建保存目录 + os.makedirs(save_dir, exist_ok=True) + print(f"[INFO] 保存目录已创建或已存在: {save_dir}") + + # 创建 t-SNE 图 + print("[INFO] 开始绘制 t-SNE 散点图...") + plt.figure(figsize=(18, 10)) # 增大图像宽度以适应右侧图例 + + # 散点图 + scatter = plt.scatter( + tsne_results[:, 0], + tsne_results[:, 1], + c=[self.colors[tid] for tid in task_ids], + alpha=0.6, + edgecolor='w', + linewidth=0.5 + ) + + # 创建自定义图例 + legend_elements = [] + for idx, env_id in enumerate(self.env_id_list): + short_name = self.env_short_names.get(env_id, env_id) + color = self.colors[idx] + legend_elements.append( + Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") + ) + + # 将图例放在图像右侧,并且每个图例项占一行 + plt.legend( + handles=legend_elements, + title="Environment IDs", + loc='center left', + bbox_to_anchor=(1, 0.5), # 图例在图像右侧中央 + fontsize=10, + title_fontsize=12, + ncol=1, + frameon=False # 去除图例边框,增强美观 + ) + + # 设置标题和轴标签 + plt.title("t-SNE of Latent States across Environments", fontsize=16) + plt.xlabel("t-SNE Dimension 1", fontsize=14) + plt.ylabel("t-SNE Dimension 2", fontsize=14) + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + plt.grid(True, linestyle='--', alpha=0.5) + print(f"[INFO] t-SNE 散点图绘制完成,共有 {len(tsne_results)} 个点。") + + # 为每个任务选择指定数量的样本进行图像标注 + print(f"[INFO] 开始为每个任务选择 {samples_per_task} 个样本进行图像标注...") + for task_id in range(len(self.env_id_list)): + # 找到当前任务的所有索引 + task_indices = np.where(task_ids == task_id)[0] + if len(task_indices) == 0: + print(f"[WARNING] 任务 ID {task_id} 没有对应的样本。") + continue + # 如果样本数量少于所需,全部选取 + if len(task_indices) < samples_per_task: + selected_indices = task_indices + print(f"[INFO] 任务 ID {task_id} 的样本数量 ({len(task_indices)}) 少于 {samples_per_task},选取全部。") + else: + selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) + print(f"[INFO] 任务 ID {task_id} 随机选取 {samples_per_task} 个样本进行标注。") + + for idx in selected_indices: + img = observations[idx] + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if img.shape[0] == 1 or img.shape[0] == 3: # 处理灰度图或 RGB 图 + img = np.transpose(img, (1, 2, 0)) + else: + raise ValueError(f"Unsupported image shape: {img.shape}") + + # 标准化图像到 [0,1] 范围 + img_min, img_max = img.min(), img.max() + if img_max - img_min > 1e-5: + img = (img - img_min) / (img_max - img_min) + else: + img = np.zeros_like(img) + + imagebox = OffsetImage(img, zoom=0.5) + ab = AnnotationBbox( + imagebox, + (tsne_results[idx, 0], tsne_results[idx, 1]), + frameon=False, + pad=0.3 + ) + plt.gca().add_artist(ab) + print(f"[INFO] 已添加图像标注: 任务 ID {task_id}, 点索引 {idx}, t-SNE 坐标 ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # 调整布局以适应图例 + plt.tight_layout(rect=[0, 0, 0.9, 1]) # 为右侧的图例预留空间 + + # 保存图像,使用高分辨率 + save_path_png = os.path.join(save_dir, 'tsne_plot.png') + save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') + plt.savefig(save_path_png, dpi=300, bbox_inches='tight') + plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') + print(f"[INFO] t-SNE 可视化图已保存至: {save_path_png} 和 {save_path_pdf}") + plt.close() + + @torch.no_grad() + def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): + world_size = dist.get_world_size() + rank = dist.get_rank() + + # 准备接收来自所有进程的CUDA张量 + embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] + task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] + + # 准备接收来自所有进程的CPU对象 + observations_list = [None for _ in range(world_size)] + + try: + # 收集CUDA张量:embeddings和task_ids + dist.all_gather(embeddings_list, local_embeddings) + dist.all_gather(task_ids_list, local_task_ids) + + # 收集CPU对象:observations + local_observations_cpu = local_observations.cpu().numpy().tolist() + dist.all_gather_object(observations_list, local_observations_cpu) + except RuntimeError as e: + print(f"Rank {rank}: all_gather failed with error: {e}") + return + + if rank == 0: + # 拼接所有embeddings和task_ids + all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() + all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() + + # 拼接所有observations + all_observations = [] + for obs in observations_list: + all_observations.extend(obs) + all_observations = np.array(all_observations) + + print(f"Shape of all_embeddings: {all_embeddings.shape}") + all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) + print(f"Shape of all_observations: {all_observations.shape}") + all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) + + # 执行t-SNE降维 + tsne = TSNE(n_components=2, random_state=42) + tsne_results = tsne.fit_transform(all_embeddings) + + # 绘制并保存图像 + self.plot_embeddings(tsne_results, all_task_ids, all_observations, save_dir=f'tsne_plots_{self.num_tasks}games') + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + if self.analysis_tsne: + # =========== tsne analysis =========== + # 确保embeddings在CUDA设备上且为稠密张量 + if not obs_embeddings.is_cuda: + obs_embeddings = obs_embeddings.cuda() + obs_embeddings = obs_embeddings.contiguous() + + # 保存当前进程的 embeddings 和 task_id + local_embeddings = obs_embeddings.detach() + local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) + + # 将observations移到CPU并转换为numpy + local_observations = batch['observations'].detach().cpu() + + # 进行数据收集和可视化 + self.gather_and_plot(local_embeddings, local_task_ids, local_observations) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio_weight_rank: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + if self.continuous_action_space: + encoder_index = task_id + else: + encoder_index = 0 + dormant_ratio_encoder_dict = cal_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(), + dormant_threshold=self.dormant_threshold) + + + # print(dormant_ratio_encoder_dict) + dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] + + # 计算全局平均权重绝对值 + avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder[encoder_index]) + # print("Average Weight Magnitude of encoder:", avg_weight_mag_encoder) + # 计算全局平均权重绝对值 + avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) + # print("Average Weight Magnitude of transformer:", avg_weight_mag_transformer) + # print(f"self.head_dict:{self.head_dict}") + avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) + # print("Average Weight Magnitude of head:", avg_weight_mag_head) + + # 计算 effective rank,对于 representation 层,注意: + # representation 层在 model.named_modules() 的名称为 "representation" + # print(f"self.tokenizer.encoder:{self.tokenizer.encoder}") + + e_rank_last_linear = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") + # print("Effective Rank of encoder_last_linear:", e_rank_last_linear) + e_rank_sim_norm = cal_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") + # print("Effective Rank of encoder_sim_norm:", e_rank_sim_norm) + + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # ==== for value priority ==== + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio_weight_rank: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + dormant_threshold=self.dormant_threshold) + dormant_ratio_transformer = dormant_ratio_world_model['transformer'] + dormant_ratio_head = dormant_ratio_world_model['head'] + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_transformer = torch.tensor(0.) + dormant_ratio_head = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + if self.use_task_embed and self.task_embed_option == "concat_task_embed": + # print(f'=='*20) + # print(f'labels_observations.shape:{labels_observations.shape}') + # print(f'=='*20) + # Expand task embeddings to match the sequence shape + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + task_emb_expanded = self.task_embeddings.expand(labels_observations.shape[0], -1) + # print(f'task_emb_expanded:{task_emb_expanded}') + # print(f"task_emb_expanded.shape: {task_emb_expanded.shape}") + # print(f"task_emb_expanded (min, max, mean): {task_emb_expanded.min()}, {task_emb_expanded.max()}, {task_emb_expanded.mean()}") + # assert not torch.isnan(task_emb_expanded).any(), "task_emb_expanded 存在 NaN 值" + # print(f"logits_observations.shape: {logits_observations.shape}") + labels_observations = torch.cat([labels_observations, task_emb_expanded.detach()], dim=-1) # NOTE: detach() + # print(f"labels_observations.shape: {labels_observations.shape}") + # assert logits_observations.shape == labels_observations.shape, "logits 和 labels 的形状不匹配" + + + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # assert not torch.isnan(logits_reshaped).any(), "logits_reshaped contains NaN values" + # assert not torch.isnan(labels_reshaped).any(), "labels_reshaped contains NaN values" + # print('loss_obs:', loss_obs.mean()) + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # logits_grad = torch.autograd.grad(loss_obs.mean(), logits_observations, retain_graph=True)[0] + # print(f"logits_grad (min, max, mean): {logits_grad.min()}, {logits_grad.max()}, {logits_grad.mean()}") + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + if not self.continuous_action_space: + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, + batch, + element='policy') + else: + # NOTE: for continuous action space + if self.config.policy_loss_type == 'simple': + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont_simple( + outputs, batch) + else: + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + outputs, batch, task_id=task_id) + + loss_policy = orig_policy_loss + self.policy_entropy_weight * policy_entropy_loss + policy_entropy = - policy_entropy_loss + + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, + latent_state_l2_norms=latent_state_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, + latent_state_l2_norms=latent_state_l2_norms, + ) + + #@profile + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + # if torch.isnan(loss).any(): + # raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + #@profile + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + #@profile + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_ends = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return None, labels_value.reshape(-1, self.support_size) + else: + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + #@profile + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + + print(f'rank {self._rank} Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/model/utils.py b/lzero/model/utils.py index 70a89d3b0..c849aedca 100644 --- a/lzero/model/utils.py +++ b/lzero/model/utils.py @@ -3,102 +3,263 @@ In this file, we provide a set of utility functions for probing network parameters and gradients, which can be helpful in analyzing and debugging the inner workings of various models. """ -from typing import List, Tuple - +from typing import List, Tuple, Union, Dict +from torch.nn import functional as F import numpy as np import torch import torch.nn as nn +############################### +# 1. 计算 average_weight_magnitude +############################### +def compute_average_weight_magnitude(model: nn.Module) -> float: + """ + 计算模型中所有参数的平均绝对值。 + + Arguments: + model: 待评估模型,类型为 nn.Module -class LinearOutputHook: + Returns: + 平均权重绝对值(float) """ - Overview: - Hook to capture the output of linear layers. + num_weights = 0 + # 使用模型中第一个参数的设备,保证计算时设备一致 + device = next(model.parameters()).device + sum_weight_magnitude = torch.tensor(0.0, device=device) + + for p in model.parameters(): + num_weights += p.numel() + sum_weight_magnitude += torch.sum(torch.abs(p)) + + if num_weights == 0: + return 0.0 + return sum_weight_magnitude.cpu().item() / num_weights + +############################### +# 2. 计算 effective_rank +############################### +def compute_effective_rank(singular_values: np.ndarray) -> float: """ + 根据给定的奇异值数组计算 effective rank,公式为: + effective_rank = exp( - sum_i [p_i * log(p_i)] ) + 其中 p_i 是归一化后的奇异值(p_i = s_i / ∑ s_i) + + Arguments: + singular_values: 奇异值数组,类型为 np.ndarray + Returns: + effective rank(float) + """ + norm_sv = singular_values / np.sum(np.abs(singular_values)) + entropy = 0.0 + for p in norm_sv: + if p > 0.0: + entropy -= p * np.log(p) + return np.e ** entropy + + +# 定义一个 Hook 类,用来捕获中间层的输出 +class IntermediateOutputHook: + """ + 用于捕获模块输出的 Hook,保存输出张量列表。 + """ def __init__(self): - """ - Overview: - Initialize the hook. - """ self.outputs: List[torch.Tensor] = [] def __call__(self, module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor) -> None: - """ - Overview: - Capture the output of the module. - Arguments: - - module: The module being hooked. - - input: The input to the module (unused in this hook). - - output: The output from the module. - """ - self.outputs.append(output) - - -def cal_dormant_ratio(model: nn.Module, *inputs: torch.Tensor, percentage: float = 0.025) -> float: + # 这里使用 detach 防止反向传播干扰,并转移到 CPU 便于后续统计 + self.outputs.append(output.detach().cpu()) + +def cal_effective_rank( + model: nn.Module, + inputs: Union[torch.Tensor, List[torch.Tensor]], + representation_layer_name: str, +) -> float: """ - Overview: - Calculate the dormant neuron ratio in the model. A neuron is considered dormant if its output is less than a - specified percentage of the average output of the layer. This function is useful for analyzing the sparsity of the model. - More details can be found in the paper https://arxiv.org/abs/2302.12902. + 针对模型指定的中间层(representation 层), + 使用 Hook 捕获该层输出,并计算 effective rank。 + + Arguments: + model: 待评估模型,应为 nn.Module 类型。 + inputs: 模型 forward 的输入,可以为 tensor 或 tensor-list。 + representation_layer_name: 模型中表示 representation 层的名称, + 该名称必须能够在 model.named_modules() 中找到对应模块。 + + Returns: + effective rank(float) + """ + # 获取 representation 层模块(若名称不存在将引发 KeyError) + module_dict = dict(model.named_modules()) + if representation_layer_name not in module_dict: + raise KeyError(f"Representation layer '{representation_layer_name}' not found in model.named_modules().") + representation_module = module_dict[representation_layer_name] + + # 注册 hook + hook = IntermediateOutputHook() + handle = representation_module.register_forward_hook(hook) + + # 执行 forward 推理 + model.eval() + with torch.no_grad(): + if isinstance(inputs, (list, tuple)): + _ = model(*inputs) + else: + _ = model(inputs) + + # 注销 hook,避免内存泄露 + handle.remove() + + if not hook.outputs: + raise RuntimeError("No outputs captured from the representation layer.") + + # 这里假定有一个或多个 forward(例如在 batch 或多次调用的场景), + # 将所有输出在 batch 维度上拼接 + if len(hook.outputs) > 1: + rep_tensor = torch.cat(hook.outputs, dim=0) + else: + rep_tensor = hook.outputs[0] + + # 将 representation 展开为二维矩阵: (samples, features) + rep_tensor = rep_tensor.view(rep_tensor.size(0), -1) + + # 将 tensor 转换为 numpy 数组以使用 numpy.linalg.svd + rep_np = rep_tensor.cpu().numpy() + + # 计算奇异值 + singular_values = np.linalg.svd(rep_np, full_matrices=False, compute_uv=False) + + # 计算 effective rank + e_rank = compute_effective_rank(singular_values) + + # 清空 hook 存储(若需要多次调用可以保持清洁状态) + hook.outputs.clear() + return e_rank + + + +def compute_dormant_stats(outputs: List[torch.Tensor], threshold: float) -> Tuple[int, int]: + """ + 对给定的一组输出(同一层可能 forward 多次)进行元素级统计。 + Arguments: - - model: The model to evaluate. - - inputs: The inputs to the model. - - percentage: The threshold percentage to consider a neuron dormant, defaults to 0.025. + outputs: List[torch.Tensor],每个 tensor 表示一次 forward 的输出 + threshold: 判断 dormant 的阈值,当激活值 <= threshold 时视为 dormant + Returns: - - float: The ratio of dormant neurons in the model. + layer_total: 该层总元素数(累加多个 forward) + layer_dormant: 该层中满足 dormant 条件的元素数目 """ - # List to store hooks and their handlers - hooks: List[LinearOutputHook] = [] - hook_handlers: List[torch.utils.hooks.RemovableHandle] = [] - total_neurons: int = 0 - dormant_neurons: int = 0 - - # Register hooks to capture outputs of specific layers - for _, module in model.named_modules(): - if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM)): - hook = LinearOutputHook() - hooks.append(hook) - hook_handlers.append(module.register_forward_hook(hook)) + layer_total = 0 + layer_dormant = 0 + for out in outputs: + flattened = out.view(-1) + total = flattened.numel() + dormant = torch.sum(flattened <= threshold).item() + layer_total += total + layer_dormant += dormant + return layer_total, layer_dormant +def cal_dormant_ratio( + model: nn.Module, + inputs: Union[torch.Tensor, List[torch.Tensor]], + dormant_threshold: float = 1e-2, +) -> Dict[str, float]: + """ + 针对模型中 encoder、transformer backbone 以及 head 三个部分, + 分别统计各部分中所有目标层(例如 nn.Conv2d、nn.Linear、nn.MultiheadAttention 等)的 + dormant ratio(元素级 dormant 百分比),同时返回全局统计指标。 + + Arguments: + model: 待评估模型,应包含属性 encoder、transformer(backbone)以及 head(可选)。 + inputs: 模型的输入,支持 tensor 或 tensor-list,要求与模型 forward 调用一致。 + dormant_threshold: 激活值低于该阈值时视为 dormant,默认 1e-2。 + + Returns: + results: 包含各部分以及全局 dormant ratio 的字典,单位为百分比(%)。 + 如:{"encoder": 2.5, "transformer": 1.8, "head": 0.5, "global": 1.6} + """ + + # 我们将统计分类为三个部分 + parts = {} + if hasattr(model, "encoder"): + parts["encoder"] = model.encoder + if hasattr(model, "transformer"): + parts["transformer"] = model.transformer + + # 对于 head 部分,查找所有以 "head_" 开头的子模块 + # head_dict = {} + # for name, module in model.named_children(): + # if name.startswith("head_"): + # head_dict[name] = module + # if head_dict: + # parts["head"] = nn.ModuleDict(head_dict) + + if hasattr(model, "head_dict"): + parts["head"] = model.head_dict + + if not hasattr(model, "encoder") and not hasattr(model, "transformer") and not hasattr(model, "head"): + # 如果传入的是self.tokenizer.encoder + parts["model"] = model + + # 定义要捕获的目标模块类型 TODO: 增加更多模块 + target_modules = (nn.Conv2d, nn.Linear) + + # 用于存储各部分的 hook(字典:部分名 -> list of (module_name, hook)) + hooks_dict = {part: [] for part in parts} + hook_handles = [] + + # 为每个部分中的满足类型条件的模块注册 hook + for part_name, submodule in parts.items(): + for name, module in submodule.named_modules(): + if isinstance(module, target_modules): + hook = IntermediateOutputHook() + # 为了避免名称冲突,加上所属部分前缀 + full_name = f"{part_name}/{name}" + hooks_dict[part_name].append((full_name, hook)) + handle = module.register_forward_hook(hook) + hook_handles.append(handle) + + # 调用 forward,执行一次推理 + model.eval() with torch.no_grad(): - # Forward pass to capture outputs - model(*inputs) - - # Analyze the captured outputs - for module, hook in zip((module for module in model.modules() if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM))), hooks): - with torch.no_grad(): - for output_data in hook.outputs: - mean_output = output_data.abs().mean(0) - avg_neuron_output = mean_output.mean() - dormant_indices = (mean_output < avg_neuron_output * percentage).nonzero(as_tuple=True)[0] - - if isinstance(module, nn.Linear): - # Calculate total and dormant neurons for Linear layers - total_neurons += module.weight.shape[0] * output_data.shape[0] - dormant_neurons += len(dormant_indices) - elif isinstance(module, nn.Conv2d): - # Calculate total and dormant neurons for Conv2D layers - total_neurons += module.weight.shape[0] * output_data.shape[0] * output_data.shape[2] * output_data.shape[3] - dormant_neurons += len(dormant_indices) - elif isinstance(module, nn.LSTM): - # Calculate total and dormant neurons for LSTM layers - total_neurons += module.hidden_size * module.num_layers * output_data.shape[0] * output_data.shape[1] - dormant_neurons += len(dormant_indices) - - # Clean up hooks - for hook in hooks: - hook.outputs.clear() - del hook.outputs - - for hook_handler in hook_handlers: - hook_handler.remove() - del hook_handler - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return dormant_neurons / total_neurons + if isinstance(inputs, (list, tuple)): + _ = model(*inputs) + else: + _ = model(inputs) + + # 统计各部分各个模块的 dormant 数量和总数 + results = {} + total_global = 0 + dormant_global = 0 + for part, hooks in hooks_dict.items(): + part_total = 0 + part_dormant = 0 + for full_name, hook in hooks: + layer_total, layer_dormant = compute_dormant_stats(hook.outputs, dormant_threshold) + # if part == "model": + # print(hook.outputs) + # 可打印日志,也可记录更详细信息 + # print(f"{full_name}: {layer_dormant}/{layer_total} -> {layer_dormant / layer_total * 100.0 if layer_total > 0 else 0.0}%") + part_total += layer_total + part_dormant += layer_dormant + if part_total > 0: + ratio = (part_dormant / part_total) * 100.0 + else: + ratio = 0.0 + results[part] = ratio + total_global += part_total + dormant_global += part_dormant + + results["global"] = (dormant_global / total_global) * 100.0 if total_global > 0 else 0.0 + + # 清理所有 hook + for handle in hook_handles: + handle.remove() + for hooks in hooks_dict.values(): + for _, hook in hooks: + hook.outputs.clear() + + return results def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: """ diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 5311062b0..55bf0dc6c 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -940,7 +940,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ return output - def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: """ Overview: Reset the observation and action for the collector environment. @@ -955,7 +955,7 @@ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: """ Overview: Reset the observation and action for the evaluator environment. @@ -969,6 +969,7 @@ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: self._cfg.device ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + def _monitor_vars_learn(self) -> List[str]: """ Overview: diff --git a/lzero/policy/muzero_multitask.py b/lzero/policy/muzero_multitask.py new file mode 100644 index 000000000..c65ccc5e8 --- /dev/null +++ b/lzero/policy/muzero_multitask.py @@ -0,0 +1,890 @@ +import copy +from typing import List, Dict, Tuple, Union, Optional + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + cross_entropy_loss, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + negative_cosine_similarity, + prepare_obs, +) +from lzero.policy.muzero import MuZeroPolicy + + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'loss_task{}' + :param task_id: 任务起始ID + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + +@POLICY_REGISTRY.register('muzero_multitask') +class MuZeroMTPolicy(MuZeroPolicy): + """ + 概述: + MuZero 的多任务策略类,扩展自 MuZeroPolicy。支持同时训练多个任务,通过分离每个任务的损失并进行优化。 + """ + + # MuZeroMTPolicy 的默认配置 + config = dict( + type='muzero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(4, 96, 96), # example shape + self_supervised_learning_loss=False, + categorical_distribution=True, + image_channel=1, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=300, + bias=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + harmony_balance=False, + ), + # ****** common ****** + use_rnd_model=False, + multi_gpu=False, + sampled_algo=False, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=200, + eval_offline=False, + cal_dormant_ratio=False, + analysis_sim_norm=False, + analysis_dormant_ratio=False, + + # ****** observation ****** + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + # ******* learn ****** + use_wandb=False, + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='SGD', + learning_rate=0.2, + target_update_freq=100, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=10, + n_episode=8, + num_segments=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=5, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + policy_entropy_weight=0, + ssl_loss_weight=0, + lr_piecewise_constant_decay=True, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + + # ****** UCB ****** + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + + # ****** 多任务相关 ****** + task_num=2, # 任务数量,根据实际需求调整 + task_id=0, # 当前任务的起始ID + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + 概述: + 返回该算法的默认模型设置。 + 返回: + - model_info (:obj:`Tuple[str, List[str]]`): 模型名称和模型导入路径列表。 + """ + return 'MuZeroMTModel', ['lzero.model.muzero_model_multitask'] + + def _init_learn(self) -> None: + """ + 概述: + 学习模式初始化方法。初始化学习模型、优化器和MCTS工具。 + """ + super()._init_learn() + + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # harmonydream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + # 初始化多任务相关参数 + self.task_num_for_current_rank = self._cfg.task_num + self.task_id = self._cfg.task_id + + def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Dict[str, Union[float, int]]: + """ + 概述: + 学习模式的前向函数,是学习过程的核心。数据从重放缓冲区采样,计算损失并反向传播更新模型。 + 参数: + - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): 每个任务的数据元组列表, + 每个元组包含 (current_batch, target_batch, task_id)。 + 返回: + - info_dict (:obj:`Dict[str, Union[float, int]]`): 用于记录的信息字典,包含当前学习损失和学习统计信息。 + """ + self._learn_model.train() + self._target_model.train() + + # 初始化多任务损失列表 + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + consistency_loss_multi_task = [] + policy_entropy_multi_task = [] + lambd_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 # 初始化为0 + losses_list = [] # 用于存储每个任务的损失 + + for task_idx, (current_batch, target_batch, task_id) in enumerate(data): + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # 数据增强 + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # 准备动作批次并转换为张量 + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor( + data_list, self._cfg.device + ) + + target_reward = target_reward.view(self._cfg.batch_size[task_idx], -1) + target_value = target_value.view(self._cfg.batch_size[task_idx], -1) + + assert obs_batch.size(0) == self._cfg.batch_size[task_idx] == target_reward.size(0) + + # 变换奖励和价值到缩放形式 + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # 转换为类别分布 + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # 初始推理 + network_output = self._learn_model.initial_inference(obs_batch, task_id=task_id) + + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # 记录 Dormant Ratio 和 L2 Norm + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio( + self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold + ) + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + + # 逆变换价值 + original_value = self.inverse_scalar_transform_handle(value) + + # 初始化预测值和策略 + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # 计算优先级 + value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # 计算第一个步骤的策略和价值损失 + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + target_policy_entropy = 0 + + # 循环进行多个unroll步骤 + for step_k in range(self._cfg.num_unroll_steps): + # 使用动态函数进行递归推理 + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # 记录 Dormant Ratio + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + action_tmp = action_batch[:, step_k] + if len(action_tmp.shape) == 1: + action_tmp = action_tmp.unsqueeze(-1) + # 转换动作为独热编码 + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio( + self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold + ) + + # 逆变换价值 + original_value = self.inverse_scalar_transform_handle(value) + + # 计算一致性损失 + if self._cfg.model.self_supervised_learning_loss and self._cfg.ssl_loss_weight > 0: + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index], task_id=task_id) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # 计算策略和价值损失 + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + # 计算策略熵损失 + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss += -entropy + + # 计算目标策略熵(仅用于调试) + target_normalized_visit_count = target_policy[:, step_k + 1] + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -( + (target_normalized_visit_count_masked + 1e-6) * + torch.log(target_normalized_visit_count_masked + 1e-6) + ).sum(-1).mean() + else: + target_policy_entropy += torch.log( + torch.tensor(target_normalized_visit_count.shape[-1], device=self._cfg.device) + ) + + + # 记录预测值和奖励(如果监控额外统计) + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat( + (predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu()) + ) + + # 核心学习模型更新步骤 + weighted_loss = self._cfg.policy_loss_weight * policy_loss + \ + self._cfg.value_loss_weight * value_loss + \ + self._cfg.reward_loss_weight * reward_loss + \ + self._cfg.ssl_loss_weight * consistency_loss + \ + self._cfg.policy_entropy_weight * policy_entropy_loss + + # 将多个任务的损失累加 + weighted_total_loss += weighted_loss.mean() + + # 保留每个任务的损失用于日志记录 + reward_loss_multi_task.append(reward_loss.mean().item()) + policy_loss_multi_task.append(policy_loss.mean().item()) + value_loss_multi_task.append(value_loss.mean().item()) + consistency_loss_multi_task.append(consistency_loss.mean().item()) + policy_entropy_multi_task.append(policy_entropy_loss.mean().item()) + lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) # TODO: 如果使用梯度校正,可以在这里调整 + value_priority_multi_task.append(value_priority.mean().item()) + value_priority_mean_multi_task.append(value_priority.mean().item()) + losses_list.append(weighted_loss.mean().item()) + + # 清零优化器的梯度 + self._optimizer.zero_grad() + + # 反向传播 + weighted_total_loss.backward() + + # 梯度裁剪 + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), + self._cfg.grad_clip_value + ) + + # 多GPU训练时同步梯度 + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # 更新优化器 + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # 更新目标模型 + self._target_model.update(self._learn_model.state_dict()) + + # 获取GPU内存使用情况 + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0.0 + max_memory_allocated_gb = 0.0 + + # 构建返回的损失字典 + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # print(f'self.task_id:{self.task_id}') + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(consistency_loss_multi_task, 'noreduce_consistency_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd_multi_task, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + + # 返回最终的损失字典 + return return_loss_dict + + def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: + """ + Overview: + Reset the observation and action for the collector environment. + Arguments: + - data_id (`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + """ + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: + """ + Overview: + Reset the observation and action for the evaluator environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + """ + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: + """ + 概述: + 注册学习模式中需要监控的变量。注册的变量将根据 `_forward_learn` 的返回值记录到tensorboard。 + 如果提供了 `num_tasks`,则为每个任务生成监控变量。 + 参数: + - num_tasks (:obj:`int`, 可选): 任务数量。 + 返回: + - monitored_vars (:obj:`List[str]`): 需要监控的变量列表。 + """ + # 基本监控变量 + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # 任务特定的监控变量 + task_specific_vars = [ + 'noreduce_consistency_loss', + 'noreduce_reward_loss', + 'noreduce_policy_loss', + 'noreduce_value_loss', + 'noreduce_policy_entropy', + 'noreduce_lambd', + 'noreduce_value_priority', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + print(f'self.task_num_for_current_rank: {self.task_num_for_current_rank}') + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([8, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(8)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - epsilon: :math:`(1, )`. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, + data, task_id=task_id) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + else: + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type in ['conv', 'conv_context']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(3)] + # elif self._cfg.model.model_type == 'mlp_context': + # self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape]).to(self._cfg.device) + # self.last_batch_action = [-1 for _ in range(3)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py new file mode 100644 index 000000000..ccdefb656 --- /dev/null +++ b/lzero/policy/sampled_unizero_multitask.py @@ -0,0 +1,1046 @@ +# /Users/puyuan/code/LightZero/lzero/policy/sample_unizero_multitask.py + +import copy +import logging +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import wandb +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import SampledUniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + prepare_obs, + prepare_obs_stack4_for_unizero +) +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt +import torch.nn.functional as F +import torch.distributed as dist +import sys +sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/') +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' + :param task_id: 基础任务 ID + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else float(task_loss) + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + def get_group_parameters(self): + """ + 返回一个字典,其中 key 为模块名或更细粒度的层, + value 为对应的参数列表。注意返回顺序应与 parameters()方法中参数的排列顺序一致。 + """ + groups = {} + groups['tokenizer'] = list(self.tokenizer.parameters()) + groups['transformer'] = list(self.transformer.parameters()) + groups['pos_emb'] = list(self.pos_emb.parameters()) + groups['act_embedding_table'] = list(self.act_embedding_table.parameters()) + + # 如 transformer 内部分层(假设 transformer.blocks 是列表) + if hasattr(self.transformer, 'blocks'): + # 若要单独统计 transformer 内各层,保持原 transformer 参数在 parameters() 中顺序不变, + # 可以在这里添加各层的切片,但需保证 parameters() 返回的顺序与此一致, + # 此处仅作为示例: + for i, layer in enumerate(self.transformer.blocks): + groups[f'transformer_layer_{i}'] = list(layer.parameters()) + return groups + +@POLICY_REGISTRY.register('sampled_unizero_multitask') +class SampledUniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for Sampled UniZero Multitask, combining multi-task learning with sampled-based MCTS. + This implementation extends the UniZeroPolicy to handle multiple tasks simultaneously while utilizing + sampled MCTS for action selection. It ensures scalability and correctness in multi-task environments. + """ + + # The default_config for Sampled UniZero Multitask policy. + config = dict( + type='sampled_unizero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(3, 64, 64), + self_supervised_learning_loss=True, + categorical_distribution=True, + image_channel=3, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=50, + bias=True, + res_connection_in_dynamics=True, + norm_type='LN', + analysis_sim_norm=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + tokens_per_block=2, + max_blocks=10, + max_tokens=20, + context_length=8, + gru_gating=False, + device='cpu', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + action_space_size=6, + group_size=8, + attention='causal', + num_layers=2, + num_heads=8, + embed_dim=768, + embed_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + support_size=101, + max_cache_size=5000, + env_num=8, + latent_recon_loss_weight=0., + perceptual_loss_weight=0., + policy_entropy_weight=5e-3, + predict_latent_loss_type='group_kl', + obs_type='image', + gamma=1, + dormant_threshold=0.025, + policy_loss_type='kl', + ), + ), + use_rnd_model=False, + multi_gpu=True, + sampled_algo=True, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=400, + analysis_sim_norm=False, + collect_with_pure_policy=False, + eval_freq=int(5e3), + sample_type='transition', + + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='AdamW', + learning_rate=0.0001, + init_w=3e-3, + target_update_freq=100, + target_update_theta=0.05, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=5, + n_episode=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=10, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + train_start_after_envsteps=0, + + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + random_collect_episode_num=0, + + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Return this algorithm's default model setting for demonstration. + """ + return 'SampledUniZeroMTModel', ['lzero.model.sampled_unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Learn mode init method. Initialize the learn model, optimizer, and MCTS utils. + """ + # Configure optimizer for world model + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + + if self._cfg.cos_lr_scheduler: + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, T_max=int(1e5), eta_min=0, last_epoch=-1 + ) + elif self._cfg.piecewise_decay_lr_scheduler: + # Example step scheduler, adjust milestones and gamma as needed + self.lr_scheduler = StepLR( + self._optimizer_world_model, step_size=int(5e4), gamma=0.1 + ) + + if self._cfg.model.continuous_action_space: + # Weight Init for the last output layer of gaussian policy head in prediction network. + init_w = self._cfg.init_w + self._model.world_model.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) + try: + self._model.world_model.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) + except Exception as exception: + logging.warning(exception) + + # Initialize target model + self._target_model = copy.deepcopy(self._model) + # Ensure torch version >= 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + # Soft target update + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + # if self._cfg.use_augmentation: + # self.image_transforms = ImageTransforms( + # self._cfg.augmentation, + # image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + # ) + + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # 创建 WrappedModel 实例,仅矫正部分参数,保持可扩展性 + # wrapped_model = WrappedModelV2( + # self._learn_model.world_model.tokenizer.encoder[0], # 假设只有一个编码器 + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # head 没有矫正梯度 + wrapped_model = WrappedModelV2( + self._learn_model.world_model.tokenizer.encoder, # TODO: one or N encoder inside + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # TODO + # 如果需要,可以在这里初始化梯度校正方法(如 MoCo, CAGrad) + # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + # self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training + self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training + + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + + + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]: + """ + Forward function for learning policy in learn mode, handling multiple tasks. + """ + self._learn_model.train() + self._target_model.train() + + # Initialize multi-task loss lists + task_weight_multi_task = [] + + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 + losses_list = [] # 存储每个任务的损失 + + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task + obs_batch_ori, action_batch, child_sampled_actions_batch, target_action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg, task_id) + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + if self._cfg.model.continuous_action_space: + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1) + else: + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_reward.astype('float32'), + target_value.astype('float32'), + target_policy, + weights + ] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for GPT model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape_list[task_id], int) or len(self._cfg.model.observation_shape_list[task_id]) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape_list[task_id]) + elif len(self._cfg.model.observation_shape_list[task_id]) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape_list[task_id]) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['child_sampled_actions'] = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device)[:, :-1] + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, + self._target_model.world_model.tokenizer, + self.inverse_scalar_transform_handle, + task_id=task_id + ) + if task_weights is not None: + weighted_total_loss += losses.loss_total * task_weights[task_id] + losses_list.append(losses.loss_total * task_weights[task_id]) + + task_weight_multi_task.append(task_weights[task_id]) + else: + weighted_total_loss += losses.loss_total + losses_list.append(losses.loss_total) + + task_weight_multi_task.append(1) + + + for loss_name, loss_value in losses.intermediate_losses.items(): + self.intermediate_losses[f"{loss_name}"] = loss_value + # print(f'{loss_name}: {loss_value.sum()}') + # print(f'{loss_name}: {loss_value[0][0]}') + + # print(f"=== 全局任务权重 (按 task_id 排列): {task_weights}") + # assert not torch.isnan(losses.loss_total).any(), f"Loss contains NaN values, losses.loss_total:{losses.loss_total}, losses:{losses}" + # assert not torch.isinf(losses.loss_total).any(), f"Loss contains Inf values, losses.loss_total:{losses.loss_total}, losses:{losses}" + + # Collect losses per task + obs_loss = self.intermediate_losses.get('loss_obs', 0.0) or 0.0 + reward_loss = self.intermediate_losses.get('loss_rewards', 0.0) or 0.0 + policy_loss = self.intermediate_losses.get('loss_policy', 0.0) or 0.0 + orig_policy_loss = self.intermediate_losses.get('orig_policy_loss', 0.0) or 0.0 + policy_entropy = self.intermediate_losses.get('policy_entropy', 0.0) or 0.0 + value_loss = self.intermediate_losses.get('loss_value', 0.0) or 0.0 + latent_recon_loss = self.intermediate_losses.get('latent_recon_loss', 0.0) or 0.0 + perceptual_loss = self.intermediate_losses.get('perceptual_loss', 0.0) or 0.0 + latent_state_l2_norms = self.intermediate_losses.get('latent_state_l2_norms', 0.0) or 0.0 + value_priority = torch.tensor(0., device=self._cfg.device) # Placeholder, adjust as needed + + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + average_target_policy_entropy_multi_task.append(average_target_policy_entropy) + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + # Core learn model update step + self._optimizer_world_model.zero_grad() + + # 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表 + # 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True + if self._cfg.use_moco: + # 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正 + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.only_use_moco_stats: + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + # 不使用梯度校正的情况,由各 rank 自己执行反向传播 + weighted_total_loss.backward() + else: + # 不使用梯度校正的情况,由各 rank 自己执行反向传播 + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) + + if self._cfg.multi_gpu: + # if not self._cfg.use_moco or self._cfg.only_use_moco_stats: + # self.sync_gradients(self._learn_model) + if not self._cfg.use_moco: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) + + # 获取GPU内存使用情况 + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # 构建损失字典 + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # if task_weights is None: + # task_weights = {self.task_id+i: 1 for i in range(self.task_num_for_current_rank)} + # else: + # print(f'task_weights:{task_weights}') + # from ding.utils import EasyTimer, set_pkg_seed, get_rank + + # print(f'rank:{get_rank()}, task_id:{self.task_id}') + + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(task_weight_multi_task, 'noreduce_task_weight_task{}', task_id=self.task_id), + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # print(f'multi_task_loss_dicts:{ multi_task_loss_dicts}') + + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + + # 如果需要,可以将损失字典记录到日志或其他地方 + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_loss_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_loss_dict + + # TODO: num_tasks + def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + If num_tasks is provided, generate monitored variables for each task. + """ + # Basic monitored variables that do not depend on the number of tasks + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # rank = get_rank() + task_specific_vars = [ + 'noreduce_task_weight', + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variables + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, we assume there's only one task and keep the original variable names + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def monitor_weights_and_grads(self, model): + """ + Monitor and print the weights and gradients of the model. + """ + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Collect mode init method. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._task_weight_temperature = 10. + + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros( + [self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64] + ).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros( + [self.collector_env_num, self._cfg.model.observation_shape_list[0]] + ).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Forward function for collecting data in collect mode, handling multiple tasks. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference( + self.last_batch_obs, + self.last_batch_action, + data, + task_id=task_id + ) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) + ] if not self._cfg.model.continuous_action_space else [ + [-1 for _ in range(self._cfg.model.world_model_cfg.num_of_sampled_actions)] + for _ in range(active_collect_env_num) + ] + + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.world_model_cfg.num_of_sampled_actions)) + .astype(np.float32).tolist() for _ in range(active_collect_env_num) + ] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots( + active_collect_env_num, + legal_actions, + self._cfg.model.world_model_cfg.action_space_size, + self._cfg.model.world_model_cfg.num_of_sampled_actions, + self._cfg.model.continuous_action_space + ) + else: + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + + # try: + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + # print("latent_state_roots.shape:", latent_state_roots.shape) + # except Exception as e: + # print("="*20) + # print(e) + # print("roots:", roots, "latent_state_roots:", latent_state_roots) + # print("latent_state_roots.shape:", latent_state_roots.shape) + # print("="*20) + # import ipdb; ipdb.set_trace() + + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([ + getattr(action, 'value', action) for action in roots_sampled_actions[i] + ]) + + # 选择动作 + action, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + + # 获取采样动作 + action = root_sampled_actions[action] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # 检查并重置采集器 + if active_collect_env_num < self.collector_env_num: + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + + return output + + def _init_eval(self) -> None: + """ + Evaluate mode init method. Initialize the eval model and MCTS utils. + """ + from ding.utils import EasyTimer, set_pkg_seed, get_rank + + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + self.task_id_for_eval = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs_eval = torch.zeros( + [self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64] + ).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs_eval = torch.zeros( + [self.evaluator_env_num, self._cfg.model.observation_shape_list[self.task_id_for_eval]] # TODO + ).to(self._cfg.device) + print(f'rank {get_rank()} last_batch_obs_eval:', self.last_batch_obs_eval.shape) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Forward function for evaluating the current policy in eval mode, handling multiple tasks. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference( + self.last_batch_obs_eval, + self.last_batch_action, + data, + task_id=task_id + ) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + # TODO:======== + # self._eval_model.training = False + # if not self._eval_model.training: + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num) + ] if not self._cfg.model.continuous_action_space else [ + [-1 for _ in range(self._cfg.model.world_model_cfg.num_of_sampled_actions)] + for _ in range(active_eval_env_num) + ] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots( + active_eval_env_num, + legal_actions, + self._cfg.model.world_model_cfg.action_space_size, + self._cfg.model.world_model_cfg.num_of_sampled_actions, + self._cfg.model.continuous_action_space + ) + else: + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # print(f'type(policy_logits): {type(policy_logits)}') + # print(f'policy_logits.shape: {policy_logits.shape}') + # print(f'policy_logits: {policy_logits}') + + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([ + getattr(action, 'value', action) for action in roots_sampled_actions[i] + ]) + + # 选择动作(确定性) + action, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + + # 获取采样动作 + action = root_sampled_actions[action] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Reset the collection process for a specific environment. + """ + if reset_init_data: + if task_id is not None: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.collector_env_num, + self._cfg.device + ) + else: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + logging.info(f'collector: last_batch_obs, last_batch_action reset() {self.last_batch_obs.shape}') + + if env_id is None or isinstance(env_id, list): + return + + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + if current_steps % clear_interval == 0: + logging.info(f'clear_interval: {clear_interval}') + + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + + logging.info('collector: collect_model clear()') + logging.info(f'eps_steps_lst[{env_id}]: {current_steps}') + + self._reset_target_model() + + def _reset_target_model(self) -> None: + """ + Reset the target model's caches. + """ + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + logging.info('collector: target_model past_kv_cache.clear()') + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Return the state_dict of learn mode, including model, target_model, and optimizer. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== TODO: original version: load all parameters ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + # """ + # # 定义需要排除的参数前缀 + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # 定义需要排除的具体参数(如果有特殊情况) + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 + # # 添加其他需要排除的具体参数名 + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # 过滤掉需要排除的参数。 + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") # 调试用 + # continue + # filtered[k] = v + # return filtered + + # # 过滤并加载 'model' 部分 + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # 过滤并加载 'target_model' 部分 + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数 + # if 'optimizer_world_model' in state_dict: + # optimizer_state_dict = state_dict['optimizer_world_model'] + # try: + # self._optimizer_world_model.load_state_dict(optimizer_state_dict) + # except Exception as e: + # print(f"Error loading optimizer state_dict: {e}") + # else: + # print("No 'optimizer_world_model' key found in the state_dict.") + + # # 如果需要,还可以加载其他部分,例如 scheduler 等 \ No newline at end of file diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index ad688f07e..a459275a7 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -80,8 +80,8 @@ class UniZeroPolicy(MuZeroPolicy): device='cpu', # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, - # (bool) Whether to analyze dormant ratio. - analysis_dormant_ratio=False, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, # (int) The shape of the action space. action_space_size=6, # (int) The size of the group, related to simulation normalization. @@ -129,6 +129,7 @@ class UniZeroPolicy(MuZeroPolicy): rope_theta=10000, # (int) The maximum sequence length for position encoding. max_seq_len=8192, + lora_r= 0, ), ), # ****** common ****** @@ -164,7 +165,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', # ****** observation ****** @@ -439,6 +440,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in ) weighted_total_loss = losses.loss_total + # 合并 intermediate_losses 字典,避免重复赋值 + # self.intermediate_losses.update(losses.intermediate_losses) + for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value @@ -454,7 +458,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in middle_step_losses = self.intermediate_losses['middle_step_losses'] last_step_losses = self.intermediate_losses['last_step_losses'] dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] - dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] + dormant_ratio_transformer = self.intermediate_losses['dormant_ratio_transformer'] + dormant_ratio_head = self.intermediate_losses['dormant_ratio_head'] + avg_weight_mag_encoder = self.intermediate_losses['avg_weight_mag_encoder'] + avg_weight_mag_transformer = self.intermediate_losses['avg_weight_mag_transformer'] + avg_weight_mag_head = self.intermediate_losses['avg_weight_mag_head'] + e_rank_last_linear = self.intermediate_losses['e_rank_last_linear'] + e_rank_sim_norm = self.intermediate_losses['e_rank_sim_norm'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" @@ -550,8 +560,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), - 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, #.item(), + 'analysis/dormant_ratio_transformer': dormant_ratio_transformer,#.item(), + 'analysis/dormant_ratio_head': dormant_ratio_head,#.item(), + + 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, + 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, + 'analysis/avg_weight_mag_head': avg_weight_mag_head, + 'analysis/e_rank_last_linear': e_rank_last_linear, + 'analysis/e_rank_sim_norm': e_rank_sim_norm, + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, @@ -603,8 +621,9 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.ndarray = None, - timestep: List = [0] + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, ) -> Dict: """ Overview: @@ -617,6 +636,7 @@ def _forward_collect( - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - timestep (:obj:`list`): The step index of the env in one episode. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -713,6 +733,8 @@ def _forward_collect( # ========= TODO: for muzero_segment_collector now ========= if active_collect_env_num < self.collector_env_num: + # 当collect_env中有一个环境先done时,传回的self.last_batch_obs的长度会减少1, transformer在检索kv_cache时需要知道env_id,实现比较复杂 + # 因此直接《self.collector_env_num》个环境的self.last_batch_action全部重置为-1,让transformer从0开始,避免检索错误 print('==========collect_forward============') print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') self._reset_collect(reset_init_data=True) @@ -740,8 +762,8 @@ def _init_eval(self) -> None: self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], - ready_env_id: np.array = None, timestep: List = [0]) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -752,6 +774,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to eval. - timestep (:obj:`list`): The step index of the env in one episode. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of eval_env, C is the number of channels, \ @@ -772,7 +795,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) # if not in training, obtain the scalars of the value/reward @@ -822,12 +845,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ } batch_action.append(action) - self.last_batch_obs = data + self.last_batch_obs_eval = data self.last_batch_action = batch_action return output - def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -871,7 +894,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in print('collector: collect_model clear()') print(f'eps_steps_lst[{env_id}]: {current_steps}') - def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the evaluation process for a specific environment. It clears caches and memory @@ -884,11 +907,22 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_zeros_batch( - self._cfg.model.observation_shape, - self._cfg.evaluator_env_num, - self._cfg.device - ) + if task_id is not None: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.evaluator_env_num, + self._cfg.device + ) + print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + else: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + print(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] # Return immediately if env_id is None or a list @@ -923,7 +957,15 @@ def _monitor_vars_learn(self) -> List[str]: """ return [ 'analysis/dormant_ratio_encoder', - 'analysis/dormant_ratio_world_model', + 'analysis/dormant_ratio_transformer', + 'analysis/dormant_ratio_head', + + 'analysis/avg_weight_mag_encoder', + 'analysis/avg_weight_mag_transformer', + 'analysis/avg_weight_mag_head', + 'analysis/e_rank_last_linear', + 'analysis/e_rank_sim_norm', + 'analysis/latent_state_l2_norms', 'analysis/l2_norm_before', 'analysis/l2_norm_after', diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py new file mode 100644 index 000000000..52469d1eb --- /dev/null +++ b/lzero/policy/unizero_multitask.py @@ -0,0 +1,1445 @@ +import copy +from collections import defaultdict +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt +import sys + +sys.path.append('/fs-computility/ai-shen/puyuan/code/LibMTL') +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect + +# from LibMTL.weighting.abstract_weighting import AbsWeighting + + +def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + + + +class WrappedModel: + def __init__(self, world_model): + self.world_model = world_model + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return self.world_model.parameters() + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + # TODO + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + def __init__(self, transformer, pos_emb, task_emb, act_embedding_table): + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + # self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning + with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: TODO + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.025, + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + """ + # NOTE: multi-task model + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + + if self._cfg.cos_lr_scheduler: + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, T_max=int(2e5), eta_min=0, last_epoch=-1 + ) # TODO + elif self._cfg.piecewise_decay_lr_scheduler: + # Example step scheduler, adjust milestones and gamma as needed + self.lr_scheduler = StepLR( + self._optimizer_world_model, step_size=int(5e4), gamma=0.1 + ) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is greater than or equal to 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # 创建 WrappedModel 实例 + # 所有参数都共享,即所有参数都需要进行矫正 + # wrapped_model = WrappedModel( + # self._learn_model.world_model, + # ) + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # head 没有矫正梯度 + self.wrapped_model = WrappedModelV2( + # self._learn_model.world_model.tokenizer, # TODO: + self._learn_model.world_model.tokenizer.encoder[0], # TODO: one encoder + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # head 和 tokenizer.encoder 没有矫正梯度 + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + + # 将 wrapped_model 作为 share_model 传递给 GradCorrect + # ========= 初始化 MoCo CAGrad 参数 ========= + # self.grad_correct = GradCorrect(self.wrapped_model, self.task_num_for_current_rank, self._cfg.device) + self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training + + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + + + + + #@profile + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # 初始化为0,避免使用in-place操作 + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + # 网络可塑性分析指标 + dormant_ratio_encoder_multi_task = [] + dormant_ratio_transformer_multi_task = [] + dormant_ratio_head_multi_task = [] + avg_weight_mag_encoder_multi_task = [] + avg_weight_mag_transformer_multi_task = [] + avg_weight_mag_head_multi_task = [] + e_rank_last_linear_multi_task = [] + e_rank_sim_norm_multi_task = [] + + + losses_list = [] # 用于存储每个任务的损失 + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task + # current_batch, target_batch, _ = data + # TODO: multitask适配rope(timestep_batch) + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + + # rank = get_rank() + # print(f'Rank {rank}: cfg.policy.task_id : {self._cfg.task_id}, self._cfg.batch_size {self._cfg.batch_size}') + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for a transformer-based world model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model + intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + ) + + weighted_total_loss += losses.loss_total # TODO + + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" # TODO + # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + losses_list.append(losses.loss_total) # TODO: for moco + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + + # value_priority = intermediate_losses['value_priority'] + # logits_value = intermediate_losses['logits_value'] + + # print(f'logits_value:" {logits_value}') + # print(f'logits_value.shape:" {logits_value.shape}') + # print(f"batch_for_gpt['observations'].shape: {batch_for_gpt['observations'].shape}") + + # ============ for value priority ============ + # transform the categorical representation of the scaled value to its original value + # original_value = self.inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + # batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # calculate the new priorities for each transition. + # value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) # TODO: mix of mean and sum + # value_priority = value_priority.data.cpu().numpy() + 1e-6 # TODO: log-reduce not support array now + value_priority = torch.tensor(0., device=self._cfg.device) + # ============ for value priority ============ + + # 关于网络可塑性的指标 + dormant_ratio_encoder = intermediate_losses['dormant_ratio_encoder'] + dormant_ratio_transformer = intermediate_losses['dormant_ratio_transformer'] + dormant_ratio_head = intermediate_losses['dormant_ratio_head'] + avg_weight_mag_encoder = intermediate_losses['avg_weight_mag_encoder'] + avg_weight_mag_transformer = intermediate_losses['avg_weight_mag_transformer'] + avg_weight_mag_head = intermediate_losses['avg_weight_mag_head'] + e_rank_last_linear = intermediate_losses['e_rank_last_linear'] + e_rank_sim_norm = intermediate_losses['e_rank_sim_norm'] + + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + reward_loss_multi_task.append(reward_loss) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + # 关于网络可塑性的指标 + dormant_ratio_encoder_multi_task.append(dormant_ratio_encoder) + dormant_ratio_transformer_multi_task.append(dormant_ratio_transformer) + dormant_ratio_head_multi_task.append(dormant_ratio_head) + avg_weight_mag_encoder_multi_task.append(avg_weight_mag_encoder) + avg_weight_mag_transformer_multi_task.append(avg_weight_mag_transformer) + avg_weight_mag_head_multi_task.append(avg_weight_mag_head) + e_rank_last_linear_multi_task.append(e_rank_last_linear) + e_rank_sim_norm_multi_task.append(e_rank_sim_norm) + + + # Core learn model update step + self._optimizer_world_model.zero_grad() + + # 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表 + # 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True + if self._cfg.use_moco: + # 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正 + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.only_use_moco_stats: + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + # 不使用梯度校正的情况,由各 rank 自己执行反向传播 + weighted_total_loss.backward() + else: + # 不使用梯度校正的情况,由各 rank 自己执行反向传播 + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + # TODO: 使用 MoCo 或 CAGrad 来计算梯度和权重 + # ============= for CAGrad and MoCo ============= + # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + + # ============= TODO: 不使用梯度矫正的情况 ============= + # lambd = torch.tensor([0. for i in range(self.task_num_for_current_rank)], device=self._cfg.device) + # weighted_total_loss.backward() + + # ========== for debugging ========== + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + # if self._cfg.multi_gpu: + # # Very important to sync gradients before updating the model + # # rank = get_rank() + # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad begin...') + # self.sync_gradients(self._learn_model) + # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad end...') + + if self._cfg.multi_gpu: + # if not self._cfg.use_moco or self._cfg.only_use_moco_stats: + # self.sync_gradients(self._learn_model) + if not self._cfg.use_moco: + self.sync_gradients(self._learn_model) + + # print("=== Step 前,参数梯度详细信息 ===") + # for idx, param in enumerate(self.grad_correct.share_model.parameters()): + # if param.grad is not None: + # print(f"Param[{idx}] - device: {param.device}, dtype: {param.dtype}, " + # f"grad device: {param.grad.device}, grad dtype: {param.grad.dtype}") + # else: + # print(f"Param[{idx}] 没有梯度!") + + self._optimizer_world_model.step() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # 然后,在您的代码中,使用这个函数来构建损失字典: + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + # 'policy_entropy': policy_entropy, + # 'target_policy_entropy': average_target_policy_entropy, + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + + # 关于网络可塑性的指标 + **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), + + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + # print(f'return_loss_dict:{return_loss_dict}') + + # 返回最终的损失字典 + return return_loss_dict + + def monitor_weights_and_grads(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: num_tasks + def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + If num_tasks is provided, generate monitored variables for each task. + """ + # Basic monitored variables that do not depend on the number of tasks + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # rank = get_rank() + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + # 关于网络可塑性的指标 + 'noreduce_dormant_ratio_encoder', + 'noreduce_dormant_ratio_transformer', + 'noreduce_dormant_ratio_head', + 'noreduce_avg_weight_mag_encoder', + 'noreduce_avg_weight_mag_transformer', + 'noreduce_avg_weight_mag_head', + 'noreduce_e_rank_last_linear', + 'noreduce_e_rank_sim_norm' + + ] + # self.task_num_for_current_rank 作为当前rank的base_index + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variables + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + # print(f"learner policy Rank {rank}, self.task_id: {self.task_id}") + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, we assume there's only one task and keep the original variable names + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + #@profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # ========= TODO: for muzero_segment_collector now ========= + if active_collect_env_num < self.collector_env_num: + # 当collect_env中有一个环境先done时,传回的self.last_batch_obs的长度会减少1, transformer在检索kv_cache时需要知道env_id,实现比较复杂 + # 因此直接《self.collector_env_num》个环境的self.last_batch_action全部重置为-1,让transformer从0开始,避免检索错误 + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + #@profile + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + # if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + #@profile + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('collector: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('collector: collect_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + # TODO: check its correctness ========= + self._reset_target_model() + + #@profile + def _reset_target_model(self) -> None: + """ + Overview: + This method resets the target model. It clears caches and memory, ensuring optimal performance. + Arguments: + - None + """ + + # Clear various caches in the target_model + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + print('collector: target_model past_kv_cache.clear()') + + #@profile + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + This method resets the evaluation process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, + the initial data will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + # if task_id is not None: + # self.last_batch_obs_eval = initialize_zeros_batch( + # self._cfg.model.observation_shape_list[task_id], + # self._cfg.evaluator_env_num, + # self._cfg.device + # ) + # print('unizero_multitask.py task_id is not None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + # else: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + print('unizero_multitask.py task_id is None after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the eval model's world model + world_model = self._eval_model.world_model + # world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + # NOTE: Clear caches and precompute positional embedding matrices both for the collect and target models + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== TODO: original version: load all parameters ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + # """ + # self._learn_model.load_state_dict(state_dict['model']) + # self._target_model.load_state_dict(state_dict['target_model']) + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components: List[str] = []) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + 根据 finetune_components 参数,决定加载 encoder 和 transformer 后,哪些部分参与后续更新,哪些被冻结。 + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + - finetune_components (:obj:`List[str]`, optional): A list of component names that will remain trainable after loading. + For example, it can include "encoder", "transformer", or both. The components not in this list will be frozen. + """ + # finetune_components = [] # load-enc-trans_finetune-head + # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head + + # 定义需要排除的参数前缀,即不加载这些参数 + exclude_prefixes = [ + '_orig_mod.world_model.head_policy_multi_task.', + '_orig_mod.world_model.head_value_multi_task.', + '_orig_mod.world_model.head_rewards_multi_task.', + '_orig_mod.world_model.head_observations_multi_task.', + '_orig_mod.world_model.task_emb.' + ] + + # 定义需要排除的具体参数(如果有特殊情况) + exclude_keys = [ + '_orig_mod.world_model.task_emb.weight', + '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 + # 添加其他需要排除的具体参数名 + ] + + def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + """ + 过滤掉需要排除的参数。 + """ + filtered = {} + for k, v in state_dict_loader.items(): + if any(k.startswith(prefix) for prefix in exclude_prefixes): + print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + continue + if k in exclude_keys: + print(f"Excluding specific parameter: {k}") # 调试用 + continue + filtered[k] = v + return filtered + + # 过滤并加载 'model' 部分 + if 'model' in state_dict: + model_state_dict = state_dict['model'] + filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _learn_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + else: + print("No 'model' key found in the state_dict.") + + # 过滤并加载 'target_model' 部分 + if 'target_model' in state_dict: + target_model_state_dict = state_dict['target_model'] + filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _target_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + else: + print("No 'target_model' key found in the state_dict.") + + # 对 _learn_model 中的参数进行冻结/解冻的处理 + # 假设模型中参数的名字如果包含 "encoder" 则属于 encoder 模块, + # 包含 "transformer" 则属于 transformer 模块,其它部分可根据需要扩展。 + for name, param in self._learn_model.named_parameters(): + # 如果参数属于 encoder 且不在需要微调的组件中,则冻结该参数 + if "encoder" in name and "encoder" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + elif "representation_network" in name and "representation_network" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # 如果参数属于 transformer 且不在需要微调的组件中,则冻结该参数 + elif "transformer" in name and "transformer" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + else: + # 如果参数属于其他模块,或者包含在 finetune_components 中,则保持默认(或者根据需要调整) + print(f"Parameter remains default: {name}") + + # 注意: + # 如果你的模型中嵌套模块更为复杂,可以基于 module 的属性而不是仅仅依靠参数名称进行判断,比如: + # for module in self._learn_model.modules(): + # if isinstance(module, EncoderModule) and "encoder" not in finetune_components: + # for param in module.parameters(): + # param.requires_grad = False + + # # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. + # """ + # # 定义需要排除的参数前缀 + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # 定义需要排除的具体参数(如果有特殊情况) + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 + # # 添加其他需要排除的具体参数名 + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # 过滤掉需要排除的参数。 + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes): + # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 + # continue + # if k in exclude_keys: + # print(f"Excluding specific parameter: {k}") # 调试用 + # continue + # filtered[k] = v + # return filtered + + # # 过滤并加载 'model' 部分 + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _learn_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + # else: + # print("No 'model' key found in the state_dict.") + + # # 过滤并加载 'target_model' 部分 + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing_keys: + # print(f"Missing keys when loading _target_model: {missing_keys}") + # if unexpected_keys: + # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + # else: + # print("No 'target_model' key found in the state_dict.") + + # # 不要加载优化器的 state_dict,因为优化器通常不包含模型参数,加载后性能反而变差 + # # if 'optimizer_world_model' in state_dict: + # # optimizer_state_dict = state_dict['optimizer_world_model'] + # # try: + # # self._optimizer_world_model.load_state_dict(optimizer_state_dict) + # # except Exception as e: + # # print(f"Error loading optimizer state_dict: {e}") + # # else: + # # print("No 'optimizer_world_model' key found in the state_dict.") diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index f2cba7161..631af8391 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -359,7 +359,7 @@ def prepare_obs_stack_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> T return obs_batch, obs_target_batch -def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: +def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict, task_id = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Prepare the observations for the model by converting the original batch of observations @@ -382,9 +382,12 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Calculate the dimension size to slice based on the model configuration. # For convolutional models ('conv'), use the number of frames to stack times the number of channels. # For multi-layer perceptron models ('mlp'), use the number of frames to stack times the size of the observation space. - stack_dim = cfg.model.frame_stack_num * ( + if task_id is None: + stack_dim = cfg.model.frame_stack_num * ( cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape) - + else: + stack_dim = cfg.model.frame_stack_num * ( + cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id]) # Slice the original observation tensor to obtain the batch for the initial inference. obs_batch = obs_batch_ori[:, :stack_dim] @@ -395,7 +398,10 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Determine the starting dimension to exclude based on the model type. # For 'conv', exclude the first 'image_channel' dimensions. # For 'mlp', exclude the first 'observation_shape' dimensions. - exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + if task_id is None: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + else: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id] # Slice the original observation tensor to obtain the batch for consistency loss calculation. obs_target_batch = obs_batch_ori[:, exclude_dim:] @@ -550,7 +556,11 @@ def concat_output_value(output_lst: List) -> np.ndarray: # concat the values of the model output list value_lst = [] for output in output_lst: - value_lst.append(output.value) + value_lst.append(output.value) # TODO:cpu + + # print(f'value_lst:{value_lst}') + # print(f'value_lst[0]:{value_lst[0]}') + # print(f'value_lst[0].shape:{value_lst[0].shape}') value_lst = np.concatenate(value_lst) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 8e08e6c61..0299abf8f 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -42,6 +42,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -54,7 +55,9 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -268,6 +271,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm end_index = beg_index + self.unroll_plus_td_steps - 1 pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_lst = game_segments[i].chance_segment[beg_index:end_index] @@ -294,7 +298,7 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm game_segment element shape: obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length -> 20 + action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 child_visits: game_segment_length + num_unroll_steps -> 20 +5 to_play: game_segment_length -> 20 @@ -445,8 +449,13 @@ def collect(self, # Key policy forward step # ============================================================== # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) - + # policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) + else: + # multi-task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} @@ -571,9 +580,9 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero', 'unizero_multitask', 'sampled_unizero_multitask']: + # TODO: only for UniZero now + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) # NOTE: reset_init_data=False total_transitions += 1 @@ -799,10 +808,16 @@ def _output_log(self, train_iter: int) -> None: for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, self._total_envstep_count) if self.policy_config.use_wandb: wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 454b81b31..468c75aa5 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -15,6 +15,7 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation +import threading class MuZeroEvaluator(ISerialEvaluator): @@ -56,6 +57,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'evaluator', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -70,7 +72,10 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - instance_name (:obj:`str`): Name of this evaluator instance. - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.stop_event = threading.Event() # Add stop event to handle timeouts + self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name @@ -88,7 +93,19 @@ def __init__( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name ) else: - self._logger, self._tb_logger = None, None # for close elegantly + # self._logger, self._tb_logger = None, None # for close elegantly + # ========== TODO: unizero_multitask ddp_v2 ======== + if tb_logger is not None: + self._logger, _ = build_logger( + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + + + self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self.reset(policy, env) @@ -101,6 +118,9 @@ def __init__( # ============================================================== self.policy_config = policy_config + # def stop(self): + # self.stop_event.set() + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: @@ -129,7 +149,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -210,10 +230,20 @@ def eval( - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. """ + if torch.cuda.is_available(): + print(f"=========in eval() Rank {get_rank()} ===========") + device = torch.cuda.current_device() + print(f"当前默认的 GPU 设备编号: {device}") + torch.cuda.set_device(get_rank()) + print(f"set device后的 GPU 设备编号: {get_rank()}") + # the evaluator only works on rank0 episode_info = None stop_flag = False - if get_rank() == 0: + # ======== TODO: unizero_multitask ddp_v2 ======== + # if get_rank() == 0: + if get_rank() >= 0: + if n_episode is None: n_episode = self._default_n_episode assert n_episode is not None, "please indicate eval n_episode" @@ -222,7 +252,7 @@ def eval( env_nums = self._env.env_num self._env.reset() - self._policy.reset() + self._policy.reset(task_id=self.task_id) # initializations init_obs = self._env.ready_obs @@ -256,7 +286,8 @@ def eval( GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] for i in range(env_nums): @@ -269,6 +300,12 @@ def eval( eps_steps_lst = np.zeros(env_nums) with self._timer: while not eval_monitor.is_finished(): + + # Check if stop_event is set (timeout occurred) + if self.stop_event.is_set(): + self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") + break + # Get current ready env obs. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) @@ -292,7 +329,13 @@ def eval( # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) + # policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + else: + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, task_id=self.task_id) actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} @@ -341,7 +384,7 @@ def eval( eps_steps_lst[env_id] += 1 if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False, task_id=self.task_id) game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], @@ -404,7 +447,8 @@ def eval( game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset( @@ -441,14 +485,23 @@ def eval( episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + self._logger.info(self._logger.get_tabulate_vars_hor(info)) for k, v in info.items(): if k in ['train_iter', 'ckpt_name', 'each_reward']: continue if not np.isscalar(v): continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + envstep) if self.policy_config.use_wandb: wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) @@ -466,12 +519,16 @@ def eval( ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) - if get_world_size() > 1: - objects = [stop_flag, episode_info] - broadcast_object_list(objects, src=0) - stop_flag, episode_info = objects + # ========== TODO: unizero_multitask ddp_v2 ======== + # if get_world_size() > 1: + # objects = [stop_flag, episode_info] + # print(f'rank {self._rank}, self.task_id: {self.task_id}') + # print('before broadcast_object_list') + # broadcast_object_list(objects, src=0) + # print('evaluator after broadcast_object_list') + # stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: episode_info['trajectory'] = game_segments - return stop_flag, episode_info + return stop_flag, episode_info \ No newline at end of file diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 46cc016bc..668a05118 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -46,19 +46,22 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: - Initialize the MuZeroSegmentCollector with the given parameters. + Initialize the MuZeroCollector with the given parameters. Arguments: - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): Namedtuple of the collection mode policy API. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. """ + self.task_id = task_id + self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -66,6 +69,10 @@ def __init__( self._end_flag = False self._rank = get_rank() + + print(f'rank {self._rank}, self.task_id: {self.task_id}') + + self._world_size = get_world_size() if self._rank == 0: if tb_logger is not None: @@ -83,7 +90,9 @@ def __init__( self._logger, _ = build_logger( path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False ) - self._tb_logger = None + # =========== TODO: for unizero_multitask ddp_v2 ======== + self._tb_logger = tb_logger + self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -124,7 +133,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: self._logger.debug( 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -390,7 +399,8 @@ def collect(self, GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] # stacked observation windows in reset stage for init game_segments @@ -448,6 +458,8 @@ def collect(self, # ready_env_id = set(obs.keys()) stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + + stack_obs = list(stack_obs.values()) self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} @@ -469,9 +481,14 @@ def collect(self, # ============================================================== # Key policy forward step # ============================================================== - # logging.info(f'ready_env_id:{ready_env_id}') - # logging.info(f'timestep:{timestep}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) + # print(f'ready_env_id:{ready_env_id}') + if self.task_id is None: + # single task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + else: + # multi task setting + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) + # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -643,7 +660,8 @@ def collect(self, game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -705,7 +723,7 @@ def collect(self, # Env reset is done by env_manager automatically # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ - self._policy.reset([env_id]) + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) ready_env_id.remove(env_id) @@ -714,7 +732,8 @@ def collect(self, game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -735,11 +754,13 @@ def collect(self, break collected_duration = sum([d['time'] for d in self._episode_info]) + # TODO: for atari multitask new ddp pipeline # reduce data when enables DDP - if self._world_size > 1: - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration @@ -755,8 +776,9 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): Current training iteration number for logging context. """ - if self._rank != 0: - return + # TODO: for atari multitask new ddp pipeline + # if self._rank != 0: + # return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) @@ -789,11 +811,20 @@ def _output_log(self, train_iter: int) -> None: if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) self._episode_info.clear() + print(f'collector output_log: rank {self._rank}, self.task_id: {self.task_id}') self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + self._total_envstep_count) diff --git a/requirements.txt b/requirements.txt index 8d50b281d..18645a2e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ moviepy pytest line_profiler xxhash +simple_parsing einops openai \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_multitask_segment_config.py b/zoo/atari/config/atari_muzero_multitask_segment_config.py new file mode 100644 index 000000000..ce486a050 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_config.py @@ -0,0 +1,260 @@ +from easydict import EasyDict + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== TODO: only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + num_res_blocks=2, # NOTE: encoder for 4 game + num_channels=256, + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + cuda=True, + env_type='not_board_games', + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # =========== TODO: debug =========== + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + exp_name_prefix = ( + f'data_muzero_mt_8games/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + # env_manager=dict(type='base'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + import sys + sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + import lzero + print("lzero path:", lzero.__file__) + # import sys + # import os + # # 添加项目根目录到 PYTHONPATH + # sys.path.append(os.path.dirname(os.path.abspath(__file__))) + + from lzero.entry import train_muzero_multitask_segment_noddp + import argparse + + parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + parser.add_argument('--seed', type=int, default=0, help='Random seed') + args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space, adjust if different per env + seed = args.seed + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + + max_batch_size = 512 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + print(f'=========== batch_size: {batch_size} ===========') + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + num_segments = 8 + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 5 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + # Start training + train_muzero_multitask_segment_noddp(configs, seed=seed, max_env_step=int(5e5)) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..698a3d1ac --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py @@ -0,0 +1,294 @@ +from easydict import EasyDict +from copy import deepcopy +from atari_env_action_space_map import atari_env_action_space_map + +def create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + + return EasyDict(dict( + env=dict( + stop_value=int(5e5), # Adjusted max_env_step based on user TODO + env_id=env_id, + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + multi_gpu=True, # ======== Very important for ddp ============= + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000,), # Adjusted checkpoint frequency + ), + ), + grad_correct_params=dict( + # Placeholder for gradient correction parameters if needed + ), + task_num=len(env_id_list), + model=dict( + device='cuda', + num_res_blocks=2, # NOTE: encoder for 4 game + num_channels=256, + reward_head_channels= 16, + value_head_channels= 16, + policy_head_channels= 16, + fc_reward_layers= [32], + fc_value_layers= [32], + fc_policy_layers= [32], + observation_shape=(4, 96, 96), + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + norm_type=norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=len(env_id_list), + ), + allocated_batch_sizes=False, + cuda=True, + env_type='not_board_games', + train_start_after_envsteps=2000, + # train_start_after_envsteps=0, # TODO: debug + game_segment_length=20, # Fixed segment length as per user config + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: debug + update_per_collect=80, # Consistent with UniZero config + batch_size=batch_size, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=num_segments, + num_simulations=num_simulations, + policy_entropy_weight=5e-3, #TODO + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), # Adjusted as per UniZero config + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs( + env_id_list, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + seed, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments +): + configs = [] + # TODO: debug name + exp_name_prefix = ( + f'data_lz/data_muzero_mt_atari_20250228/{len(env_id_list)}games_brf{buffer_reanalyze_freq}/' + f'{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_' + f'{len(env_id_list)}-pred-head_mbs-512_upc80_H{num_unroll_steps}_seed{seed}/' + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments + ) + config.policy.task_id = task_id + config.exp_name = f"{exp_name_prefix}{env_id.split('NoFrameskip')[0]}_muzero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + +if __name__ == "__main__": + # import sys + # sys.path.insert(0, "/mnt/afs/niuyazhe/code/LightZero") + # import lzero + # print("lzero path:", lzero.__file__) + + # parser = argparse.ArgumentParser(description='Train MuZero Multitask on Atari') + # parser.add_argument('--seed', type=int, default=0, help='Random seed') + # args = parser.parse_args() + + # Define your list of environment IDs + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', + # 'ChopperCommandNoFrameskip-v4', + # 'HeroNoFrameskip-v4', + # 'RoadRunnerNoFrameskip-v4', + ] + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', + 'AsterixNoFrameskip-v4', + 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', + 'CrazyClimberNoFrameskip-v4', + 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', + 'GopherNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', + 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', + 'PrivateEyeNoFrameskip-v4', + 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', + 'BreakoutNoFrameskip-v4', + ] + + action_space_size = 18 # Full action space, adjust if different per env + seed = 0 + + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + n_episode = 8 + num_simulations = 50 + reanalyze_ratio = 0.0 + max_env_step = 5e5 + + max_batch_size = 512 + # max_batch_size = 1024 + batch_size = [int(min(64, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 5 + infer_context_length = 4 + # norm_type = 'LN' + norm_type = 'BN' + + buffer_reanalyze_freq = 1 / 50 # Adjusted as per UniZero config + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + + # =========== TODO: debug =========== + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # num_simulations = 3 + # batch_size = [int(min(2, max_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + print(f'=========== batch_size: {batch_size} ===========') + # Generate configurations + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments + ) + + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + export NCCL_TIMEOUT=3600000 + python -m torch.distributed.launch --nproc_per_node=4 --master_port=29501 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + 或者使用 torchrun: + torchrun --nproc_per_node=4 ./zoo/atari/config/atari_muzero_multitask_segment_8games_ddp_config.py + """ + from lzero.entry import train_muzero_multitask_segment_ddp + from ding.utils import DDPContext + with DDPContext(): + train_muzero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_segment_config.py b/zoo/atari/config/atari_muzero_segment_config.py index 4289fb957..212798506 100644 --- a/zoo/atari/config/atari_muzero_segment_config.py +++ b/zoo/atari/config/atari_muzero_segment_config.py @@ -43,7 +43,7 @@ def main(env_id, seed): env=dict( stop_value=int(1e6), env_id=env_id, - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, gray_scale=True, collector_env_num=collector_env_num, @@ -59,7 +59,7 @@ def main(env_id, seed): analysis_sim_norm=False, cal_dormant_ratio=False, model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), image_channel=1, frame_stack_num=4, gray_scale=True, @@ -123,7 +123,7 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_muzero_segment - main_config.exp_name = f'data_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_lz/data_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) if __name__ == "__main__": diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index c7787831b..91517afd5 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -18,6 +18,17 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + +# ====== only for debug ===== +# collector_env_num = 8 +# num_segments = 8 +# evaluator_env_num = 2 +# num_simulations = 5 +# max_env_step = int(2e5) +# reanalyze_ratio = 0.1 +# batch_size = 64 +# num_unroll_steps = 10 +# replay_ratio = 0.01 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -32,6 +43,9 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # # TODO: only for debug + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( model=dict( diff --git a/zoo/atari/config/atari_unizero_ddp_config.py b/zoo/atari/config/atari_unizero_ddp_config.py index d64332d58..887e5f7cb 100644 --- a/zoo/atari/config/atari_unizero_ddp_config.py +++ b/zoo/atari/config/atari_unizero_ddp_config.py @@ -55,13 +55,20 @@ max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', - # device='cpu', action_space_size=action_space_size, num_layers=2, num_heads=8, embed_dim=768, obs_type='image', env_num=max(collector_env_num, evaluator_env_num), + task_num=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, ), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..a08064748 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -0,0 +1,243 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + # collect_max_episode_steps=int(5e3), + # eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + collect_max_episode_steps=int(20), + eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + # use_adaptive_scale=True, + use_adaptive_scale=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + share_head=False, # TODO + + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + + # LoRA 参数: + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + ), + ), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: ===== only for debug ===== + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=2, # TODO: ===== only for debug ===== + # update_per_collect=80, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + exp_name_prefix = f'data_lz/data_unizero_atari_mt_20250425_debug/atari_{len(env_id_list)}games_tbs1536-encoderchannel256-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_encoder-final-ln_seed{seed}/' + + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29504 ./zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee ./log/uz_mt_atari26_channel256_debug.log + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(512*3) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + # total_batch_size = 512 + # batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 1 + reanalyze_batch_size = 2 + num_unroll_steps = 5 + infer_context_length = 2 + batch_size = [4, 4, 4, 4, 4, 4, 4, 4] + + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + + with DDPContext(): + # train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + + # ======== TODO: only for debug ======== + train_unizero_multitask_segment_ddp(configs[:8], seed=seed, max_env_step=max_env_step) # train on the first four tasks diff --git a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py new file mode 100644 index 000000000..29de4f112 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py @@ -0,0 +1,167 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Enable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, + MoCo_rho=0, calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + env_id_list=env_id_list, + analysis_tsne=True, # TODO + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, # Transformer layers + num_heads=8, + # num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + configs = [] + exp_name_prefix = f'data_unizero_mt_ddp-8gpu_eval-latent_state_tsne/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_nlayer8-nh24-lsd768_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This program is designed to obtain the t-SNE of the latent states in 8games multi-task learning. + + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_eval_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_eval + from ding.utils import DDPContext + + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + + action_space_size = 18 + + for seed in [0]: + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + total_batch_size = int(4*len(env_id_list)) + batch_size = [4 for _ in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1/50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + + configs = generate_configs( + env_id_list, action_space_size, collector_env_num, n_episode, + evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, seed, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size + ) + + # Pretrained model paths + # 8games + pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' + # 26games + # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu-26game_1127/26games_brf0.02_nlayer8-nhead24_seed0/26games_brf0.02_1-encoder-LN-res2-channel256_gsl20_26-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed0/Pong_unizero-mt_seed0/ckpt/iteration_150000.pth.tar' + + with DDPContext(): + train_unizero_multitask_segment_eval(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py new file mode 100644 index 000000000..badcd9585 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py @@ -0,0 +1,236 @@ +from easydict import EasyDict + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( # Gradient correction parameters + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + share_head=False, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=96, + # task_embed_dim=128, + + use_shared_projection=False, + + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + + # LoRA 参数(启用LoRA) + lora_r=0, + # lora_r=8, + lora_alpha=32, + lora_dropout=0.1, + # 默认目标模块:attn和feed_forward + lora_target_modules=["attn", "feed_forward"], + # 调整finetune_components + ), + ), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + cos_lr_scheduler=True, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size): + configs = [] + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' + exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-encoder/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-trans/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/amidar_load-enc-trans_finetune-head-trans-lora/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' + + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/pong_load-enc-trans_finetune-head-trans-lora/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_finetune_20250308/pong_load-enc-trans_finetune-head/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_lsd768-nlayer8-nh24_upc80_seed{seed}/' + + + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type, + buffer_reanalyze_freq, + reanalyze_batch_size, + reanalyze_partition, + num_segments, + total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=1 --master_port=29507 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + from easydict import EasyDict + + # env_id_list = ['PongNoFrameskip-v4'] # Debug setup + env_id_list = ['AmidarNoFrameskip-v4'] # Debug setup + + action_space_size = 18 + + # NCCL environment setup + import os + os.environ["NCCL_TIMEOUT"] = "3600000000" + + # for seed in [0, 1, 2]: + for seed in [0]: + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 10000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # batch_size = [4, 4, 4, 4, 4, 4, 4, 4] + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size) + + # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar' + # pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_atari_mt_20250217/atari_8games_notaskembed_bs64_brf0.02_seed0_dev-uz-mz-mt-cont/Pong_seed0_250218_124624/ckpt/ckpt_best.pth.tar' + + pretrained_model_path = '/fs-computility/ai-shen/puyuan/code/LightZero/data_lz/data_unizero_atari_mt_20250307/atari_8games_brf0.02_not-share-head_final-ln_seed0/Pong_seed0/ckpt/ckpt_best.pth.tar' + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_segment_config.py b/zoo/atari/config/atari_unizero_segment_config.py index d9e78dfd4..4518fbe4d 100644 --- a/zoo/atari/config/atari_unizero_segment_config.py +++ b/zoo/atari/config/atari_unizero_segment_config.py @@ -11,9 +11,9 @@ def main(env_id, seed): collector_env_num = 8 num_segments = 8 game_segment_length = 20 - evaluator_env_num = 10 + evaluator_env_num = 3 num_simulations = 50 - max_env_step = int(5e5) + max_env_step = int(4e5) batch_size = 64 num_layers = 2 replay_ratio = 0.25 @@ -31,8 +31,9 @@ def main(env_id, seed): # collector_env_num = 2 # num_segments = 2 # evaluator_env_num = 2 - # num_simulations = 10 + # num_simulations = 5 # batch_size = 5 + # buffer_reanalyze_freq = 1/1000000 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -47,9 +48,11 @@ def main(env_id, seed): evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # collect_max_episode_steps=int(5e3), + # eval_max_episode_steps=int(5e3), # TODO: only for debug - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), ), policy=dict( learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 @@ -58,6 +61,20 @@ def main(env_id, seed): action_space_size=action_space_size, support_scale=300, world_model_cfg=dict( + # final_norm_option_in_obs_head='LayerNorm', + # final_norm_option_in_encoder='LayerNorm', + # predict_latent_loss_type='mse', # TODO: only for latent state layer_norm + + final_norm_option_in_obs_head='SimNorm', + final_norm_option_in_encoder='SimNorm', + predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # analysis_dormant_ratio_weight_rank=True, # TODO + + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + use_shared_projection=False, support_size=601, policy_entropy_weight=5e-3, continuous_action_space=False, @@ -73,6 +90,17 @@ def main(env_id, seed): env_num=max(collector_env_num, evaluator_env_num), num_simulations=num_simulations, rotary_emb=False, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + # LoRA 参数: + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, ), ), # (str) The path of the pretrained model. If None, the model will be initialized by the default model. @@ -90,7 +118,8 @@ def main(env_id, seed): num_simulations=num_simulations, num_segments=num_segments, td_steps=5, - train_start_after_envsteps=0, + # train_start_after_envsteps=0, # only for debug + train_start_after_envsteps=2000, game_segment_length=game_segment_length, grad_clip_value=5, replay_buffer_size=int(1e6), @@ -135,5 +164,4 @@ def main(env_id, seed): parser.add_argument('--env', type=str, help='The environment to use', default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() - main(args.env, args.seed) diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 8bc491674..f5e43f6c8 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -24,6 +24,8 @@ class AtariEnvLightZero(BaseEnv): _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed """ config = dict( + # (bool) Whether to use the full action space of the environment. Default is False. If set to True, the action space size is 18 for Atari. + full_action_space=False, # (int) The number of environment instances used for data collection. collector_env_num=8, # (int) The number of environment instances used for evaluator. @@ -180,6 +182,8 @@ def step(self, action: int) -> BaseEnvTimestep: if done: logging.info(f'one episode done! total episode length is: {self._timestep}') info['eval_episode_return'] = self._eval_episode_return + print(f'one episode of {self.cfg.env_id} done') + return BaseEnvTimestep(observation, self.reward, done, info) def observe(self) -> dict: diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index f38aa24d6..265ef31ac 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -93,9 +93,9 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: - env = gym.make(config.env_id, render_mode='human') + env = gym.make(config.env_id, render_mode='human', full_action_space=config.full_action_space) else: - env = gym.make(config.env_id, render_mode='rgb_array') + env = gym.make(config.env_id, render_mode='rgb_array', full_action_space=config.full_action_space) assert 'NoFrameskip' in env.spec.id if hasattr(config, 'save_replay') and config.save_replay \ and hasattr(config, 'replay_path') and config.replay_path is not None: diff --git a/zoo/box2d/box2d_suz_multitask.py b/zoo/box2d/box2d_suz_multitask.py new file mode 100644 index 000000000..cf87e189d --- /dev/null +++ b/zoo/box2d/box2d_suz_multitask.py @@ -0,0 +1,179 @@ +from easydict import EasyDict +from copy import deepcopy +import torch +def create_config(env_id, observation_shapes, action_space_sizes, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + continuous=True, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000,),),), # default is 10000 + grad_correct_params=dict( + # for MoCo + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + # for CAGrad + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shapes=observation_shapes, + action_space_size=4, + action_space_sizes=action_space_sizes, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + obs_type='vector', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=1e-4, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + norm_type=norm_type, + bound_type=None, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda' if torch.cuda.is_available() else 'cpu', + action_space_size=action_space_sizes, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, # NOTE + moe_in_transformer=False, # NOTE + multiplication_moe_in_transformer=False, # NOTE + num_experts_of_moe_in_transformer=4, + ), + ), + use_priority=True, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=1e-4, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + )) + +def generate_configs(env_id_list, observation_shapes, action_space_sizes, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): + configs = [] + exp_name_prefix = f'data_unizero_mt_box2d/{len(env_id_list)}games_cont_action_seed{seed}/' + + for task_id, (env_id, observation_shape, action_space_size) in enumerate(zip(env_id_list, observation_shapes, action_space_sizes)): + config = create_config( + env_id, + observation_shapes, # TODO + action_space_sizes, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('-v')[0]}_unizero_mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager(env_name=env_id)]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='box2d', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env', 'zoo.box2d.bipedalwalker.envs.bipedalwalker_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + +def create_env_manager(env_name: str): + if env_name == 'LunarLanderContinuous-v2': + return EasyDict(dict( + env=dict( + type='lunarlander', + import_names=[f'zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + elif env_name == 'BipedalWalker-v3': + return EasyDict(dict( + env=dict( + type='bipedalwalker', + import_names=[f'zoo.box2d.bipedalwalker.envs.bipedalwalker_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask + + env_id_list = [ + 'LunarLanderContinuous-v2', + 'BipedalWalker-v3', + ] + + observation_shapes = [ + 8, # LunarLanderContinuous-v2 + 24, # BipedalWalker-v3 + ] + + action_space_sizes = [ + 2, # LunarLanderContinuous-v2 + 4, # BipedalWalker-v3 + ] + + seed = 0 + collector_env_num = 6 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0. + max_batch_size = 1000 + batch_size = [int(max_batch_size/len(env_id_list)) for i in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + configs = generate_configs(env_id_list, observation_shapes, action_space_sizes, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) + + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py new file mode 100644 index 000000000..4f5ca5bda --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py @@ -0,0 +1,132 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + +env_id = 'cartpole-swingup' # You can specify any DMC task here +action_space_size = dmc_state_env_action_space_map[env_id] +obs_space_size = dmc_state_env_obs_space_map[env_id] +print(f'env_id: {env_id}, action_space_size: {action_space_size}, obs_space_size: {obs_space_size}') + +domain_name = env_id.split('-')[0] +task_name = env_id.split('-')[1] + +continuous_action_space = True +K = 20 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(1e6) +reanalyze_ratio = 0 +batch_size = 64 +num_unroll_steps = 10 +infer_context_length = 4 +norm_type = 'LN' +seed = 0 + +# for debug +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 1 +# num_simulations = 2 +# batch_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +dmc2gym_pixels_cont_sampled_unizero_config = dict( + exp_name=f'data_sampled_unizero_0901/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}', + env=dict( + env_id='dmc2gym-v0', + continuous=True, + domain_name=domain_name, + task_name=task_name, + from_pixels=True, # pixel/image obs + frame_skip=2, + warp_frame=True, + scale=True, + frame_stack_num=1, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 84, 84), + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + world_model_cfg=dict( + obs_type='image', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=5e-3, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + fixed_sigma_value=0.3, + bound_type=None, + model_type='conv', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + # device='cpu', + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_augmentation=False, + env_type='not_board_games', + game_segment_length=100, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + lr_piecewise_constant_decay=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +dmc2gym_pixels_cont_sampled_unizero_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_config) +main_config = dmc2gym_pixels_cont_sampled_unizero_config + +dmc2gym_pixels_cont_sampled_unizero_create_config = dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + # env_manager=dict(type='subprocess'), + env_manager=dict(type='base'), + policy=dict( + type='sampled_unizero', + import_names=['lzero.policy.sampled_unizero'], + ), +) +dmc2gym_pixels_cont_sampled_unizero_create_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_create_config) +create_config = dmc2gym_pixels_cont_sampled_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero + + train_unizero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_config.py similarity index 100% rename from zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py rename to zoo/dmc2gym/config/dmc2gym_state_suz_config.py diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py new file mode 100644 index 000000000..57770d6a3 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py @@ -0,0 +1,355 @@ +from easydict import EasyDict +from typing import List + +import logging + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(message)s', + handlers=[ + logging.FileHandler("output.log", encoding="utf-8"), # 文件日志 + logging.StreamHandler() # 终端日志 + ] +) + +def create_config(env_id, observation_shape_list, action_space_size_list, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + domain_name = env_id.split('-')[0] + task_name = env_id.split('-')[1] + return EasyDict(dict( + env=dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + frame_skip=2, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, # As per single-task config + # ===== TODO: only for debug ===== + # game_segment_length=10, # As per single-task config + # collect_max_episode_steps=int(20), + # eval_max_episode_steps=int(20), + ), + policy=dict( + multi_gpu=True, # TODO: enable multi-GPU for DDP + only_use_moco_stats=False, + # use_moco=False, # ==============TODO============== + use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + # Example gradient correction parameters, adjust as needed + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, # To be set per task + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + share_head=False, # TODO + use_shared_projection=False, + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + + # task_embed_option=None, # ==============TODO: none ============== + # use_task_embed=False, # ==============TODO============== + + task_embed_option='concat_task_embed', # ==============TODO: none ============== + use_task_embed=True, # ==============TODO============== + task_embed_dim=128, + # task_embed_dim=96, + + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + policy_loss_type='kl', + obs_type='vector', + num_unroll_steps=num_unroll_steps, + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + norm_type=norm_type, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + # num_layers=1, # TODO: debug config + num_layers=8, # TODO + num_heads=24, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + + # LoRA 参数: + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + + task_complexity_weight=False, # TODO + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + # train_start_after_envsteps=int(2e3), # TODO + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + # update_per_collect=2, # TODO: debug config + # update_per_collect=200, # TODO: 8*100*0.25=200 + update_per_collect=80, # TODO: 8*100*0.1=80 + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + # eval_freq=int(5e3), + eval_freq=int(4e3), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list: List[str], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int): + configs = [] + + exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250413_moco/dmc_{len(env_id_list)}tasks_concattaskembed128_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' + # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250413_moco/dmc_{len(env_id_list)}tasks_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' + + # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250409_moco/dmc_{len(env_id_list)}tasks_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' + + # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250325/dmc_{len(env_id_list)}tasks_task-exploitation-weight_notaskembed_nlayer8_not-share-head_final-ln_bs64_brf{buffer_reanalyze_freq}_seed{seed}/' + # exp_name_prefix = f'data_lz/data_suz_dmc_mt_20250311/dmc_{len(env_id_list)}tasks_concattaskembed-128_nlayer8_not-share-head_final-ln_bs64*8_brf{buffer_reanalyze_freq}_seed{seed}/' + + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, (env_id, obs_shape, act_space) in enumerate(zip(env_id_list, observation_shape_list, action_space_size_list)): + config = create_config( + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py 2>&1 | tee ./log/uz_mt_dmc_moco_taskembed_20250409.log + torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + os.environ["NCCL_TIMEOUT"] = "3600000000" + + + # DMC 8games + # env_id_list = [ + # 'acrobot-swingup', + # 'cartpole-balance', + # 'cartpole-balance_sparse', + # 'cartpole-swingup', + # 'cartpole-swingup_sparse', + # 'cheetah-run', + # "ball_in_cup-catch", + # "finger-spin", + # ] + + # DMC 18games + env_id_list = [ + 'acrobot-swingup', # 0 + 'cartpole-balance', # 1 + 'cartpole-balance_sparse', # 2 + 'cartpole-swingup', # 3 + 'cartpole-swingup_sparse', # 4 bad + 'cheetah-run', # 5 bad + "ball_in_cup-catch", # 6 + "finger-spin", # 7 bad + "finger-turn_easy", # 8 波动 + "finger-turn_hard", # 9 波动 + 'hopper-hop', # 10 bad + 'hopper-stand', # 11 + 'pendulum-swingup', # 12 bad + 'reacher-easy', # 13 + 'reacher-hard', # 14 波动 + 'walker-run', # 15 略差 + 'walker-stand', # 16 + 'walker-walk', # 17 + ] + + # debug + # env_id_list = [ + # 'acrobot-swingup', # 0 + # 'cartpole-balance', # 1 + # 'cartpole-balance_sparse', # 2 + # 'cartpole-swingup', # 3 + # 'cartpole-swingup_sparse', # 4 bad + # 'cheetah-run', # 5 bad + # "ball_in_cup-catch", # 6 + # "finger-spin", # 7 bad + # # "finger-turn_easy", # 8 波动 + # # "finger-turn_hard", # 9 波动 + # ] + + # 获取各环境的 action_space_size 和 observation_shape + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + # nlayer=8 + total_batch_size = 1024 + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + # nlayer=12 + # total_batch_size = 256 + # batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # batch_size = [4 for _ in range(len(env_id_list))] + # ======================================= + + seed = 0 # You can iterate over multiple seeds if needed + + configs = generate_configs( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step) + # 如果只想训练部分任务,可以修改 configs,例如: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py index 4fcfb209a..068790293 100644 --- a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py +++ b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py @@ -18,6 +18,8 @@ from gym.spaces import Box from matplotlib import animation import imageio +import logging + def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box: @@ -268,6 +270,8 @@ def __init__(self, cfg: dict = {}) -> None: self._save_replay_gif = cfg.save_replay_gif self._replay_path_gif = cfg.replay_path_gif self._save_replay_count = 0 + self._timestep = 0 + self.max_episode_steps = cfg.max_episode_steps def reset(self) -> Dict[str, np.ndarray]: """ @@ -409,11 +413,15 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: if self._save_replay_gif: self._frames.append(image_obs) + + if self._timestep > self.max_episode_steps: + done = True if self._timestep > self._cfg.max_episode_steps: done = True if done: + logging.info(f'one episode done! episode return: {self._eval_episode_return}, episode_steps:{self._timestep}') info['eval_episode_return'] = self._eval_episode_return if self._save_replay_gif: @@ -422,7 +430,8 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', + self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') @@ -487,7 +496,7 @@ def __repr__(self) -> str: String representation of the environment. """ return "LightZero DMC2Gym Env({}:{})".format(self._cfg["domain_name"], self._cfg["task_name"]) - + @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') @@ -502,4 +511,4 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.eval_max_episode_steps cfg.is_eval = True - return [cfg for _ in range(evaluator_env_num)] + return [cfg for _ in range(evaluator_env_num)] \ No newline at end of file diff --git a/zoo/jericho/detective_unizero_cprofile_10k_envstep b/zoo/jericho/detective_unizero_cprofile_10k_envstep new file mode 100644 index 000000000..ae3d22ab1 Binary files /dev/null and b/zoo/jericho/detective_unizero_cprofile_10k_envstep differ