Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Bugfix: wrong device placement and invalid CUDA ordinal when using CGO engine #4086

Merged
merged 17 commits into from
Oct 11, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions nni/common/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,69 @@
# Licensed under the MIT license.

from dataclasses import dataclass
from abc import ABC, abstractmethod

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal


@dataclass
class GPUDevice:
class Device(ABC):
node_id: str
gpu_id: int
status: Literal['idle', 'busy', 'unknown'] = 'idle'

def __eq__(self, o) -> bool:
if type(self) == type(o):
return self.node_id == o.node_id
else:
return False

def __lt__(self, o) -> bool:
return self.node_id < o.node_id

def set_status(self, status):
self.status = status

def __repr__(self) -> str:
return "{Abstract Device %s, Status %s}" % (self.node_id, self.status)

@abstractmethod
def device_repr(self) -> str:
pass


@dataclass
class GPUDevice(Device):
gpu_id: str = -1

def __init__(self, node_id, gpu_id, status='idle'):
self.node_id = node_id
self.gpu_id = gpu_id
self.status = status
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this typical usage of @dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is typical. Since the dataclass GPUDevice inherits the dataclass Device, the order in __init__ by default should be node_id, status, gpu_id, which looks not nature to human. So I explicitly declare the order here with the expected order.


def __eq__(self, o: Device) -> bool:
if isinstance(o, GPUDevice):
return self.node_id == o.node_id and self.gpu_id == o.gpu_id
return False

def __lt__(self, o) -> bool:
def __lt__(self, o: Device) -> bool:
if self.node_id < o.node_id:
return True
elif self.node_id > o.node_id:
return False
else:
return self.gpu_id < o.gpu_id
if isinstance(o, GPUDevice):
return self.gpu_id < o.gpu_id
else:
return True

def __repr__(self) -> str:
return "{Environment %s, GPU %d, Status %s}" % (self.node_id, self.gpu_id, self.status)

def __hash__(self) -> int:
return hash(self.node_id + '_' + str(self.gpu_id))

def set_status(self, status):
self.status = status

def device_repr(self,):
return f"cuda:{self.gpu_id}"
37 changes: 34 additions & 3 deletions nni/retiarii/codegen/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# Licensed under the MIT license.

import logging
from typing import List, Tuple, Any
from typing import Dict, List, Tuple, Any

from nni.retiarii.operation_def.torch_op_def import ToDevice
from nni.common.device import Device, GPUDevice

from ..graph import IllegalGraphError, Edge, Graph, Node, Model

Expand Down Expand Up @@ -70,7 +73,7 @@ def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
else:
inputs_value.append(None)
Expand Down Expand Up @@ -98,15 +101,39 @@ def _remove_prefix(names, graph_name):
return names[len(graph_name):] if names.startswith(graph_name) else names


def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
'''
Since CUDA_VISIBLE_DEVICES will be set to the list of real GPU ID,
we need to remap the GPU ID when generating code to match them correctly.
For example, when CUDA_VISIBLE_DEVICES="0,3", we need to use "cuda:0", "cuda:1" in the generated code.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't get the point, why CUDA_VISIBLE_DEVICES="0,3" equals to cuda:0, cuda:1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nni_manager sets CUDA_VISIBLE_DEVICES to the allocated GPUs when running a trial, which are the physical GPU IDs.

When CUDA_VISIBLE_DEVICES=0,3, Pytorch identifies there are two GPUs, and names them as cuda:0 and cuda:1.

Thus, when generating code that explicitly place operations (e.g., x.to("cuda:1")), we should use the "cuda:X" ID instead of physical GPU ID.

'''
unique_devices = sorted(list(set([e for e in placement.values() if isinstance(e, GPUDevice)])))
node_gpu_cnt = {}
cuda_remapped_id = {}
for d in unique_devices:
if d.node_id not in node_gpu_cnt:
node_gpu_cnt[d.node_id] = 0
node_gpu_cnt[d.node_id] += 1
cuda_remapped_id[d] = node_gpu_cnt[d.node_id] - 1

return cuda_remapped_id


def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()

# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
cuda_remapped_id = None
if placement:
cuda_remapped_id = generate_cuda_mapping(placement)
for node in nodes:
if node.operation:
if placement and isinstance(node.operation, ToDevice):
node.operation.override_device_repr("cuda:%d" % cuda_remapped_id[node.operation.device])

if node.operation.type == 'shared':
continue
pkg_name = node.operation.get_import_pkg()
Expand All @@ -115,7 +142,11 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device_repr()}')")
if isinstance(placement[node], GPUDevice):
device_repr = "cuda:%d" % cuda_remapped_id[placement[node]]
else:
device_repr = placement[node].device_repr()
node_codes.append(f"{node_code}.to('{device_repr}')")
else:
node_codes.append(node_code)

Expand Down
55 changes: 29 additions & 26 deletions nni/retiarii/execution/cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
from typing import Iterable, List, Dict, Tuple

