Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Dev enas - multi-phase hyper parameters support (#96)
Browse files Browse the repository at this point in the history
* Multi-phase support

* Updates

* Updates

* updates

* updates

* updates
  • Loading branch information
chicm-ms authored Sep 20, 2018
1 parent eedf095 commit 3c832bf
Show file tree
Hide file tree
Showing 20 changed files with 524 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ before_install:
- sudo sh -c 'PATH=/usr/local/node/bin:$PATH yarn global add serve'
install:
- make
- make install
- make dev-install
- export PATH=$HOME/.nni/bin:$PATH
before_script:
- cd test/naive
Expand Down
4 changes: 2 additions & 2 deletions src/nni_manager/common/datastore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService';

type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM';

interface ExperimentProfileRecord {
Expand Down Expand Up @@ -62,7 +62,7 @@ interface TrialJobInfo {
status: TrialJobStatus;
startTime?: number;
endTime?: number;
hyperParameters?: string;
hyperParameters?: string[];
logPath?: string;
finalMetricData?: string;
stderrPath?: string;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/common/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ interface ExperimentParams {
maxExecDuration: number; //seconds
maxTrialNum: number;
searchSpace: string;
multiPhase?: boolean;
tuner: {
className: string;
builtinTunerName?: string;
Expand Down
9 changes: 7 additions & 2 deletions src/nni_manager/common/trainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,16 @@ interface JobApplicationForm {
readonly jobType: JobType;
}

interface HyperParameters {
readonly value: string;
readonly index: number;
}

/**
* define TrialJobApplicationForm
*/
interface TrialJobApplicationForm extends JobApplicationForm {
readonly hyperParameters: string;
readonly hyperParameters: HyperParameters;
}

/**
Expand Down Expand Up @@ -116,6 +121,6 @@ abstract class TrainingService {

export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType
};
5 changes: 4 additions & 1 deletion src/nni_manager/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner
*
*/
function getMsgDispatcherCommand(tuner: any, assessor: any): string {
function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string {
let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`;
if (multiPhase) {
command += ' --multi_phase';
}

if (process.env.VIRTUAL_ENV) {
command = path.join(process.env.VIRTUAL_ENV, 'bin/') +command;
Expand Down
5 changes: 4 additions & 1 deletion src/nni_manager/core/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const TRIAL_END = 'EN';
const TERMINATE = 'TE';

const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO';
const KILL_TRIAL_JOB = 'KI';

Expand All @@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
TERMINATE,

NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER,
NO_MORE_TRIAL_JOBS
]);

Expand All @@ -63,5 +65,6 @@ export {
NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB,
TUNER_COMMANDS,
ASSESSOR_COMMANDS
ASSESSOR_COMMANDS,
SEND_TRIAL_JOB_PARAMETER
};
30 changes: 26 additions & 4 deletions src/nni_manager/core/nniDataStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class NNIDataStore implements DataStore {
}

public async storeMetricData(trialJobId: string, data: string): Promise<void> {
this.log.debug(`storeMetricData: trialJobId: ${trialJobId}, data: ${data}`);
const metrics = JSON.parse(data) as MetricData;
assert(trialJobId === metrics.trial_job_id);
await this.db.storeMetricData(trialJobId, JSON.stringify({
Expand Down Expand Up @@ -168,18 +169,34 @@ class NNIDataStore implements DataStore {
}
}

private getJobStatusByLatestEvent(event: TrialJobEvent): TrialJobStatus {
private getJobStatusByLatestEvent(oldStatus: TrialJobStatus, event: TrialJobEvent): TrialJobStatus {
switch (event) {
case 'USER_TO_CANCEL':
return 'USER_CANCELED';
case 'ADD_CUSTOMIZED':
return 'WAITING';
case 'ADD_HYPERPARAMETER':
return oldStatus;
default:
}

return <TrialJobStatus>event;
}

private mergeHyperParameters(hyperParamList: string[], newParamStr: string): string[] {
const mergedHyperParams: any[] = [];
const newParam: any = JSON.parse(newParamStr);
for (const hyperParamStr of hyperParamList) {
const hyperParam: any = JSON.parse(hyperParamStr);
mergedHyperParams.push(hyperParam);
}
if (mergedHyperParams.filter((value: any) => { return value.parameter_index === newParam.parameter_index; }).length <= 0) {
mergedHyperParams.push(newParam);
}

return mergedHyperParams.map<string>((value: any) => { return JSON.stringify(value); });
}

private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> {
const map: Map<string, TrialJobInfo> = new Map();
// assume data is stored by time ASC order
Expand All @@ -193,7 +210,8 @@ class NNIDataStore implements DataStore {
} else {
jobInfo = {
id: record.trialJobId,
status: this.getJobStatusByLatestEvent(record.event)
status: this.getJobStatusByLatestEvent('UNKNOWN', record.event),
hyperParameters: []
};
}
if (!jobInfo) {
Expand Down Expand Up @@ -222,9 +240,13 @@ class NNIDataStore implements DataStore {
}
default:
}
jobInfo.status = this.getJobStatusByLatestEvent(record.event);
jobInfo.status = this.getJobStatusByLatestEvent(jobInfo.status, record.event);
if (record.data !== undefined && record.data.trim().length > 0) {
jobInfo.hyperParameters = record.data;
if (jobInfo.hyperParameters !== undefined) {
jobInfo.hyperParameters = this.mergeHyperParameters(jobInfo.hyperParameters, record.data);
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
}
}
map.set(record.trialJobId, jobInfo);
}
Expand Down
27 changes: 23 additions & 4 deletions src/nni_manager/core/nnimanager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {
import { delay , getLogDir, getMsgDispatcherCommand} from '../common/utils';
import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface';
import { TrialJobMaintainerEvent, TrialJobs } from './trialJobs';
Expand Down Expand Up @@ -116,7 +116,7 @@ class NNIManager implements Manager {
await this.storeExperimentProfile();
this.log.debug('Setup tuner...');

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor);
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
//expParams.tuner.tunerCommand,
Expand All @@ -140,7 +140,7 @@ class NNIManager implements Manager {
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params;

const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor);
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner(
dispatcherCommand,
Expand Down Expand Up @@ -460,7 +460,10 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: content
hyperParameters: {
value: content,
index: 0
}
};
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
this.trialJobsMaintainer.setTrialJob(trialJobDetail.id, Object.assign({}, trialJobDetail));
Expand All @@ -472,6 +475,22 @@ class NNIManager implements Manager {
}
}
break;
case SEND_TRIAL_JOB_PARAMETER:
const tunerCommand: any = JSON.parse(content);
assert(tunerCommand.parameter_index >= 0);
assert(tunerCommand.trial_job_id !== undefined);

const trialJobForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: content,
index: tunerCommand.parameter_index
}
};
await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm);
await this.dataStore.storeTrialJobEvent(
'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
break;
case NO_MORE_TRIAL_JOBS:
this.trialJobsMaintainer.setNoMoreTrials();
break;
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export namespace ValidationSchemas {
trialConcurrency: joi.number().min(0).required(),
searchSpace: joi.string().required(),
maxExecDuration: joi.number().min(0).required(),
multiPhase: joi.boolean(),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution'),
codeDir: joi.string(),
Expand Down
27 changes: 20 additions & 7 deletions src/nni_manager/training_service/local/localTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ import { getLogger, Logger } from '../../common/log';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { file } from 'tmp';

const tkill = require('tree-kill');

Expand Down Expand Up @@ -210,8 +211,18 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}

return trialJobDetail;
}

/**
Expand Down Expand Up @@ -332,10 +343,7 @@ class LocalTrainingService implements TrainingService {
await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`);
await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`);
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8' });
await fs.promises.writeFile(
path.join(trialJobDetail.workingDirectory, 'parameter.cfg'),
(<TrialJobApplicationForm>trialJobDetail.form).hyperParameters,
{ encoding: 'utf8' });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`);

this.setTrialJobStatus(trialJobDetail, 'RUNNING');
Expand Down Expand Up @@ -402,6 +410,11 @@ class LocalTrainingService implements TrainingService {
}
}
}

private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, `parameter_${hyperParameters.index}.cfg`);
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
}

export { LocalTrainingService };
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer';
import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { GPUSummary } from '../common/gpuData';
Expand Down Expand Up @@ -198,8 +198,24 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
this.log.info(`updateTrialJob: form: ${JSON.stringify(form)}`);
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta;
if (rmMeta !== undefined) {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}

return trialJobDetail;
}

