Skip to content

Commit

Permalink
Fix reporter
Browse files Browse the repository at this point in the history
  • Loading branch information
KKIEEK authored and KKIEEK committed Dec 2, 2022
1 parent 1e8f996 commit 96ac47a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions siatune/mm/hooks/reporter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) SI-Analytics. All rights reserved.
import ray
from mmcv.runner import HOOKS, BaseRunner
from mmcv.runner.dist_utils import get_dist_info
from mmcv.runner.hooks.logger import LoggerHook
from ray.air import session
from torch import distributed as dist


Expand Down Expand Up @@ -90,4 +90,4 @@ def log(self, runner: BaseRunner) -> None:
filter(lambda elem: self.filtering_key in elem, tags.keys())):
return
tags['global_step'] = self.get_iter(runner)
ray.tune.report(**tags)
session.report(tags)
8 changes: 4 additions & 4 deletions siatune/mm/tasks/mmtrainbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def create_trainable(self) -> TorchTrainer:
return TorchTrainer(
self.context_aware_run,
scaling_config=ScalingConfig(
num_workers=self.num_workers,
use_gpu=torch.cuda.is_available(),
resources_per_worker=dict(
trainer_resources=dict(
CPU=self.num_cpus_per_worker,
GPU=self.num_gpus_per_worker)),
GPU=self.num_gpus_per_worker),
num_workers=self.num_workers,
use_gpu=torch.cuda.is_available()),
torch_config=TorchConfig(backend='gloo'))

0 comments on commit 96ac47a

Please sign in to comment.