Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[PAI training service] Support running multiple PAI experiment (#348)
Browse files Browse the repository at this point in the history
* Change base image from devel to runtime, to reduce docker image size

* Support running multiple experiment for PAI

* Fix a bug regarding to recuisively reference between paiRestServer and
paiTrainingService
  • Loading branch information
yds05 authored Nov 12, 2018
1 parent 35e0832 commit b1d4c12
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 25 deletions.
20 changes: 16 additions & 4 deletions src/nni_manager/common/experimentStartupInfo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ import * as component from '../common/component';
class ExperimentStartupInfo {
private experimentId: string = '';
private newExperiment: boolean = true;
private basePort: number = -1;
private initialized: boolean = false;
private initTrialSequenceID: number = 0;

public setStartupInfo(newExperiment: boolean, experimentId: string): void {
public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void {
assert(!this.initialized);
assert(experimentId.trim().length > 0);

this.newExperiment = newExperiment;
this.experimentId = experimentId;
this.basePort = basePort;
this.initialized = true;
}

Expand All @@ -44,6 +46,12 @@ class ExperimentStartupInfo {
return this.experimentId;
}

public getBasePort(): number {
assert(this.initialized);

return this.basePort;
}

public isNewExperiment(): boolean {
assert(this.initialized);

Expand All @@ -66,6 +74,10 @@ function getExperimentId(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId();
}

function getBasePort(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getBasePort();
}

function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
}
Expand All @@ -78,9 +90,9 @@ function getInitTrialSequenceId(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getInitTrialSequenceId();
}

function setExperimentStartupInfo(newExperiment: boolean, experimentId: string): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId);
function setExperimentStartupInfo(newExperiment: boolean, experimentId: string, basePort: number): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId, basePort);
}

export { ExperimentStartupInfo, getExperimentId, isNewExperiment,
export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment,
setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId };
16 changes: 12 additions & 4 deletions src/nni_manager/common/restServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

'use strict';

import * as assert from 'assert';
import * as express from 'express';
import * as http from 'http';
import { Deferred } from 'ts-deferred';
import { getLogger, Logger } from './log';
import { getBasePort } from './experimentStartupInfo';

/**
* Abstraction class to create a RestServer
Expand All @@ -39,13 +41,20 @@ export abstract class RestServer {
protected port?: number;
protected app: express.Application = express();
protected log: Logger = getLogger();
protected basePort?: number;

constructor() {
this.port = getBasePort();
assert(this.port && this.port > 1024);
}

get endPoint(): string {
// tslint:disable-next-line:no-http-string
return `http://${this.hostName}:${this.port}`;
}

public start(port?: number, hostName?: string): Promise<void> {
public start(hostName?: string): Promise<void> {
this.log.info(`RestServer start`);
if (this.startTask !== undefined) {
return this.startTask.promise;
}
Expand All @@ -56,9 +65,8 @@ export abstract class RestServer {
if (hostName) {
this.hostName = hostName;
}
if (port) {
this.port = port;
}

this.log.info(`RestServer base port is ${this.port}`);

this.server = this.app.listen(this.port as number, this.hostName).on('listening', () => {
this.startTask.resolve();
Expand Down
2 changes: 1 addition & 1 deletion src/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ function prepareUnitTest(): void {
Container.snapshot(TrainingService);
Container.snapshot(Manager);

setExperimentStartupInfo(true, 'unittest');
setExperimentStartupInfo(true, 'unittest', 8080);
mkDirPSync(getLogDir());

const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
Expand Down
8 changes: 4 additions & 4 deletions src/nni_manager/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ import {
import { PAITrainingService } from './training_service/pai/paiTrainingService'


function initStartupInfo(startExpMode: string, resumeExperimentId: string) {
function initStartupInfo(startExpMode: string, resumeExperimentId: string, basePort: number) {
const createNew: boolean = (startExpMode === 'new');
const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId);
setExperimentStartupInfo(createNew, expId, basePort);
}

async function initContainer(platformMode: string): Promise<void> {
Expand Down Expand Up @@ -93,14 +93,14 @@ if (startMode === 'resume' && experimentId.trim().length < 1) {
process.exit(1);
}

initStartupInfo(startMode, experimentId);
initStartupInfo(startMode, experimentId, port);

mkDirP(getLogDir()).then(async () => {
const log: Logger = getLogger();
try {
await initContainer(mode);
const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(port);
await restServer.start();
log.info(`Rest server listening on: ${restServer.endPoint}`);
} catch (err) {
log.error(`${err.stack}`);
Expand Down
4 changes: 2 additions & 2 deletions src/nni_manager/training_service/pai/paiData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3}
&& cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --pai_hdfs_output_dir '{6}'
--pai_hdfs_host '{7}' --pai_user_name {8}`;
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{4}' --nnimanager_ip '{5}' --nnimanager_port '{6}'
--pai_hdfs_output_dir '{7}' --pai_hdfs_host '{8}' --pai_user_name {9}`;

export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`;
Expand Down
14 changes: 13 additions & 1 deletion src/nni_manager/training_service/pai/paiJobRestServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

'use strict';

import * as assert from 'assert';
import { Request, Response, Router } from 'express';
import * as bodyParser from 'body-parser';
import * as component from '../../common/component';
import { getBasePort } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { Inject } from 'typescript-ioc';
import { PAITrainingService } from './paiTrainingService';
Expand All @@ -48,10 +50,20 @@ export class PAIJobRestServer extends RestServer{
*/
constructor() {
super();
this.port = PAIJobRestServer.DEFAULT_PORT;
const basePort: number = getBasePort();
assert(basePort && basePort > 1024);

this.port = basePort + 1; // PAIJobRestServer.DEFAULT_PORT;
this.paiTrainingService = component.get(PAITrainingService);
}

