-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training container for NAS Envelopenet #429
Conversation
@@ -0,0 +1,238 @@ | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import print_function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably don't need future imports since you are using python3 anyways. Same goes for every import like this.
|
||
def cell(self, inputs, arch, is_training): | ||
"""Create the cell by instantiating the cell blocks""" | ||
nscope = 'Cell_' + self.cellname + '_' + str(self.cellidx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.format should be used
By default use stride=1 and SAME padding | ||
""" | ||
dropout_keep_prob = 0.8 | ||
nscope = 'Cell_' + self.cellname + '_' + str(self.cellidx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.format (will be no need to use str())
net = tf.Print( | ||
net, | ||
[msss], | ||
message="MeanSSS=:" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.format
self.batch_size = self.task_config["batch_size"] | ||
self.num_examples = 10000 | ||
self.run_once = True | ||
self.eval_dir = self.task_config["data_dir"] + "/results/" + \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.path.join() instead to take care of "/"
|
||
# Build a Graph that computes the logits predictions from the | ||
# inference model. | ||
# TODO: Clean up all args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get rid of TODOs since I can't see a WIP status.
self.get_params(params) | ||
|
||
def get_params(self, params): | ||
global global_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should they be global and not class attributes?
filepath = os.path.join(dest_directory, filename) | ||
if not os.path.exists(filepath): | ||
def _progress(count, block_size, total_size): | ||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you really need a progressbar?
self._step = self.global_step_init | ||
self._start_time = time.time() | ||
|
||
def before_run(self, run_context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run_context is never used. function is actually never called
from __future__ import print_function | ||
|
||
from datetime import datetime | ||
import ast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it used anywhere?
with tf.variable_scope(nscope, 'initial_block', [inputs], reuse=reuse) as scope: | ||
with slim.arg_scope([slim.conv2d, slim.max_pool2d], stride=1, padding='SAME'): | ||
net = inputs | ||
layeridx = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need this? As fas as I can see you only increment it.
outputs): | ||
self.cellidx = cellidx | ||
self.log_stats = log_stats | ||
self.res=sys.argv[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never used
self.log_stats = log_stats | ||
self.res=sys.argv[2] | ||
self.cellname = "Envelope" | ||
self.numbranches = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never used
self.numbranches = 4 | ||
self.numbins = 100 | ||
self.batchsize = int(net.shape[0]) | ||
numfilters = len(filters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never used
scope) | ||
return softmax_linear | ||
|
||
def maybe_download_and_extract(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow used? I see the acutla line calling it is commented out.
is_training=True, | ||
scope='Nacnet'): | ||
net = self.add_init(inputs, initcell, is_training) | ||
end_points = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? not used? since the line using it is commented out
self.global_step_init = global_step_init | ||
self.loss = loss | ||
|
||
def begin(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never used
# Asks for loss value. | ||
return tf.train.SessionRunArgs(self.loss) | ||
|
||
def after_run(self, run_context, run_values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function is never called
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are required in loggerhook call during training
mon_sess.run(train_op) | ||
|
||
def evaluate(): | ||
eval=Evaluate(self.arch, self.params, self.train_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's called from run_trail.py
@richardsliu @Akado2009 plz merge |
Is the suggestion trained for one experiment (or search space)? Or is it trained once, and can be used for all experiments? |
@gaocegege It is trained for every experiment. As there is no controller like in RL, the only training required is of the architectures being sampled, and that too is truncated,i.e., 10 epochs. |
Now our suggestion is long-running, thus I am not sure if we could support the case. But the suggestion itself LGTM. Thanks for your contribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a README file that explains:
- How this code works
- How to run/debug locally, if possible
- Links to the design doc and/or paper
Without prior knowledge of NAS, it is hard to follow.
Also do you have unit tests?
@gaocegege In fact it's the fastest NAS Algo, even faster than RL algo which is there itself in katib. If you change the 'steps' parameter to 10 in yaml file, it will complete in ~15 min. Also, you might want to run it on GPU. |
@richardsliu done |
@richardsliu Plz merge |
/lgtm |
[APPROVALNOTIFIER] This PR is APPROVED This pull-request has been approved by: richardsliu The full list of commands accepted by this bot can be found here. The pull request process is described here
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
/retest |
This PR has all files for the training and evaluation required in NAS Envelopenet Suggestion service.
This change is