Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

NAS for TensorFlow #2115

Merged
merged 42 commits into from
Apr 26, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f5f38e1
mutable interface
liuzhe-lz Mar 3, 2020
fcc7688
mutator interface
liuzhe-lz Mar 3, 2020
8fb92a5
enas mutator
liuzhe-lz Mar 3, 2020
b4aa775
add enas trainer
liuzhe-lz Mar 28, 2020
710a672
fix typo
liuzhe-lz Mar 30, 2020
7fb98ad
Upgrade local windows tensorflow version (#2246) (#2247)
chicm-ms Mar 31, 2020
0e62122
update pipeline mac version (#2252)
liuzhe-lz Apr 1, 2020
2c5bde7
Fix security alerts (#2251)
liuzhe-lz Apr 1, 2020
63bd0f5
remove clean step in remote-windows (#2256)
SparkSnail Apr 2, 2020
a7b96de
Enable visualization in examples (#2261)
ultmaster Apr 3, 2020
a84b32b
improve doc for PBT tuner (#2258)
QuanluZhang Apr 3, 2020
c61700f
Add doc of TextNAS (#2260)
pkuyym Apr 3, 2020
4dc9eb9
Update pr test cases for windows (#2267) (#2268)
chicm-ms Apr 4, 2020
742c26e
Add NAS Visualization Documentation (#2257)
ultmaster Apr 5, 2020
4deb4b4
Fix pai-windows pipeline (#2270)
SparkSnail Apr 5, 2020
0fd8466
Fix incorrect doc (#2264)
liuzhe-lz Apr 6, 2020
dfe166e
Fix pruner issues (#2265)
chicm-ms Apr 6, 2020
470caf4
Fix broken link in readthedoc (#2266)
liuzhe-lz Apr 7, 2020
4b598dd
fix bug of refresh from disable refresh to refresh (#2274)
Lijiaoa Apr 7, 2020
d2c5777
Add supported data types for PBT tuner (#2271)
RayMeng8 Apr 7, 2020
a95ccb6
remove old nas examples (#2285)
ultmaster Apr 8, 2020
ffc23cf
Merge pull request #2282 from microsoft/master
liuzhe-lz Apr 8, 2020
065d788
Fix nasui installation (#2283)
SparkSnail Apr 8, 2020
41faab3
Fix lottery ticket (#2286)
chicm-ms Apr 9, 2020
2e88fe7
update doc: TextNAS and PBT tuner (#2279)
QuanluZhang Apr 9, 2020
7d586d3
Fix nas tests (#2291)
chicm-ms Apr 9, 2020
4dfd9d1
Merge pull request #2254 from microsoft/v1.5
liuzhe-lz Apr 14, 2020
f1ce164
Create studentProgram.md (#2273)
Lijiaoa Apr 14, 2020
f8d42a3
Add release note and update version numbers for v1.5 (#2300)
liuzhe-lz Apr 16, 2020
649eabc
Fix build docker image problem (#2326)
chicm-ms Apr 17, 2020
ae72aec
Show more log info for failed test cases (#2321)
chicm-ms Apr 17, 2020
f36b62a
Update sklearn regression example (#2330)
SparkSnail Apr 17, 2020
0a64d94
update typo in FeatureEngineering (#2292)
xuehui1991 Apr 19, 2020
6f6faff
Update GbdtExample.md (#2306)
xuehui1991 Apr 19, 2020
d2a0fc5
Fix doc title of DLTS (#2341)
SparkSnail Apr 19, 2020
aca75e8
Update outdated pai image (#2302)
SparkSnail Apr 20, 2020
bcb53a7
Chinese translation (#2250)
squirrelsc Apr 20, 2020
8c8220e
Overview concurrency tooltip (#2333)
Lijiaoa Apr 20, 2020
6b02f7a
add tooltip when there is no data in overview table (#2318)
Lijiaoa Apr 20, 2020
80242f2
Update issue template (#2355)
Lijiaoa Apr 23, 2020
1372cc8
naive example
liuzhe-lz Apr 24, 2020
a89495c
Merge branch 'master' into tf-nas
liuzhe-lz Apr 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed src/sdk/__init__.py
Empty file.
Empty file removed src/sdk/pynni/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions src/sdk/pynni/nni/nas/tensorflow/base_mutator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from tensorflow.keras import Model

from .mutables import Mutable, MutableScope, InputChoice


class BaseMutator(Model):
def __init__(self, model):
super().__init__()
self.__dict__['model'] = model
self._structured_mutables = self._parse_search_space(self.model)

def _parse_search_space(self, module, root=None, prefix='', memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError('Cannot have nested search space. Error at {} in {}'
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError('"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.'
.format(k, module.key))
for submodule in module.layers:
if not isinstance(submodule, Model):
continue
submodule_prefix = prefix + ('.' if prefix else '') + submodule.name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, nested_detection=nested_detection)
return root

@property
def mutables(self):
return self._structured_mutables

def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)

def forward(self, *inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think forward is not needed for tensorflow.

raise RuntimeError('Forward is undefined for mutators.')

def __setattr__(self, name, value):
if name == 'model':
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include your network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)

def enter_mutable_scope(self, mutable_scope):
pass

def exit_mutable_scope(self, mutable_scope):
pass

def on_forward_layer_choice(self, mutable, *inputs):
raise NotImplementedError

def on_forward_input_choice(self, mutable, tensor_list):
raise NotImplementedError

def export(self):
raise NotImplementedError


# TODO: move to utils
class StructuredMutableTreeNode:
def __init__(self, mutable):
self.mutable = mutable
self.children = []

def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]

def type(self):
return type(self.mutable)

def __iter__(self):
return self.traverse()

def traverse(self, order="pre", deduplicate=True, memo=None):
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
5 changes: 5 additions & 0 deletions src/sdk/pynni/nni/nas/tensorflow/enas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .mutator import EnasMutator
from .trainer import EnasTrainer
173 changes: 173 additions & 0 deletions src/sdk/pynni/nni/nas/tensorflow/enas/mutator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Embedding, LSTMCell
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction

from nni.nas.tensorflow.mutator import Mutator
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope


class StackedLSTMCell(Model):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = [LSTMCell(units=size, use_bias=bias) for _ in range(layers)]

def call(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
return next_c, next_h


class EnasMutator(Mutator):
def __init__(
self,
model,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
cell_exit_extra_step=False,
skip_target=0.4,
temperature=None,
branch_bias=0.25,
entropy_reduction='sum'):
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias

self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = Dense(self.lstm_size, use_bias=False)
self.attn_query = Dense(self.lstm_size, use_bias=False)
self.v_attn = Dense(1, use_bias=False)
self.g_emb = tf.Variable(tf.random.normal((1, self.lstm_size)) * 0.1)
self.skip_targets = tf.constant([1.0 - self.skip_target, self.skip_target])
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
self.bias_dict = {}

self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
if 'reduce' in mutable.key:
bias = []
for choice in mutable.choices:
if 'conv' in str(type(choice)).lower():
bias.append(self.branch_bias)
else:
bias.append(-self.branch_bias)
self.bias_dict[mutable.key] = tf.constant(bias)

self.embedding = Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = Dense(self.max_layer_choice, use_bias=False)

def sample_search(self):
self._initialize()
self._sample(self.mutables)
return self._choices

def sample_final(self):
return self.sample_search()

def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)

def _initialize(self):
self._choices = {}
self._anchors_hid = {}
self._inputs = tf.Variable(self.g_emb)
self._c = [tf.zeros((1, self.lstm_size), dtype=self._inputs.dtype) for _ in range(self.lstm_num_layers)]
self._h = [tf.zeros((1, self.lstm_size), dtype=self._inputs.dtype) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0

def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))

def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[1]

def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * tf.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = tf.random.categorical(tf.nn.softmax(logit, axis=-1), 1)
branch_id = tf.reshape(branch_id, [-1])
log_prob = self.cross_entropy_loss(branch_id, logit)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
ret = tf.cast(tf.one_hot(branch_id, self.max_layer_choice), tf.bool)
return tf.reshape(ret, [-1])

def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label)
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = tf.concat(query, 0)
query = tf.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * tf.tanh(query)

if mutable.n_chosen is None:
logit = tf.concat([-query, query], 1)

skip = tf.reshape(tf.random.categorical(tf.nn.softmax(logit, axis=-1), 1), [-1])
skip_prob = tf.math.sigmoid(logit)
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(skip, logit)
self._inputs = (tf.linalg.matmul(skip.float(), tf.concat(anchors, 0)) / (1. + tf.reduce_sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = tf.reshape(query, [1, -1])
index = tf.reshape(tf.random.categorical(tf.nn.softmax(logit, axis=-1), 1), [-1])
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
log_prob = self.cross_entropy_loss(index, logit)
self._inputs = anchors[index.item()]

self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
return tf.cast(skip, tf.bool)
Loading