diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index dfd4638e38..d9036e9486 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -121,10 +121,8 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]: Dict: Metrics of the sampling step. List[Dict]: A list of representative samples for logging. """ - with Timer({}, "time/sample_data"): - batch, metrics, repr_samples = await self.sample_strategy.sample( - self.train_step_num + 1 - ) + batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1) + metrics["sample/task_count"] = len(set(eid.task for eid in batch.eids)) return batch, metrics, repr_samples async def need_sync(self) -> bool: