Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Retiarii] Remove unused code and enrich integration tests (#4097)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Sep 13, 2021
1 parent 0918ea0 commit 619177b
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 330 deletions.
1 change: 0 additions & 1 deletion nni/retiarii/evaluator/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .lightning import *
305 changes: 0 additions & 305 deletions nni/retiarii/evaluator/pytorch/base.py

This file was deleted.

Empty file.
25 changes: 12 additions & 13 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,38 @@

import atexit
import logging
import os
import socket
import time
from dataclasses import dataclass
import os
from pathlib import Path
import socket
from subprocess import Popen
from threading import Thread
import time
from typing import Any, List, Optional, Union

import colorama
import psutil

import torch
import torch.nn as nn
import nni.runtime.log
from nni.experiment import Experiment, TrainingServiceConfig
from nni.experiment import management, launcher, rest
from nni.common.device import GPUDevice
from nni.experiment import Experiment, TrainingServiceConfig, launcher, management, rest
from nni.experiment.config import util
from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
from nni.common.device import GPUDevice

from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict
from ..graph import Model, Evaluator
from ..graph import Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from ..strategy import BaseStrategy
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation
from ..oneshot.interface import BaseOneShotTrainer
from ..strategy import BaseStrategy

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,7 +70,7 @@ def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform = training_service_platform)
self.training_service = util.training_service_config_factory(platform=training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py'

def __setattr__(self, key, value):
Expand Down Expand Up @@ -117,6 +114,7 @@ def _validation_rules(self):
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}


def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
Expand Down Expand Up @@ -220,6 +218,7 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine

# assert self.config.trial_gpu_number==1, "trial_gpu_number must be 1 to use CGOExecutionEngine"
assert self.config.batch_waiting_time is not None
devices = self._construct_devices()
Expand Down Expand Up @@ -273,14 +272,14 @@ def start(self, port: int = 8080, debug: bool = False) -> None:
def _construct_devices(self):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine_idx, machine in enumerate(self.config.training_service.machine_list):
for machine in self.config.training_service.machine_list:
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
else:
for gpu_idx in self.config.training_service.gpu_indices:
devices.append(GPUDevice('local', gpu_idx))
return devices

def _create_dispatcher(self):
return self._dispatcher

Expand Down
Loading

0 comments on commit 619177b

Please sign in to comment.