Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Bugfix: wrong device placement and invalid CUDA ordinal when using CGO engine #4086

Merged
merged 17 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dependencies/recommended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.2.8, < 1.4.2
pytorch-lightning >= 1.4.2
onnx
peewee
graphviz
Expand Down
44 changes: 37 additions & 7 deletions nni/common/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,69 @@
# Licensed under the MIT license.

from dataclasses import dataclass
from abc import ABC, abstractmethod

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal


@dataclass
class GPUDevice:
class Device(ABC):
node_id: str
gpu_id: int
status: Literal['idle', 'busy', 'unknown'] = 'idle'

def __eq__(self, o) -> bool:
if type(self) == type(o):
return self.node_id == o.node_id
else:
return False

def __lt__(self, o) -> bool:
return self.node_id < o.node_id

def set_status(self, status):
self.status = status

def __repr__(self) -> str:
return "{Abstract Device %s, Status %s}" % (self.node_id, self.status)

@abstractmethod
def device_repr(self) -> str:
pass


@dataclass
class GPUDevice(Device):
gpu_id: str = -1

def __init__(self, node_id, gpu_id, status='idle'):
self.node_id = node_id
self.gpu_id = gpu_id
self.status = status
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this typical usage of @dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is typical. Since the dataclass GPUDevice inherits the dataclass Device, the order in __init__ by default should be node_id, status, gpu_id, which looks not nature to human. So I explicitly declare the order here with the expected order.


def __eq__(self, o: Device) -> bool:
if isinstance(o, GPUDevice):
return self.node_id == o.node_id and self.gpu_id == o.gpu_id
return False

def __lt__(self, o) -> bool:
def __lt__(self, o: Device) -> bool:
if self.node_id < o.node_id:
return True
elif self.node_id > o.node_id:
return False
else:
return self.gpu_id < o.gpu_id
if isinstance(o, GPUDevice):
return self.gpu_id < o.gpu_id
else:
return True

def __repr__(self) -> str:
return "{Environment %s, GPU %d, Status %s}" % (self.node_id, self.gpu_id, self.status)

def __hash__(self) -> int:
return hash(self.node_id + '_' + str(self.gpu_id))

def set_status(self, status):
self.status = status

def device_repr(self,):
return f"cuda:{self.gpu_id}"
37 changes: 34 additions & 3 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# Licensed under the MIT license.

import logging
from typing import List, Tuple, Any
from typing import Dict, List, Tuple, Any

from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.common.device import Device, GPUDevice

from ..graph import IllegalGraphError, Edge, Graph, Node, Model

Expand Down Expand Up @@ -70,7 +73,7 @@ def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
Expand Down Expand Up @@ -98,15 +101,39 @@ def _remove_prefix(names, graph_name):
return names[len(graph_name):] if names.startswith(graph_name) else names


def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
'''
Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
we need to remap the GPU ID when generating code to match them correctly.
For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't get the point, why CUDA_VISIBLE_DEVICES="0,3" equals to cuda:0, cuda:1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nni_manager sets CUDA_VISIBLE_DEVICES to the allocated GPUs when running a trial, which are the physical GPU IDs.

When CUDA_VISIBLE_DEVICES=0,3, Pytorch identifies there are two GPUs, and names them as cuda:0 and cuda:1.

Thus, when generating code that explicitly place operations (e.g., x.to("cuda:1")), we should use the "cuda:X" ID instead of physical GPU ID.

'''
unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
node_gpu_cnt = {}
cuda_remapped_id = {}
for d in unique_devices:
if d.node_id not in node_gpu_cnt:
node_gpu_cnt[d.node_id] = 0
node_gpu_cnt[d.node_id] += 1
cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1

return cuda_remapped_id


def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()

# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])

if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg()
Expand All @@ -115,7 +142,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device_repr()}')")
if isinstance(placement[node], GPUDevice):
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()
node_codes.append(f"{node_code}.to('{device_repr}')")
else:
node_codes.append(node_code)

Expand Down
41 changes: 35 additions & 6 deletions nni/retiarii/evaluator/pytorch/cgo/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer

from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
Expand Down Expand Up @@ -53,6 +54,13 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
"""Perform a all_gather on all processes """
return tensor