/**
Expand Down Expand Up @@ -442,15 +458,13 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);

// Write file content ( run.sh and parameter.cfg ) to local tmp files
// Write file content ( run.sh and parameter_0.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'parameter.cfg'), form.hyperParameters, { encoding: 'utf8' });

// Copy local tmp files to remote machine
await SSHClientUtility.copyFileToRemote(
path.join(trialLocalTempFolder, 'run.sh'), path.join(trialWorkingFolder, 'run.sh'), sshClient);
await SSHClientUtility.copyFileToRemote(
path.join(trialLocalTempFolder, 'parameter.cfg'), path.join(trialWorkingFolder, 'parameter.cfg'), sshClient);
await this.writeParameterFile(trialJobId, form.hyperParameters, rmScheduleInfo.rmMeta);

// Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(this.trialConfig.codeDir, trialWorkingFolder, sshClient);
Expand Down Expand Up @@ -562,6 +576,22 @@ class RemoteMachineTrainingService implements TrainingService {

return jobpidPath;
}

private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters, rmMeta: RemoteMachineMeta): Promise<void> {
const sshClient: Client | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClient === undefined) {
throw new Error('sshClient is undefined.');
}

const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);

const fileName: string = `parameter_${hyperParameters.index}.cfg`;
const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });

await SSHClientUtility.copyFileToRemote(localFilepath, path.join(trialWorkingFolder, fileName), sshClient);
}
}

export { RemoteMachineTrainingService };
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: 'mock hyperparameters'
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const trialJob = await remoteMachineTrainingService.submitTrialJob(form);

Expand Down Expand Up @@ -135,7 +138,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: 'mock hyperparameters'
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const jobDetail: TrialJobDetail = await remoteMachineTrainingService.submitTrialJob(form);
// Add metrics listeners
Expand Down
Loading

0 comments on commit 3c832bf

Please sign in to comment.