Skip to content

WIP: feature(pu): unizero and muzero multitask ddp pipeline #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 176 commits into from
Closed
Changes from all commits
Commits
Show all changes
176 commits
Select commit Hold shift + click to select a range
dd2c95c
feature(pu): add UniZero multitask related pipeline
puyuan1996 Jul 5, 2024
8769a5c
polish(pu): polish unizero_multitask config
dyyoungg Jul 8, 2024
c342ce1
fix(pu): fix empty_keys_values in init_infer
dyyoungg Jul 11, 2024
6eb772a
feature(pu): add softmoe head option in unizero_multitask
dyyoungg Jul 11, 2024
71f55b4
fix(pu): fix unizero reset in muzero_collector
dyyoungg Jul 12, 2024
445fd70
polish(pu): polish unizero-multitask config
dyyoungg Jul 14, 2024
4954581
fix(pu): fix replay ratio
dyyoungg Jul 16, 2024
44304bf
feature(pu): add moe option of feedforward in transformer backbone
dyyoungg Jul 16, 2024
d6be21a
feature(pu): add value_priority in unizero_multitask
dyyoungg Jul 17, 2024
fde51cc
polish(pu): polish value_priority in unizero_multitask
dyyoungg Jul 17, 2024
b460d2f
sync code
dyyoungg Jul 18, 2024
5117459
fix(pu): fix moe in feedforward layer of transformer and polish configs
dyyoungg Jul 19, 2024
2495d60
feature(pu): add mistralai moe in transformer feedforward and head of…
dyyoungg Jul 23, 2024
95886bd
polish(pu): polish quantize_state_hash and deepcopy
PaParaZz1 Aug 18, 2024
0e49a30
fix(pu): fix np.array dtype bug in buffer
PaParaZz1 Aug 18, 2024
00147f4
polish(pu): use 0 deepcopy in kv_cache operation in collect/eval phas…
PaParaZz1 Aug 19, 2024
b40c71b
polish(pu): use custom deepcopy for kv_cache
PaParaZz1 Aug 22, 2024
2cc81be
polish(pu): use value_array rather than value_list in compute_target_…
PaParaZz1 Aug 22, 2024
bc5332f
polish(pu): optimize compute_target_policy_non_re
PaParaZz1 Aug 22, 2024
a6c6a8e
polish(pu): optimize kv_caching update()
PaParaZz1 Aug 22, 2024
b5dcdcc
polish(pu): kv_cache_dict no to_cpu
PaParaZz1 Aug 22, 2024
5b0cbd4
polish(pu): optimize custom kv_cache copy
PaParaZz1 Aug 22, 2024
0035829
polish(pu): kv_cache_dict no to_cpu
PaParaZz1 Aug 22, 2024
043727b
feature(pu): add unizero ddp config
PaParaZz1 Aug 23, 2024
d568008
fix(pu): fix unizero ddp
dyyoungg Aug 23, 2024
d349137
sync code
dyyoungg Aug 25, 2024
3a344aa
polish(pu): use de kv_cacheepcopy only in recur_infer load
PaParaZz1 Aug 26, 2024
40053f7
Merge branch 'dev-efficiency' of https://github.com/opendilab/LightZe…
dyyoungg Aug 26, 2024
61a1139
sync code
dyyoungg Aug 26, 2024
0e545c7
polish(pu): polish suz dmc config
dyyoungg Aug 26, 2024
bb38a10
sync code
dyyoungg Aug 26, 2024
b813be7
Merge branch 'dev-efficiency' of https://github.com/opendilab/LightZe…
dyyoungg Aug 27, 2024
715d17e
polish(pu): use share_polol for kv_cache in recurrent_inference and u…
jiayilee65 Aug 27, 2024
39d6bbe
polish(pu): all kv_cache copy use predefined share_pool
jiayilee65 Aug 27, 2024
f18be2a
polish(pu): unuse decoder_net and lpips in ddp config
dyyoungg Aug 28, 2024
1d010d3
sync code
dyyoungg Aug 28, 2024
a2ce5a8
feature(pu): add dmc save_replay_gif option
puyuan1996 Aug 30, 2024
abf8924
sync code
dyyoungg Sep 2, 2024
9de4096
polish(pu): polish sampled muzero ctree
dyyoungg Sep 2, 2024
54416e6
Merge branch 'polish-unizero-cont' of https://github.com/opendilab/Li…
dyyoungg Sep 3, 2024
5804ff2
test(pu): add sac cheetah config
dyyoungg Sep 3, 2024
fabffd2
fix(pu): fix render_image in dmc_env
dyyoungg Sep 3, 2024
7d0d4c7
fix(pu): fix reanalyze in sampled unizero
dyyoungg Sep 3, 2024
e666d12
polish(pu): polish policy projector
dyyoungg Sep 4, 2024
fea98ee
feature(pu): add muzero_segment_collector.py
dyyoungg Sep 5, 2024
0391f4c
polish(pu): use uniform prior in ucb_score of suz mcts
dyyoungg Sep 5, 2024
2a376ec
fix(pu): fix self.action_mask_dict init bug
dyyoungg Sep 5, 2024
0121b63
test(pu): use clamp0.9->1
dyyoungg Sep 10, 2024
ed4773b
polish(pu): polish suz
dyyoungg Sep 12, 2024
d36196e
fix(pu): fix muzero_segment_collector
dyyoungg Sep 12, 2024
51e10f2
fix(pu): uz target-value obs also use aug when use_aug=True
dyyoungg Sep 12, 2024
31543c3
sync code
dyyoungg Sep 13, 2024
8615899
fix(pu): fix last_game_segment bug in muzero_segment_collector.py
dyyoungg Sep 13, 2024
4c969f6
fix(pu): one episode done then return in muzero_segment_collector.py
dyyoungg Sep 13, 2024
cf2fd81
fix(pu): fix muzero_collector
dyyoungg Sep 14, 2024
bff16f7
polish(pu): polish unizero config and polish sample from segments
dyyoungg Sep 16, 2024
f0ff953
fix(pu): fix reanalyze in uz
dyyoungg Sep 17, 2024
91d48c1
polish(pu): add batch config and bash
dyyoungg Sep 17, 2024
1c8b92b
polish(pu): polish uz configs
dyyoungg Sep 20, 2024
1eed401
feature(pu): add unizero buffer_reanalyze variant
dyyoungg Sep 20, 2024
380f693
fix(pu): fix uz reanalyze_buffer
dyyoungg Sep 20, 2024
c43cdd4
polish(pu): polish configs
dyyoungg Sep 22, 2024
05a2ec3
feature(pu): add atari_muzero_segment_config
dyyoungg Sep 23, 2024
639c2e1
Merge branch 'dev-efficieny-plus-tune-uz-atari100k' of https://github…
dyyoungg Sep 23, 2024
81b47d2
fix(pu): fix sampled_unizero reanalyze_policy
dyyoungg Sep 23, 2024
a78fa70
polish(pu):polish configs
dyyoungg Sep 24, 2024
dc2d454
polish(pu):polish suz configs
dyyoungg Sep 24, 2024
f634e1f
polish(pu):polish configs
dyyoungg Sep 24, 2024
c19b203
Merge branch 'dev-efficieny-plus-tune-uz-atari100k' of https://github…
dyyoungg Sep 24, 2024
f536f3c
fix(pu): fix root value in suz buffer
dyyoungg Sep 25, 2024
f3e6d8d
fix(pu): fix suz ctree
dyyoungg Sep 25, 2024
b7243ea
polish(pu): polish uz related configs, segment collector, train_entry
puyuan1996 Sep 26, 2024
dba9ca7
polish(pu): polish unizero world_model
puyuan1996 Sep 26, 2024
d5fff6d
polish(pu): polish reanalyze in buffer
puyuan1996 Sep 26, 2024
29197d2
fix(pu): fix entry import and nparray object bug in buffer
puyuan1996 Sep 26, 2024
eb268ac
polish(pu): polish configs
puyuan1996 Sep 26, 2024
5e75d09
polish(pu): polish configs
dyyoungg Sep 26, 2024
31221d8
Merge branch 'dev-efficieny-plus-tune-uz-atari100k-polish' of https:/…
dyyoungg Sep 27, 2024
32ad2d0
polish(pu): fix collector, polish configs
dyyoungg Sep 27, 2024
4a18e33
fix(pu): fix truncation segment sample in buffer
dyyoungg Sep 28, 2024
cb37b29
fix(pu): fix segment sample for uz in buffer
dyyoungg Sep 28, 2024
39051c5
fix(pu): use origin buffer
dyyoungg Sep 29, 2024
5842e89
fix(pu): fixvaluebugV8
jiayilee65 Sep 29, 2024
361d6c6
sync code
dyyoungg Sep 30, 2024
dafb655
fix(pu): fix target action when calculating bootstrap value in unizero
dyyoungg Oct 2, 2024
f7792c0
fix(pu): fix target-action in sampled_unizero buffer
dyyoungg Oct 3, 2024
858006c
polish(pu): delete wrongly added files
dyyoungg Oct 3, 2024
1e13f9b
polish(pu): polish entry/buffer/ctree, and fix index+1 bug in compute…
puyuan1996 Oct 3, 2024
0027834
polish(pu): polish buffer and config
puyuan1996 Oct 8, 2024
46b7096
polish(pu): rename train_xxx_reanalyze to train_xxx_segment
puyuan1996 Oct 8, 2024
a2e1611
polish(pu): polish world_model
puyuan1996 Oct 8, 2024
5d384bb
polish(pu): polish entry comments
puyuan1996 Oct 8, 2024
b7f0f0f
fix(pu): fix reward shape bug in dmc
puyuan1996 Oct 9, 2024
ea16e48
fix(pu): polish sample_orig_reanalyze_batch and fix sample_orig_data …
puyuan1996 Oct 12, 2024
b4ae014
polish(pu): polish comments in _sample_orig_reanalyze_batch
puyuan1996 Oct 16, 2024
6abef12
Merge remote-tracking branch 'origin/dev-unizero-multitask-v2' into d…
puyuan1996 Oct 16, 2024
cc7cc66
polish(pu): adapt muzero_multitask to segment_collector
puyuan1996 Oct 16, 2024
43e0966
fix(pu): fix unizero task_id
dyyoungg Oct 16, 2024
ad01aab
fix(pu): fix unizero task_id
dyyoungg Oct 16, 2024
144e7d4
Merge branch 'dev-uz-multitask-v3' of https://github.com/opendilab/Li…
dyyoungg Oct 17, 2024
9bb1189
polish(pu): polish uz_mt config
dyyoungg Nov 4, 2024
04730e1
polish(pu): polish uz_mt config
dyyoungg Nov 6, 2024
f1d62e0
polish(pu): polish uz_mt config
dyyoungg Nov 6, 2024
025527c
feature(pu): add uz_mt_ddp config
dyyoungg Nov 7, 2024
122bcf2
feature(pu): add atari100k unizero_multitask ddp config
dyyoungg Nov 9, 2024
558c048
fix(pu): fix unizero_multitask ddp config
dyyoungg Nov 9, 2024
140df70
fix(pu): fix unizero_multitask ddp_v2 log in collector/evaluator/tb_l…
dyyoungg Nov 11, 2024
92fd026
fix(pu): fix uz_mt ddp_v2 learner log
dyyoungg Nov 12, 2024
4799b74
feature(pu): add eval_async use ThreadPool
dyyoungg Nov 13, 2024
aeb997c
fix(pu): fix ddp bug when task_num>gpu_num
dyyoungg Nov 13, 2024
9f4fba9
fix(pu): fix log_buffer_memory_usage in ddp setting
dyyoungg Nov 13, 2024
d30cf00
feature(pu): add allocated_batch_sizes option
dyyoungg Nov 13, 2024
77cb8b9
fix(pu): add timeout in eval_async
dyyoungg Nov 14, 2024
eeaf986
fix(pu): use stop_event to quit eval() when timeout in eval_async
dyyoungg Nov 15, 2024
0fb4263
polish(pu): polish unizero_mt configs
dyyoungg Nov 19, 2024
aaa2793
sync code
dyyoungg Nov 24, 2024
13fbe4c
sync code
dyyoungg Nov 24, 2024
d5842f1
feature(pu): add muzero multitask (and its ddp version) pipeline
puyuan1996 Nov 28, 2024
7723f13
polish(pu): polish configs
puyuan1996 Nov 28, 2024
f1e8d8c
polish(pu): polish config
dyyoungg Nov 29, 2024
62c8a96
polish(pu): polish config
dyyoungg Dec 1, 2024
99c08e2
Merge branch 'dev-mz-multitask-ddp' of https://github.com/opendilab/L…
dyyoungg Dec 1, 2024
1edcba3
fix(pu): fix embed dim in uz_multitask pipeline
dyyoungg Dec 2, 2024
4b195eb
feature(pu): add uz finetune config
dyyoungg Dec 2, 2024
13183e7
feature(pu): add uz eval-tsne config
dyyoungg Dec 2, 2024
1e37cae
fix(pu): add uz eval-tsne config
dyyoungg Dec 2, 2024
29298e6
polish(pu): polish tsne-plot legend
dyyoungg Dec 3, 2024
ffdf4db
polish(pu): polish atari multitask related configs
puyuan1996 Dec 18, 2024
2b0af34
polish(pu): polish unizero/muzero multitask related entry
puyuan1996 Dec 18, 2024
0bd688e
polish(pu): delete unused files
puyuan1996 Dec 18, 2024
69a1842
Merge remote-tracking branch 'origin/main' into dev-mz-multitask-ddp
puyuan1996 Dec 18, 2024
d06ce61
polish(pu): polish policy/model in multitask settings
puyuan1996 Dec 18, 2024
3a88f46
feature(pu): add sampled_unizero multitask pipeline
puyuan1996 Dec 24, 2024
563548b
fix(pu): fix sampled_unizero multitask ddp pipeline
puyuan1996 Dec 24, 2024
d8705e6
fix(pu): fix sampled_unizero multitask ddp pipeline
puyuan1996 Dec 24, 2024
a5b38b6
sync code
puyuan1996 Dec 25, 2024
1c8c4fb
polish(pu): polish suz dmc multitask configs
puyuan1996 Dec 26, 2024
aedf83b
fix(pu): fix self.last_batch_obs_eval
Jan 2, 2025
22fd7c1
fix(pu): fix self.last_batch_obs_eval in suz_multitask
dyyoungg Jan 6, 2025
d098e71
feature(pu): add task_complexity_weight option
Jan 6, 2025
8beb492
feature(pu): add task_exploitation_weight option
Jan 7, 2025
afa5123
feature(pu): add concat_task_embed task_embed_option
PaParaZz1 Jan 14, 2025
90d45c3
feature(pu): add register_task_embed task_embed_option
PaParaZz1 Jan 15, 2025
406c607
fix(pu): fix register_task_embed
PaParaZz1 Jan 20, 2025
51d4f17
fix(pu): fix register_task_embed
PaParaZz1 Jan 20, 2025
db958d7
fix(pu): fix cache update()
PaParaZz1 Jan 20, 2025
a38f61c
fix(pu): fix world_model_multitask.py
PaParaZz1 Jan 21, 2025
911ed12
fix(pu): fix self.pos_emb in register_task_embed partially
PaParaZz1 Jan 21, 2025
557e8f9
sync code
PaParaZz1 Jan 21, 2025
8fe1a6d
feature(pu): add register_token_shared option
Feb 7, 2025
97988e2
polish(pu): add moco multigpu support
Feb 7, 2025
1860f11
polish(pu): polish moco multigpu option
Feb 8, 2025
04c4a68
polish(pu): adapt to discrete action space env like atari
Feb 8, 2025
e9ad078
polish(pu): adapt to discrete action space env like atari
Feb 12, 2025
a04b896
tmp: polish(pu): add only_use_moco_stats option
Feb 13, 2025
d53a402
feature(pu): add unizero_multitask atari concat_task_embed support
Feb 17, 2025
772ba00
fix(pu): fix atari uz mt when use concat_task_embed in reanalyze phase
Feb 18, 2025
17eea95
tmp
Feb 19, 2025
f8a8957
tmp:fix(pu): adapt to single-task setting
Feb 19, 2025
3a25c08
fix(pu): fix cal_dormant_ratio
Feb 20, 2025
ed9d7c1
feature(pu): add analysis_dormant_ratio_weight_rank option in single-…
Feb 20, 2025
b7c9295
fix(pu): fix encoder dormant_ratio
Feb 20, 2025
3bb1e57
fix(pu): fix encoder dormant_ratio log
Feb 20, 2025
22a2428
feature(pu): add analysis_dormant_ratio_weight_rank option in unizero…
Feb 20, 2025
1c115c9
feature(pu): add lora option in transformer
Feb 20, 2025
fc3b2e4
fix(pu): fix backward when only_use_moco_stats is True
PaParaZz1 Feb 21, 2025
83a9f35
feature(pu): add share_head and final_norm_option
PaParaZz1 Feb 28, 2025
f0c57bb
tmp
PaParaZz1 Mar 1, 2025
8c36a6f
tmp
PaParaZz1 Mar 1, 2025
05dd910
polish(pu): polish uz-mt config
Mar 7, 2025
fb7520e
polish(pu): polish uz-mt finetune config
Mar 8, 2025
8dc52ea
polish(pu): polish dmc suz_mt config
Mar 11, 2025
dc58845
polish(pu): polish dmc mt config
Mar 13, 2025
4c5b5ab
fix(pu): fix unizero_moco_8gpu_18task
Apr 7, 2025
8e76cbc
fix(pu): fix atari uz mt moco
Apr 8, 2025
26e8a6d
feature(pu): add use_adaptive_scale option in encoder
Apr 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1449,3 +1449,5 @@ events.*
# pooltool-specific stuff
!/assets/pooltool/**
lzero/mcts/ctree/ctree_alphazero/pybind11

zoo/jericho/envs/z-machine-games-master/
10 changes: 10 additions & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
@@ -9,4 +9,14 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment

from .train_muzero_multitask_segment_noddp import train_muzero_multitask_segment_noddp
from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp


from .train_unizero_multitask_serial import train_unizero_multitask_serial
from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp
from .train_unizero_multitask_segment_serial import train_unizero_multitask_segment_serial

from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
from .utils import *
80 changes: 80 additions & 0 deletions lzero/entry/compute_task_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@



import numpy as np
import torch


def symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 归一化,减少目标值的幅度差异。
symlog(x) = sign(x) * log(|x| + 1)
"""
return torch.sign(x) * torch.log(torch.abs(x) + 1)


def inv_symlog(x: torch.Tensor) -> torch.Tensor:
"""
Symlog 的逆操作,用于恢复原始值。
inv_symlog(x) = sign(x) * (exp(|x|) - 1)
"""
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)


def compute_task_weights(
task_rewards: dict,
epsilon: float = 1e-6,
min_weight: float = 0.1,
max_weight: float = 0.5,
temperature: float = 1.0,
use_symlog: bool = True,
) -> dict:
"""
改进后的任务权重计算函数,加入 symlog 处理和鲁棒性设计。

Args:
task_rewards (dict): 每个任务的字典,键为 task_id,值为评估奖励。
epsilon (float): 避免分母为零的小值。
min_weight (float): 权重的最小值,用于裁剪。
max_weight (float): 权重的最大值,用于裁剪。
temperature (float): 控制权重分布的温度系数。
use_symlog (bool): 是否使用 symlog 对 task_rewards 进行矫正。

Returns:
dict: 每个任务的权重,键为 task_id,值为归一化并裁剪后的权重。
"""
# Step 1: 矫正奖励值(可选,使用 symlog)
if use_symlog:
rewards_tensor = torch.tensor(list(task_rewards.values()), dtype=torch.float32)
corrected_rewards = symlog(rewards_tensor).numpy() # 使用 symlog 矫正
task_rewards = dict(zip(task_rewards.keys(), corrected_rewards))

# Step 2: 计算初始权重(反比例关系)
raw_weights = {task_id: 1 / (reward + epsilon) for task_id, reward in task_rewards.items()}

# Step 3: 温度缩放
scaled_weights = {task_id: weight ** (1 / temperature) for task_id, weight in raw_weights.items()}

# Step 4: 归一化权重
total_weight = sum(scaled_weights.values())
normalized_weights = {task_id: weight / total_weight for task_id, weight in scaled_weights.items()}

# Step 5: 裁剪权重,确保在 [min_weight, max_weight] 范围内
clipped_weights = {task_id: np.clip(weight, min_weight, max_weight) for task_id, weight in normalized_weights.items()}

final_weights = clipped_weights
return final_weights

task_rewards_list = [
{"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300},
{"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000},
{"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10},
]

for i, task_rewards in enumerate(task_rewards_list, start=1):
print(f"Case {i}: Original Rewards: {task_rewards}")
print("Original Weights:")
print(compute_task_weights(task_rewards, use_symlog=False))
print("Improved Weights with Symlog:")
print(compute_task_weights(task_rewards, use_symlog=True))
print()
203 changes: 203 additions & 0 deletions lzero/entry/eval_muzero_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import os
from functools import partial
from typing import Optional, Tuple
import logging

import numpy as np
import torch
from tensorboardX import SummaryWriter

from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from lzero.worker import MuZeroEvaluator
from lzero.entry.utils import initialize_zeros_batch
import logging
import os
from functools import partial
from typing import Tuple, Optional

import torch
import wandb
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect, calculate_update_per_collect
import torch.distributed as dist
from ding.utils import set_pkg_seed, get_rank, get_world_size

def eval_muzero_v2(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
num_episodes_each_seed: int = 1,
print_seed_details: int = False,
) -> 'Policy': # noqa
"""
Overview:
The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, StochasticMuZero, GumbelMuZero, UniZero, etc.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero only supports the following algorithms: 'unizero', 'sampled_unizero'"
logging.info(f"Using policy type: {create_cfg.policy.type}")

# Import the appropriate GameBuffer class based on the policy type
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Check for GPU availability and set the device accordingly
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f"Device set to: {cfg.policy.device}")

# Compile the configuration file
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

# Create environment manager
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

# Initialize environment and random seed
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())

# Initialize wandb if specified
if cfg.policy.use_wandb:
logging.info("Initializing wandb...")
wandb.init(
project="LightZero",
config=cfg,
sync_tensorboard=False,
monitor_gym=False,
save_code=True,
)
logging.info("wandb initialization completed!")

# Create policy
logging.info("Creating policy...")
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
logging.info("Policy created successfully!")

# Load pretrained model if specified
if model_path is not None:
logging.info(f"Loading pretrained model from {model_path}...")
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info("Pretrained model loaded successfully!")

# Create core components for training
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
replay_buffer = GameBuffer(cfg.policy)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=cfg.policy)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=cfg.policy)

# Execute the learner's before_run hook
learner.call_hook('before_run')

if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Randomly collect data if specified
if cfg.policy.random_collect_episode_num > 0:
logging.info("Collecting random data...")
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
logging.info("Random data collection completed!")

batch_size = policy._cfg.batch_size

if cfg.policy.multi_gpu:
# Get current world size and rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# Log memory usage of the replay buffer
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)

# Set temperature parameter for data collection
collect_kwargs = {
'temperature': visit_count_temperature(
cfg.policy.manual_temperature_decay,
cfg.policy.fixed_temperature_value,
cfg.policy.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}

# Configure epsilon-greedy exploration
if cfg.policy.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=cfg.policy.eps.start,
end=cfg.policy.eps.end,
decay=cfg.policy.eps.decay,
type_=cfg.policy.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

# Evaluate policy performance
# logging.info(f"Training iteration {learner.train_iter}: Starting evaluation...")
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
# logging.info(f"Training iteration {learner.train_iter}: Evaluation completed, stop condition: {stop}, current reward: {reward}")
# if stop:
# logging.info("Stopping condition met, training ends!")
# break

# Collect new data
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
logging.info(f"Rank {rank}, Training iteration {learner.train_iter}: New data collection completed!")

if world_size > 1:
# Synchronize all ranks before training
try:
dist.barrier()
except Exception as e:
logging.error(f'Rank {rank}: Synchronization barrier failed, error: {e}')
break


policy.recompute_pos_emb_diff_and_clear_cache()



learner.call_hook('after_run')
if cfg.policy.use_wandb:
wandb.finish()
logging.info("===== Training Completed =====")
return policy
Loading