Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Asynchronous dispatcher (#372)
Browse files Browse the repository at this point in the history
* Asynchronous dispatcher

* updates

* updates

* updates

* updates
  • Loading branch information
chicm-ms authored Nov 22, 2018
1 parent 8d63b10 commit a5d614d
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 31 deletions.
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ interface ExperimentParams {
searchSpace: string;
trainingServicePlatform: string;
multiPhase?: boolean;
multiThread?: boolean;
tuner: {
className: string;
builtinTunerName?: string;
Expand Down
6 changes: 5 additions & 1 deletion src/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,16 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner
*
*/
function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string {
function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false, multiThread: boolean = false): string {
let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`;
if (multiPhase) {
command += ' --multi_phase';
}

if (multiThread) {
command += ' --multi_thread';
}

if (tuner.classArgs !== undefined) {
command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`;
}
Expand Down
3 changes: 3 additions & 0 deletions src/nni_manager/core/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
const TRIAL_END = 'EN';
const TERMINATE = 'TE';

const INITIALIZED = 'ID';
const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO';
Expand All @@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
ADD_CUSTOMIZED_TRIAL_JOB,
TERMINATE,

INITIALIZED,
NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER,
NO_MORE_TRIAL_JOBS
Expand All @@ -61,6 +63,7 @@ export {
ADD_CUSTOMIZED_TRIAL_JOB,
TRIAL_END,
TERMINATE,
INITIALIZED,
NEW_TRIAL_JOB,
NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB,
Expand Down
51 changes: 39 additions & 12 deletions src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ import {
} from '../common/trainingService';
import { delay, getLogDir, getMsgDispatcherCommand } from '../common/utils';
import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS,
REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface';

Expand Down Expand Up @@ -127,7 +127,8 @@ class NNIManager implements Manager {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
const dispatcherCommand: string = getMsgDispatcherCommand(
expParams.tuner, expParams.assessor, expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
//expParams.tuner.tunerCommand,
Expand Down Expand Up @@ -159,7 +160,8 @@ class NNIManager implements Manager {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
const dispatcherCommand: string = getMsgDispatcherCommand(
expParams.tuner, expParams.assessor, expParams.multiPhase, expParams.multiThread);
this.log.debug(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
dispatcherCommand,
Expand Down Expand Up @@ -419,16 +421,20 @@ class NNIManager implements Manager {
} else {
this.trialConcurrencyChange = requestTrialNum;
}
for (let i: number = 0; i < requestTrialNum; i++) {

const requestCustomTrialNum: number = Math.min(requestTrialNum, this.customizedTrials.length);
for (let i: number = 0; i < requestCustomTrialNum; i++) {
// ask tuner for more trials
if (this.customizedTrials.length > 0) {
const hyperParams: string | undefined = this.customizedTrials.shift();
this.dispatcher.sendCommand(ADD_CUSTOMIZED_TRIAL_JOB, hyperParams);
} else {
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, '1');
}
}

if (requestTrialNum - requestCustomTrialNum > 0) {
this.requestTrialJobs(requestTrialNum - requestCustomTrialNum);
}

// check maxtrialnum and maxduration here
if (this.experimentProfile.execDuration > this.experimentProfile.params.maxExecDuration ||
this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
Expand Down Expand Up @@ -526,11 +532,9 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) {
throw new Error('Dispatcher error: tuner has not been setup');
}
// TO DO: we should send INITIALIZE command to tuner if user's tuner needs to run init method in tuner
this.log.debug(`Send tuner command: update search space: ${this.experimentProfile.params.searchSpace}`);
this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, this.experimentProfile.params.searchSpace);
this.log.debug(`Send tuner command: ${this.experimentProfile.params.trialConcurrency}`);
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, String(this.experimentProfile.params.trialConcurrency));
this.log.debug(`Send tuner command: INITIALIZE: ${this.experimentProfile.params.searchSpace}`);
// Tuner need to be initialized with search space before generating any hyper parameters
this.dispatcher.sendCommand(INITIALIZE, this.experimentProfile.params.searchSpace);
}

private async onTrialJobMetrics(metric: TrialJobMetric): Promise<void> {
Expand All @@ -541,9 +545,32 @@ class NNIManager implements Manager {
this.dispatcher.sendCommand(REPORT_METRIC_DATA, metric.data);
}

private requestTrialJobs(jobNum: number): void {
if (jobNum < 1) {
return;
}
if (this.dispatcher === undefined) {
throw new Error('Dispatcher error: tuner has not been setup');
}
if (this.experimentProfile.params.multiThread) {
// Send multiple requests to ensure multiple hyper parameters are generated in non-blocking way.
// For a single REQUEST_TRIAL_JOBS request, hyper parameters are generated one by one
// sequentially.
for (let i: number = 0; i < jobNum; i++) {
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, '1');
}
} else {
this.dispatcher.sendCommand(REQUEST_TRIAL_JOBS, String(jobNum));
}
}

private async onTunerCommand(commandType: string, content: string): Promise<void> {
this.log.info(`Command from tuner: ${commandType}, ${content}`);
switch (commandType) {
case INITIALIZED:
// Tuner is intialized, search space is set, request tuner to generate hyper parameters
this.requestTrialJobs(this.experimentProfile.params.trialConcurrency);
break;
case NEW_TRIAL_JOB:
this.waitingTrials.push(content);
break;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export namespace ValidationSchemas {
searchSpace: joi.string().required(),
maxExecDuration: joi.number().min(0).required(),
multiPhase: joi.boolean(),
multiThread: joi.boolean(),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'),
codeDir: joi.string(),
Expand Down
4 changes: 4 additions & 0 deletions src/sdk/pynni/nni/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import importlib

from .constants import ModuleName, ClassName, ClassArgs
from nni.common import enable_multi_thread
from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main')
Expand Down Expand Up @@ -91,6 +92,7 @@ def parse_args():
parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path')
parser.add_argument('--multi_phase', action='store_true')
parser.add_argument('--multi_thread', action='store_true')

flags, _ = parser.parse_known_args()
return flags
Expand All @@ -101,6 +103,8 @@ def main():
'''

args = parse_args()
if args.multi_thread:
enable_multi_thread()

tuner = None
assessor = None
Expand Down
9 changes: 9 additions & 0 deletions src/sdk/pynni/nni/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,12 @@ def init_logger(logger_file_path):
logging.getLogger('matplotlib').setLevel(logging.INFO)

sys.stdout = _LoggerFileWrapper(logger_file)

_multi_thread = False

def enable_multi_thread():
global _multi_thread
_multi_thread = True

def multi_thread_enabled():
return _multi_thread
13 changes: 11 additions & 2 deletions src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
from collections import defaultdict
import json_tricks
import threading

from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase
Expand Down Expand Up @@ -69,7 +70,7 @@ def _pack_parameter(parameter_id, params, customized=False):

class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super()
super().__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
Expand All @@ -85,6 +86,14 @@ def save_checkpoint(self):
if self.assessor is not None:
self.assessor.save_checkpoint()

def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True

def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
Expand Down Expand Up @@ -127,7 +136,7 @@ def handle_report_metric_data(self, data):
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
pass
else:
raise ValueError('Data type not supported: {}'.format(data['type']))

Expand Down
35 changes: 25 additions & 10 deletions src/sdk/pynni/nni/msg_dispatcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
import os
import logging
import json_tricks

from .common import init_logger
from multiprocessing.dummy import Pool as ThreadPool
from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable
from .protocol import CommandType, receive

init_logger('dispatcher.log')
_logger = logging.getLogger(__name__)

class MsgDispatcherBase(Recoverable):
def __init__(self):
if multi_thread_enabled():
self.pool = ThreadPool()

def run(self):
"""Run the tuner.
This function will never return unless raise.
Expand All @@ -39,17 +43,24 @@ def run(self):
if mode == 'resume':
self.load_checkpoint()

while self.handle_request():
pass
while True:
_logger.debug('waiting receive_message')
command, data = receive()
if command is None:
break
if multi_thread_enabled():
self.pool.map_async(self.handle_request, [(command, data)])
else:
self.handle_request((command, data))

_logger.info('Terminated by NNI manager')
if multi_thread_enabled():
self.pool.close()
self.pool.join()

def handle_request(self):
_logger.debug('waiting receive_message')
_logger.info('Terminated by NNI manager')

command, data = receive()
if command is None:
return False
def handle_request(self, request):
command, data = request

_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))

Expand All @@ -60,6 +71,7 @@ def handle_request(self):

command_handlers = {
# Tunner commands:
CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
Expand All @@ -74,6 +86,9 @@ def handle_request(self):

return command_handlers[command](data)

def handle_initialize(self, data):
raise NotImplementedError('handle_initialize not implemented')

def handle_request_trial_jobs(self, data):
raise NotImplementedError('handle_request_trial_jobs not implemented')

Expand Down
8 changes: 8 additions & 0 deletions src/sdk/pynni/nni/multi_phase/multi_phase_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def save_checkpoint(self):
if self.assessor is not None:
self.assessor.save_checkpoint()

def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True

def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
Expand Down
23 changes: 17 additions & 6 deletions src/sdk/pynni/nni/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
# ==================================================================================================

import logging
import threading
from enum import Enum
from .common import multi_thread_enabled


class CommandType(Enum):
Expand All @@ -33,6 +35,7 @@ class CommandType(Enum):
Terminate = b'TE'

# out
Initialized = b'ID'
NewTrialJob = b'TR'
SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO'
Expand All @@ -42,6 +45,7 @@ class CommandType(Enum):
try:
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
_lock = threading.Lock()
except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
import logging
Expand All @@ -53,12 +57,19 @@ def send(command, data):
command: CommandType object.
data: string payload.
"""
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg)
_out_file.write(msg)
_out_file.flush()
global _lock
try:
if multi_thread_enabled():
_lock.acquire()
data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg)
_out_file.write(msg)
_out_file.flush()
finally:
if multi_thread_enabled():
_lock.release()


def receive():
Expand Down
1 change: 1 addition & 0 deletions tools/nni_cmd/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'trainingServicePlatform': And(str, lambda x: x in ['remote', 'local', 'pai', 'kubeflow']),
Optional('searchSpacePath'): os.path.exists,
Optional('multiPhase'): bool,
Optional('multiThread'): bool,
'useAnnotation': bool,
'tuner': Or({
'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch'),
Expand Down
2 changes: 2 additions & 0 deletions tools/nni_cmd/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['description'] = experiment_config['description']
if experiment_config.get('multiPhase'):
request_data['multiPhase'] = experiment_config.get('multiPhase')
if experiment_config.get('multiThread'):
request_data['multiThread'] = experiment_config.get('multiThread')
request_data['tuner'] = experiment_config['tuner']
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
Expand Down

0 comments on commit a5d614d

Please sign in to comment.