from nni.common.device import GPUDevice
from nni.common.device import GPUDevice, Device
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node
Expand All @@ -33,24 +33,23 @@ class CGOExecutionEngine(AbstractExecutionEngine):

Parameters
----------
devices : List[str] or List[GPUDevice]
devices : List[Device]
Available devices for execution.
If a list of str is provided, it will build a list of GPUDevice in a server named ``single_server``
max_concurrency : int
The maximum number of trials to run concurrently.
batch_waiting_time: int
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
"""

def __init__(self, devices: List[GPUDevice] = None,
def __init__(self, devices: List[Device] = None,
max_concurrency: int = None,
batch_waiting_time: int = 60,
) -> None:
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[GPUDevice] = []
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
for device in devices:
self.available_devices.append(device)
Expand All @@ -61,7 +60,7 @@ def __init__(self, devices: List[GPUDevice] = None,
self._original_models = {}
self._original_model_to_multi_model = {}
self._trial_to_original_models = {}
self._trial_used_devices: Dict[int, List[GPUDevice]] = {}
self._trial_used_devices: Dict[int, List[Device]] = {}

self._history: List[Model] = []

Expand Down Expand Up @@ -110,6 +109,15 @@ def _consume_queue(self):
self._queue_lock.release()
time.sleep(1)

def _extract_placement_constaint(self, placement_mapping: Dict[Node, Device]):
unique_gpus = sorted(list(set([ e for e in placement_mapping.values() if isinstance(e, GPUDevice)])))
placement_constraint = None
if len(unique_gpus) > 0:
placement_constraint = {}
placement_constraint['type'] = 'Device'
placement_constraint['gpus'] = [(e.node_id, e.gpu_id) for e in unique_gpus]
return placement_constraint

def _submit_models_in_batch(self, *models: List[Model]) -> None:
_logger.info('%d models are submitted in batch', len(models))
logical = self._build_logical(models)
Expand All @@ -120,9 +128,10 @@ def _submit_models_in_batch(self, *models: List[Model]) -> None:
phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator)
trial_id = send_trial(data.dump())
placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial
self._trial_used_devices[trial_id] = list([_ for _ in set(placement.values()) if isinstance(_, GPUDevice)])
self._trial_used_devices[trial_id] = list(set([_ for _ in placement.values() if isinstance(_, GPUDevice)]))
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

# currently, it is impossible for search strategy to submit models more than the number of available devices
for used_device in self._trial_used_devices[trial_id]:
Expand All @@ -139,14 +148,18 @@ def _submit_models_in_batch(self, *models: List[Model]) -> None:
def list_models(self) -> Iterable[Model]:
return self._history

def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, GPUDevice], List[Model]]]:
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, Dict[Node, Device], List[Model]]]:
"""
Return the assembled models as a list of tuple.
Each tuple contains the assembled model, the device placement of graph nodes, and the original models.
"""
# try to use the available_devices first so that it can be launched as early as possible
# if free devices are not enough to assemble all models in one trial, try all devices
if len(self.available_devices) > 0:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.available_devices)
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.available_devices)

if len(self.available_devices) == 0 or len(grouped_models) > 1:
grouped_models: List[Dict[Model, GPUDevice]] = AssemblePolicy().group(logical_plan, self.all_devices)
grouped_models: List[Dict[Model, Device]] = AssemblePolicy().group(logical_plan, self.all_devices)

phy_models_and_placements = []
for multi_model in grouped_models:
Expand Down Expand Up @@ -256,17 +269,7 @@ def trial_execute_graph(cls) -> None:
os.remove(file_name)


def _remap_cuda_device(group_model: Dict[Model, GPUDevice]):
used_devices = {}
for m in group_model:
if group_model[m].node_id not in used_devices:
used_devices[group_model[m].node_id] = {}
if isinstance(group_model[m], GPUDevice):
if group_model[m].gpu_id not in used_devices[group_model[m].node_id]:
n_used_gpu_in_server = len(used_devices[group_model[m].node_id])
used_devices[group_model[m].node_id][group_model[m].gpu_id] = n_used_gpu_in_server
group_model[m].gpu_id = used_devices[group_model[m].node_id][group_model[m].gpu_id]
return group_model



class AssemblePolicy:
Expand All @@ -282,7 +285,7 @@ def _is_related_node(model: Model, node: Node):

@staticmethod
def _check_graph_connectivity(model: Model,
group_model: Dict[Model, GPUDevice],
group_model: Dict[Model, Device],
logical_plan: LogicalPlan) -> bool:
for edge in logical_plan.logical_graph.edges:
if AssemblePolicy._is_related_node(model, edge.head) or \
Expand All @@ -294,7 +297,7 @@ def _check_graph_connectivity(model: Model,
return False

@staticmethod
def _check_evaluator(new_model: Model, group_model: Dict[Model, GPUDevice]) -> bool:
def _check_evaluator(new_model: Model, group_model: Dict[Model, Device]) -> bool:
if not (isinstance(new_model.evaluator, Lightning)
and isinstance(new_model.evaluator.module, MultiModelSupervisedLearningModule)):
return False
Expand All @@ -318,11 +321,11 @@ def group(logical_plan, available_devices):
if len(group_model) > 0 and \
(AssemblePolicy._check_graph_connectivity(m, group_model, logical_plan) == False or
AssemblePolicy._check_evaluator(m, group_model) == False):
all_grouped_models.append(_remap_cuda_device(group_model))
all_grouped_models.append(group_model)
group_model = {}
group_model[m] = available_devices[idx % len(available_devices)]
if len(group_model) == len(available_devices) or \
idx == len(logical_plan.models) - 1:
all_grouped_models.append(_remap_cuda_device(group_model))
all_grouped_models.append(group_model)
group_model = {}
return all_grouped_models
22 changes: 13 additions & 9 deletions nni/retiarii/execution/logical_optimizer/logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
# Licensed under the MIT license.

import copy
from typing import Dict, Tuple, Any, Union
from typing import Dict, Tuple, Any

from nni.retiarii.utils import uid
from nni.common.device import GPUDevice
from nni.common.device import Device

from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation


class CPUDevice:
class CPUDevice(Device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a little strange that why not put CPUDevice into device.py? it should also be a dataclass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good comment. I have moved CPUDevice in nni.common.device and mark it as a dataclass.

def __init__(self, node_id):
self.node_id = node_id
self.device = 'cpu'

def __repr__(self) -> str:
return "{CPU Device, NodeID %s, Status %s}" % (self.node_id, self.status)

def device_repr(self):
return "cpu"

Expand All @@ -25,7 +28,7 @@ def __init__(self, graph, node_id, name, operation, _internal=False):
super().__init__(graph, node_id, name, operation, _internal=_internal)
self.related_models = []

def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
raise NotImplementedError

def _fork_to(self, graph: Graph):
Expand Down Expand Up @@ -92,7 +95,7 @@ def __init__(self, logical_graph: LogicalGraph,
self.original_graph = original_graph
self.original_node = original_node

def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) -> Tuple[Node, GPUDevice]:
def assemble(self, multi_model_placement: Dict[Model, Device]) -> Tuple[Node, Device]:
model_id = self.original_node.graph.model.model_id
new_node = Node(self.original_node.graph, self.original_node.id,
f"M_{model_id}_" +
Expand Down Expand Up @@ -138,8 +141,8 @@ def _merge_graph(self, from_graph):
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()

def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) \
-> Tuple[Model, Dict[Node, Union[GPUDevice, CPUDevice]]]:
def assemble(self, multi_model_placement: Dict[Model, Device]) \
-> Tuple[Model, Dict[Node, Device]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add docstring for this function

Copy link
Contributor Author

@hzhua hzhua Sep 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I have added the docstring for assemble in AbstractLogicalNode. Also, I add comments in each type of logical node to explain its function and how should they be assembled.

phy_model = Model(_internal=True)
phy_graph = self.lp_model.root_graph._fork_to(phy_model)
phy_graph._rename_graph(phy_graph.name, "_model")
Expand Down Expand Up @@ -224,7 +227,7 @@ def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) \
# If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges = phy_graph.edges.copy()
# Avoid a node is copied multiple times on the same device
copied_op: Dict[Tuple(Node, Union[GPUDevice, CPUDevice]), Node] = {}
copied_op: Dict[Tuple(Node, Device), Node] = {}
for edge in existing_edges:
head_placement = node_placements[edge.head]
tail_placement = node_placements[edge.tail]
Expand All @@ -238,11 +241,12 @@ def assemble(self, multi_model_placement: Dict[Model, GPUDevice]) \
dst_name = edge.head.name + "_to_" + edge.tail.name
to_operation = Operation.new(
'ToDevice', {
"device": tail_placement.device_repr(), "src": (
"device": tail_placement, "src": (
edge.head.name, edge.head_slot), "dst": dst_name})
to_node = Node(phy_graph, uid(), dst_name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
node_placements[to_node] = head_placement
edge.head = to_node
edge.head_slot = None

Expand Down
8 changes: 6 additions & 2 deletions nni/retiarii/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def _validate_placement_constraint(self, placement_constraint):
raise ValueError('placement_constraint.type must be either `None`,. `GPUNumber` or `Device`')
if placement_constraint['type'] == 'None' and len(placement_constraint['gpus']) > 0:
raise ValueError('placement_constraint.gpus must be an empty list when type == None')
if placement_constraint['type'] == 'Device' and len(placement_constraint['gpus']) != 1:
raise ValueError('placement_constraint.gpus must be a list of number (currently only support one host)')
if placement_constraint['type'] == 'GPUNumber':
if len(placement_constraint['gpus']) != 1:
raise ValueError('placement_constraint.gpus currently only support one host when type == GPUNumber')
for e in placement_constraint['gpus']:
if not isinstance(e, int):
raise ValueError('placement_constraint.gpus must be a list of number when type == GPUNumber')
if placement_constraint['type'] == 'Device':
for e in placement_constraint['gpus']:
if not isinstance(e, tuple):
Expand Down
Loading