Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support frameworkcontroller log #572

Merged
merged 2 commits into from
Jan 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/trials/mnist/config_frameworkcontroller.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
#choice: local, remote, pai, kubeflow
trainingServicePlatform: frameworkcontroller
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
gpuNum: 0
trial:
codeDir: .
taskRoles:
- name: worker
taskNum: 1
command: python3 mnist.py
gpuNum: 1
cpuNum: 1
memoryMB: 8192
image: msranni/nni:latest
frameworkAttemptCompletionPolicy:
minFailedTaskCount: 1
minSucceededTaskCount: 1
frameworkcontrollerConfig:
storage: nfs
nfs:
# Your NFS server IP, like 10.10.10.10
server: {your_nfs_server_ip}
# Your NFS server export path, like /var/nfs/nni
path: {your_nfs_server_export_path}
4 changes: 2 additions & 2 deletions src/sdk/pynni/nni/platform/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
os.makedirs(_outputdir)

_nni_platform = os.environ['NNI_PLATFORM']
if _nni_platform not in ['pai', 'kubeflow']:
if _nni_platform not in ['pai', 'kubeflow', 'frameworkcontroller']:
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)

Expand Down Expand Up @@ -77,7 +77,7 @@ def get_next_parameter():
return params

def send_metric(string):
if _nni_platform in ['pai', 'kubeflow']:
if _nni_platform in ['pai', 'kubeflow', 'frameworkcontroller']:
data = (string).encode('utf8')
assert len(data) < 1000000, 'Metric too long'
print('NNISDK_ME%s' % (data))
Expand Down