This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
NAS for TensorFlow #2115
Merged
Merged
NAS for TensorFlow #2115
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
f5f38e1
mutable interface
liuzhe-lz fcc7688
mutator interface
liuzhe-lz 8fb92a5
enas mutator
liuzhe-lz b4aa775
add enas trainer
liuzhe-lz 710a672
fix typo
liuzhe-lz 7fb98ad
Upgrade local windows tensorflow version (#2246) (#2247)
chicm-ms 0e62122
update pipeline mac version (#2252)
liuzhe-lz 2c5bde7
Fix security alerts (#2251)
liuzhe-lz 63bd0f5
remove clean step in remote-windows (#2256)
SparkSnail a7b96de
Enable visualization in examples (#2261)
ultmaster a84b32b
improve doc for PBT tuner (#2258)
QuanluZhang c61700f
Add doc of TextNAS (#2260)
pkuyym 4dc9eb9
Update pr test cases for windows (#2267) (#2268)
chicm-ms 742c26e
Add NAS Visualization Documentation (#2257)
ultmaster 4deb4b4
Fix pai-windows pipeline (#2270)
SparkSnail 0fd8466
Fix incorrect doc (#2264)
liuzhe-lz dfe166e
Fix pruner issues (#2265)
chicm-ms 470caf4
Fix broken link in readthedoc (#2266)
liuzhe-lz 4b598dd
fix bug of refresh from disable refresh to refresh (#2274)
Lijiaoa d2c5777
Add supported data types for PBT tuner (#2271)
RayMeng8 a95ccb6
remove old nas examples (#2285)
ultmaster ffc23cf
Merge pull request #2282 from microsoft/master
liuzhe-lz 065d788
Fix nasui installation (#2283)
SparkSnail 41faab3
Fix lottery ticket (#2286)
chicm-ms 2e88fe7
update doc: TextNAS and PBT tuner (#2279)
QuanluZhang 7d586d3
Fix nas tests (#2291)
chicm-ms 4dfd9d1
Merge pull request #2254 from microsoft/v1.5
liuzhe-lz f1ce164
Create studentProgram.md (#2273)
Lijiaoa f8d42a3
Add release note and update version numbers for v1.5 (#2300)
liuzhe-lz 649eabc
Fix build docker image problem (#2326)
chicm-ms ae72aec
Show more log info for failed test cases (#2321)
chicm-ms f36b62a
Update sklearn regression example (#2330)
SparkSnail 0a64d94
update typo in FeatureEngineering (#2292)
xuehui1991 6f6faff
Update GbdtExample.md (#2306)
xuehui1991 d2a0fc5
Fix doc title of DLTS (#2341)
SparkSnail aca75e8
Update outdated pai image (#2302)
SparkSnail bcb53a7
Chinese translation (#2250)
squirrelsc 8c8220e
Overview concurrency tooltip (#2333)
Lijiaoa 6b02f7a
add tooltip when there is no data in overview table (#2318)
Lijiaoa 80242f2
Update issue template (#2355)
Lijiaoa 1372cc8
naive example
liuzhe-lz a89495c
Merge branch 'master' into tf-nas
liuzhe-lz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from tensorflow.keras import Model | ||
|
||
from .mutables import Mutable, MutableScope, InputChoice | ||
|
||
|
||
class BaseMutator(Model): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.__dict__['model'] = model | ||
self._structured_mutables = self._parse_search_space(self.model) | ||
|
||
def _parse_search_space(self, module, root=None, prefix='', memo=None, nested_detection=None): | ||
if memo is None: | ||
memo = set() | ||
if root is None: | ||
root = StructuredMutableTreeNode(None) | ||
if module not in memo: | ||
memo.add(module) | ||
if isinstance(module, Mutable): | ||
if nested_detection is not None: | ||
raise RuntimeError('Cannot have nested search space. Error at {} in {}' | ||
.format(module, nested_detection)) | ||
module.name = prefix | ||
module.set_mutator(self) | ||
root = root.add_child(module) | ||
if not isinstance(module, MutableScope): | ||
nested_detection = module | ||
if isinstance(module, InputChoice): | ||
for k in module.choose_from: | ||
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]: | ||
raise RuntimeError('"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.' | ||
.format(k, module.key)) | ||
for submodule in module.layers: | ||
if not isinstance(submodule, Model): | ||
continue | ||
submodule_prefix = prefix + ('.' if prefix else '') + submodule.name | ||
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, nested_detection=nested_detection) | ||
return root | ||
|
||
@property | ||
def mutables(self): | ||
return self._structured_mutables | ||
|
||
def undedup_mutables(self): | ||
return self._structured_mutables.traverse(deduplicate=False) | ||
|
||
def forward(self, *inputs): | ||
raise RuntimeError('Forward is undefined for mutators.') | ||
|
||
def __setattr__(self, name, value): | ||
if name == 'model': | ||
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to " | ||
"include your network, as it will include all parameters in model into the mutator.") | ||
return super().__setattr__(name, value) | ||
|
||
def enter_mutable_scope(self, mutable_scope): | ||
pass | ||
|
||
def exit_mutable_scope(self, mutable_scope): | ||
pass | ||
|
||
def on_forward_layer_choice(self, mutable, *inputs): | ||
raise NotImplementedError | ||
|
||
def on_forward_input_choice(self, mutable, tensor_list): | ||
raise NotImplementedError | ||
|
||
def export(self): | ||
raise NotImplementedError | ||
|
||
|
||
# TODO: move to utils | ||
class StructuredMutableTreeNode: | ||
def __init__(self, mutable): | ||
self.mutable = mutable | ||
self.children = [] | ||
|
||
def add_child(self, mutable): | ||
self.children.append(StructuredMutableTreeNode(mutable)) | ||
return self.children[-1] | ||
|
||
def type(self): | ||
return type(self.mutable) | ||
|
||
def __iter__(self): | ||
return self.traverse() | ||
|
||
def traverse(self, order="pre", deduplicate=True, memo=None): | ||
if memo is None: | ||
memo = set() | ||
assert order in ["pre", "post"] | ||
if order == "pre": | ||
if self.mutable is not None: | ||
if not deduplicate or self.mutable.key not in memo: | ||
memo.add(self.mutable.key) | ||
yield self.mutable | ||
for child in self.children: | ||
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): | ||
yield m | ||
if order == "post": | ||
if self.mutable is not None: | ||
if not deduplicate or self.mutable.key not in memo: | ||
memo.add(self.mutable.key) | ||
yield self.mutable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from .mutator import EnasMutator | ||
from .trainer import EnasTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras import Model | ||
from tensorflow.keras.layers import Dense, Embedding, LSTMCell | ||
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction | ||
|
||
from nni.nas.tensorflow.mutator import Mutator | ||
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope | ||
|
||
|
||
class StackedLSTMCell(Model): | ||
def __init__(self, layers, size, bias): | ||
super().__init__() | ||
self.lstm_num_layers = layers | ||
self.lstm_modules = [LSTMCell(units=size, use_bias=bias) for _ in range(layers)] | ||
|
||
def call(self, inputs, hidden): | ||
prev_c, prev_h = hidden | ||
next_c, next_h = [], [] | ||
for i, m in enumerate(self.lstm_modules): | ||
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) | ||
next_c.append(curr_c) | ||
next_h.append(curr_h) | ||
inputs = curr_h[-1] | ||
return next_c, next_h | ||
|
||
|
||
class EnasMutator(Mutator): | ||
def __init__( | ||
self, | ||
model, | ||
lstm_size=64, | ||
lstm_num_layers=1, | ||
tanh_constant=1.5, | ||
cell_exit_extra_step=False, | ||
skip_target=0.4, | ||
temperature=None, | ||
branch_bias=0.25, | ||
entropy_reduction='sum'): | ||
super().__init__(model) | ||
self.lstm_size = lstm_size | ||
self.lstm_num_layers = lstm_num_layers | ||
self.tanh_constant = tanh_constant | ||
self.temperature = temperature | ||
self.cell_exit_extra_step = cell_exit_extra_step | ||
self.skip_target = skip_target | ||
self.branch_bias = branch_bias | ||
|
||
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) | ||
self.attn_anchor = Dense(self.lstm_size, use_bias=False) | ||
self.attn_query = Dense(self.lstm_size, use_bias=False) | ||
self.v_attn = Dense(1, use_bias=False) | ||
self.g_emb = tf.Variable(tf.random.normal((1, self.lstm_size)) * 0.1) | ||
self.skip_targets = tf.constant([1.0 - self.skip_target, self.skip_target]) | ||
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.' | ||
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean | ||
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE) | ||
self.bias_dict = {} | ||
|
||
self.max_layer_choice = 0 | ||
for mutable in self.mutables: | ||
if isinstance(mutable, LayerChoice): | ||
if self.max_layer_choice == 0: | ||
self.max_layer_choice = mutable.length | ||
assert self.max_layer_choice == mutable.length, \ | ||
"ENAS mutator requires all layer choice have the same number of candidates." | ||
if 'reduce' in mutable.key: | ||
bias = [] | ||
for choice in mutable.choices: | ||
if 'conv' in str(type(choice)).lower(): | ||
bias.append(self.branch_bias) | ||
else: | ||
bias.append(-self.branch_bias) | ||
self.bias_dict[mutable.key] = tf.constant(bias) | ||
|
||
self.embedding = Embedding(self.max_layer_choice + 1, self.lstm_size) | ||
self.soft = Dense(self.max_layer_choice, use_bias=False) | ||
|
||
def sample_search(self): | ||
self._initialize() | ||
self._sample(self.mutables) | ||
return self._choices | ||
|
||
def sample_final(self): | ||
return self.sample_search() | ||
|
||
def _sample(self, tree): | ||
mutable = tree.mutable | ||
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: | ||
self._choices[mutable.key] = self._sample_layer_choice(mutable) | ||
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: | ||
self._choices[mutable.key] = self._sample_input_choice(mutable) | ||
for child in tree.children: | ||
self._sample(child) | ||
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: | ||
if self.cell_exit_extra_step: | ||
self._lstm_next_step() | ||
self._mark_anchor(mutable.key) | ||
|
||
def _initialize(self): | ||
self._choices = {} | ||
self._anchors_hid = {} | ||
self._inputs = tf.Variable(self.g_emb) | ||
self._c = [tf.zeros((1, self.lstm_size), dtype=self._inputs.dtype) for _ in range(self.lstm_num_layers)] | ||
self._h = [tf.zeros((1, self.lstm_size), dtype=self._inputs.dtype) for _ in range(self.lstm_num_layers)] | ||
self.sample_log_prob = 0 | ||
self.sample_entropy = 0 | ||
self.sample_skip_penalty = 0 | ||
|
||
def _lstm_next_step(self): | ||
self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) | ||
|
||
def _mark_anchor(self, key): | ||
self._anchors_hid[key] = self._h[1] | ||
|
||
def _sample_layer_choice(self, mutable): | ||
self._lstm_next_step() | ||
logit = self.soft(self._h[-1]) | ||
if self.temperature is not None: | ||
logit /= self.temperature | ||
if self.tanh_constant is not None: | ||
logit = self.tanh_constant * tf.tanh(logit) | ||
if mutable.key in self.bias_dict: | ||
logit += self.bias_dict[mutable.key] | ||
branch_id = tf.random.categorical(tf.nn.softmax(logit, axis=-1), 1) | ||
branch_id = tf.reshape(branch_id, [-1]) | ||
log_prob = self.cross_entropy_loss(branch_id, logit) | ||
self.sample_log_prob += self.entropy_reduction(log_prob) | ||
entropy = log_prob * tf.exp(-log_prob) | ||
self.sample_entropy += self.entropy_reduction(entropy) | ||
self._inputs = self.embedding(branch_id) | ||
ret = tf.cast(tf.one_hot(branch_id, self.max_layer_choice), tf.bool) | ||
return tf.reshape(ret, [-1]) | ||
|
||
def _sample_input_choice(self, mutable): | ||
query, anchors = [], [] | ||
for label in mutable.choose_from: | ||
if label not in self._anchors_hid: | ||
self._lstm_next_step() | ||
self._mark_anchor(label) | ||
query.append(self.attn_anchor(self._anchors_hid[label])) | ||
anchors.append(self._anchors_hid[label]) | ||
query = tf.concat(query, 0) | ||
query = tf.tanh(query + self.attn_query(self._h[-1])) | ||
query = self.v_attn(query) | ||
if self.temperature is not None: | ||
query /= self.temperature | ||
if self.tanh_constant is not None: | ||
query = self.tanh_constant * tf.tanh(query) | ||
|
||
if mutable.n_chosen is None: | ||
logit = tf.concat([-query, query], 1) | ||
|
||
skip = tf.reshape(tf.random.categorical(tf.nn.softmax(logit, axis=-1), 1), [-1]) | ||
skip_prob = tf.math.sigmoid(logit) | ||
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets)) | ||
self.sample_skip_penalty += kl | ||
log_prob = self.cross_entropy_loss(skip, logit) | ||
self._inputs = (tf.linalg.matmul(skip.float(), tf.concat(anchors, 0)) / (1. + tf.reduce_sum(skip))).unsqueeze(0) | ||
else: | ||
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." | ||
logit = tf.reshape(query, [1, -1]) | ||
index = tf.reshape(tf.random.categorical(tf.nn.softmax(logit, axis=-1), 1), [-1]) | ||
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1]) | ||
log_prob = self.cross_entropy_loss(index, logit) | ||
self._inputs = anchors[index.item()] | ||
|
||
self.sample_log_prob += self.entropy_reduction(log_prob) | ||
entropy = log_prob * tf.exp(-log_prob) | ||
self.sample_entropy += self.entropy_reduction(entropy) | ||
return tf.cast(skip, tf.bool) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think forward is not needed for tensorflow.