Skip to content

Commit

Permalink
Merge pull request #108 from Microsoft/master
Browse files Browse the repository at this point in the history
merge master
  • Loading branch information
SparkSnail authored Jan 2, 2019
2 parents be23f55 + 37354df commit 6f760ab
Show file tree
Hide file tree
Showing 15 changed files with 411 additions and 209 deletions.
4 changes: 4 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
trigger:
- master
- dev-remote-ci

jobs:

- job: 'Ubuntu_16_04'
Expand Down
37 changes: 37 additions & 0 deletions src/nni_manager/training_service/common/clusterJobRestServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ import * as assert from 'assert';
import { Request, Response, Router } from 'express';
import * as bodyParser from 'body-parser';
import * as component from '../../common/component';
import * as fs from 'fs'
import * as path from 'path'
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { RestServer } from '../../common/restServer'
import { getLogDir } from '../../common/utils';
import { Writable } from 'stream';

/**
* Cluster Job Training service Rest server, provides rest API to support Cluster job metrics update
Expand All @@ -33,6 +37,7 @@ import { RestServer } from '../../common/restServer'
@component.Singleton
export abstract class ClusterJobRestServer extends RestServer{
private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;

private readonly expId: string = getExperimentId();

Expand Down Expand Up @@ -88,6 +93,38 @@ export abstract class ClusterJobRestServer extends RestServer{
}
});

router.post(`/stdout/${this.expId}/:trialId`, (req: Request, res: Response) => {
const trialLogPath: string = path.join(getLogDir(), `trial_${req.params.trialId}.log`);
try {
let skipLogging: boolean = false;
if(req.body.tag === 'trial' && req.body.msg !== undefined) {
const metricsContent = req.body.msg.match(this.NNI_METRICS_PATTERN);
if(metricsContent && metricsContent.groups) {
this.handleTrialMetrics(req.params.trialId, [metricsContent.groups['metrics']]);
skipLogging = true;
}
}

if(!skipLogging){
// Construct write stream to write remote trial's log into local file
const writeStream: Writable = fs.createWriteStream(trialLogPath, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});

writeStream.write(req.body.msg + '\n');
writeStream.end();
}
res.send();
}
catch(err) {
this.log.error(`json parse stdout data error: ${err}`);
res.status(500);
res.send(err.message);
}
});

return router;
}

Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/training_service/pai/paiTrainingService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class PAITrainingService implements TrainingService {
public async run(): Promise<void> {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
await restServer.start();

this.log.info(`PAI Training service rest server listening on: ${restServer.endPoint}`);
while (!this.stopping) {
await this.updatePaiToken();
Expand Down
3 changes: 3 additions & 0 deletions src/sdk/pynni/nni/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ def _load_env_args():
class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file):
self.file = logger_file
self.orig_stdout = sys.stdout

def write(self, s):
if s != '\n':
time = datetime.now().strftime(_time_format)
self.orig_stdout.write(s + '\n')
self.orig_stdout.flush()
self.file.write('[{}] PRINT '.format(time) + s + '\n')
self.file.flush()
return len(s)
Expand Down
22 changes: 15 additions & 7 deletions src/sdk/pynni/nni/platform/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
_outputdir = os.environ['NNI_OUTPUT_DIR']
if not os.path.exists(_outputdir):
os.makedirs(_outputdir)
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)

_nni_platform = os.environ['NNI_PLATFORM']
if _nni_platform not in ['pai', 'kubeflow']:
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)

_multiphase = os.environ.get('MULTI_PHASE')

Expand Down Expand Up @@ -74,11 +77,16 @@ def get_next_parameter():
return params

def send_metric(string):
data = (string + '\n').encode('utf8')
assert len(data) < 1000000, 'Metric too long'
_metric_file.write(b'ME%06d%b' % (len(data), data))
_metric_file.flush()
subprocess.run(['touch', _metric_file.name], check = True)
if _nni_platform in ['pai', 'kubeflow']:
data = (string).encode('utf8')
assert len(data) < 1000000, 'Metric too long'
print('NNISDK_ME%s' % (data))
else:
data = (string + '\n').encode('utf8')
assert len(data) < 1000000, 'Metric too long'
_metric_file.write(b'ME%06d%b' % (len(data), data))
_metric_file.flush()
subprocess.run(['touch', _metric_file.name], check = True)

def get_sequence_id():
return os.environ['NNI_TRIAL_SEQ_ID']
Loading

0 comments on commit 6f760ab

Please sign in to comment.