def teardown(self):
"""
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""
pass

@property
def root_device(self) -> torch.device:
return torch.device(self.device)
Expand All @@ -78,10 +86,13 @@ def broadcast(self, obj: object, src: int = 0) -> object:

def get_accelerator_connector(
num_processes: int = 1,
devices: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
ipus: Optional[int] = None,
distributed_backend: Optional[str] = None,
auto_select_gpus: bool = False,
accelerator: Optional[Union[str, Accelerator]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
num_nodes: int = 1,
sync_batchnorm: bool = False,
benchmark: bool = False,
Expand All @@ -90,17 +101,35 @@ def get_accelerator_connector(
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None):
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector(
num_processes, tpu_cores, distributed_backend, auto_select_gpus, gpus, num_nodes, sync_batchnorm, benchmark,
replace_sampler_ddp, deterministic, precision, amp_backend, amp_level, plugins
num_processes,
devices,
tpu_cores,
ipus,
distributed_backend,
accelerator,
gpus,
gpu_ids,
num_nodes,
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic,
precision,
amp_backend,
amp_level,
plugins,
)


@serialize_cls
class BypassAccelerator(Accelerator):
def __init__(self, precision_plugin=None, device="cpu"):
def __init__(self, precision_plugin=None, device="cpu", **trainer_kwargs):
if precision_plugin is None:
precision_plugin = get_accelerator_connector().precision_plugin
precision_plugin = get_accelerator_connector(**trainer_kwargs).select_precision_plugin()

# pylint: disable=abstract-class-instantiated
super().__init__(precision_plugin=precision_plugin, training_type_plugin=BypassPlugin(device))
2 changes: 1 addition & 1 deletion nni/retiarii/evaluator/pytorch/cgo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def __init__(self, use_cgo=False, **trainer_kwargs):
if use_cgo:
if "accelerator" in trainer_kwargs:
raise ValueError("accelerator should not be set when cross-graph optimization is enabled.")
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu')
trainer_kwargs['accelerator'] = BypassAccelerator(device='cpu', **trainer_kwargs)

super().__init__(**trainer_kwargs)
55 changes: 29 additions & 26 deletions nni/retiarii/execution/cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
from typing import Iterable, List, Dict, Tuple

from nni.common.device import GPUDevice
from nni.common.device import GPUDevice, Device
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node
Expand All @@ -33,24 +33,23 @@ class CGOExecutionEngine(AbstractExecutionEngine):

Parameters
----------
devices : List[str] or List[GPUDevice]
devices : List[Device]
Available devices for execution.
If a list of str is provided, it will build a list of GPUDevice in a server named ``single_server``
max_concurrency : int
The maximum number of trials to run concurrently.
batch_waiting_time: int
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""

def __init__(self, devices: List[GPUDevice] = None,
def __init__(self, devices: List[Device] = None,
max_concurrency: int = None,
batch_waiting_time: int = 60,
) -> None:
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[GPUDevice] = []
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
for device in devices:
self.available_devices.append(device)
Expand All @@ -61,7 +60,7 @@ def __init__(self, devices: List[GPUDevice] = None,
self._original_models = {}
self._original_model_to_multi_model = {}
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[GPUDevice]] = {}
self._trial_used_devices: Dict[int, List[Device]] = {}

self._history: List[Model] = []

Expand Down Expand Up @@ -110,6 +109,15 @@ def _consume_queue(self):
self._queue_lock.release()
time.sleep(1)

def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([ e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
placement_constraint['type'] = 'Device'
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
return placement_constraint

def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
logical = self._build_logical(models)
Expand All @@ -120,9 +128,10 @@ def _submit_models_in_batch(self, *models: List[Model]) -> None:
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator)
trial_id = send_trial(data.dump())
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list([_ for _ in set(placement.values()) if isinstance(_, GPUDevice)])
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
Expand All @@ -139,14 +148,18 @@ def _submit_models_in_batch(self, *models: List[Model]) -> None:
def list_models(self) -> Iterable[Model]:
return self._history

def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, GPUDevice], List[Model]]]:
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.available_devices)
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)

if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.all_devices)
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)

phy_models_and_placements = []
for multi_model in grouped_models:
Expand Down Expand Up @@ -256,17 +269,7 @@ def trial_execute_graph(cls) -> None:
os.remove(file_name)


def _remap_cuda_device(group_model: Dict[Model, GPUDevice]):
used_devices = {}
for m in group_model:
if group_model[m].node_id not in used_devices:
used_devices[group_model[m].node_id] = {}
if isinstance(group_model[m], GPUDevice):
if group_model[m].gpu_id not in used_devices[group_model[m].node_id]:
n_used_gpu_in_server = len(used_devices[group_model[m].node_id])
used_devices[group_model[m].node_id][group_model[m].gpu_id] = n_used_gpu_in_server
group_model[m].gpu_id = used_devices[group_model[m].node_id][group_model[m].gpu_id]
return group_model



class AssemblePolicy:
Expand All @@ -282,7 +285,7 @@ def _is_related_node(model: Model, node: Node):

@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, GPUDevice],
group_model: Dict[Model, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
Expand All @@ -294,7 +297,7 @@ def _check_graph_connectivity(model: Model,
return False

@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, GPUDevice]) -> bool:
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelSupervisedLearningModule)):
return False
Expand All @@ -318,11 +321,11 @@ def group(logical_plan, available_devices):
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(_remap_cuda_device(group_model))
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(_remap_cuda_device(group_model))
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models
Loading