diff --git a/swanlab/converter/wb/wb_converter.py b/swanlab/converter/wb/wb_converter.py index 5f0dfae4..673459d1 100644 --- a/swanlab/converter/wb/wb_converter.py +++ b/swanlab/converter/wb/wb_converter.py @@ -31,6 +31,13 @@ def parse_wandb_logs(self, wb_project: str, wb_entity: str, wb_run_id: str = Non raise TypeError( "Wandb Converter requires wandb. Install with 'pip install wandb'." ) + + try: + import pandas as pd + except ImportError as e: + raise TypeError( + "Wandb Converter requires pandas when process wandb logs. Install with 'pip install pandas'." + ) client = wandb.Api() @@ -51,11 +58,16 @@ def parse_wandb_logs(self, wb_project: str, wb_entity: str, wb_run_id: str = Non workspace=self.workspace, experiment_name=wb_run.name, description=wb_run.notes, - mode="cloud" if self.cloud else "local", + cloud=self.cloud, logdir=self.logdir, ) else: swanlab_run = swanlab.get_run() + + try: + wb_run_metadata = {"wandb_metadata": wb_run_metadata} + except: + wb_run_metadata = None wb_config = { "wandb_run_id": wb_run.id, @@ -64,8 +76,8 @@ def parse_wandb_logs(self, wb_project: str, wb_entity: str, wb_run_id: str = Non "wandb_user": wb_run.user, "wandb_tags": wb_run.tags, "wandb_url": wb_run.url, - "wandb_metadata": wb_run.metadata, } + wb_config.update(wb_run_metadata) swanlab_run.config.update(wb_config) swanlab_run.config.update(wb_run.config) @@ -73,7 +85,13 @@ def parse_wandb_logs(self, wb_project: str, wb_entity: str, wb_run_id: str = Non # Get the first history record to extract available keys history = wb_run.history(stream="default") if len(history) > 0: - keys = [key for key in history[0].keys() if not key.startswith("_")] + # 检查 history 是否为 DataFrame 类型 + if isinstance(history, pd.DataFrame): + # 如果是 DataFrame,直接获取列名 + keys = [key for key in history.columns if not key.startswith("_")] + else: + # 原来的逻辑,假设是字典列表 + keys = [key for key in history[0].keys() if not key.startswith("_")] else: keys = []