public get paiRestServerPort(): number {
if(!this.port) {
throw new Error('PAI Rest server port is undefined');
}
return this.port;
}

/**
* NNIRestServer's own router registration
*/
Expand Down
7 changes: 7 additions & 0 deletions src/nni_manager/training_service/pai/paiTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class PAITrainingService implements TrainingService {
private hdfsBaseDir: string | undefined;
private hdfsOutputHost: string | undefined;
private trialSequenceId: number;
private paiRestServerPort?: number;

constructor() {
this.log = getLogger();
Expand Down Expand Up @@ -145,6 +146,11 @@ class PAITrainingService implements TrainingService {
throw new Error('hdfsOutputHost is not initialized');
}

if(!this.paiRestServerPort) {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
this.paiRestServerPort = restServer.paiRestServerPort;
}

this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);

const trialJobId: string = uniqueString(5);
Expand Down Expand Up @@ -200,6 +206,7 @@ class PAITrainingService implements TrainingService {
this.experimentId,
this.paiTrialConfig.command,
getIPV4Address(),
this.paiRestServerPort,
hdfsOutputDir,
this.hdfsOutputHost,
this.paiClusterConfig.userName
Expand Down
2 changes: 0 additions & 2 deletions tools/nni_trial_tool/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

BASE_URL = 'http://{}'

DEFAULT_REST_PORT = 51189

HOME_DIR = os.path.join(os.environ['HOME'], 'nni')

LOG_DIR = os.environ['NNI_OUTPUT_DIR']
Expand Down
9 changes: 4 additions & 5 deletions tools/nni_trial_tool/metrics_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import re
import requests

from .constants import BASE_URL, DEFAULT_REST_PORT
from .constants import BASE_URL
from .rest_utils import rest_get, rest_post, rest_put, rest_delete
from .url_utils import gen_update_metrics_url

Expand All @@ -40,11 +40,10 @@ class TrialMetricsReader():
'''
Read metrics data from a trial job
'''
def __init__(self, rest_port = DEFAULT_REST_PORT):
def __init__(self):
metrics_base_dir = os.path.join(NNI_SYS_DIR, '.nni')
self.offset_filename = os.path.join(metrics_base_dir, 'metrics_offset')
self.metrics_filename = os.path.join(metrics_base_dir, 'metrics')
self.rest_port = rest_port
if not os.path.exists(metrics_base_dir):
os.makedirs(metrics_base_dir)

Expand Down Expand Up @@ -107,7 +106,7 @@ def read_trial_metrics(self):
offset = self._get_offset()
return self._read_all_available_records(offset)

def read_experiment_metrics(nnimanager_ip):
def read_experiment_metrics(nnimanager_ip, nnimanager_port):
'''
Read metrics data for specified trial jobs
'''
Expand All @@ -118,7 +117,7 @@ def read_experiment_metrics(nnimanager_ip):
result['metrics'] = reader.read_trial_metrics()
print('Result metrics is {}'.format(json.dumps(result)))
if len(result['metrics']) > 0:
response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), DEFAULT_REST_PORT, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10)
response = rest_post(gen_update_metrics_url(BASE_URL.format(nnimanager_ip), nnimanager_port, NNI_EXP_ID, NNI_TRIAL_JOB_ID), json.dumps(result), 10)
print('Response code is {}'.format(response.status_code))
except Exception:
#TODO error logging to file
Expand Down
5 changes: 3 additions & 2 deletions tools/nni_trial_tool/trial_keeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main_loop(args):
while True:
retCode = process.poll()
## Read experiment metrics, to avoid missing metrics
read_experiment_metrics(args.nnimanager_ip)
read_experiment_metrics(args.nnimanager_ip, args.nnimanager_port)

if retCode is not None:
print('subprocess terminated. Exit code is {}. Quit'.format(retCode))
Expand Down Expand Up @@ -80,7 +80,8 @@ def trial_keeper_help_info(*args):
PARSER = argparse.ArgumentParser()
PARSER.set_defaults(func=trial_keeper_help_info)
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager IP')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
Expand Down

0 comments on commit b1d4c12

Please sign in to comment.