-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
nni/retiarii/nn/pytorch/api.py
Outdated
return x | ||
|
||
|
||
class Cell(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "cell" could be more general.
op_candidates and num_nodes (or num_ops maybe) are likely to be supported by almost all algorithms, but ops_per_node and merge_op might not.
And from end-user's perspective, without knowledge about the NAS algorithm, they might not understand why there are "ops per node".
My suggestion is, either make optional parameters kwargs, or only support basic params in Cell
and let algorithms to inherit it.
I do accept current version if it's not considered API freeze.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, it is a little confusing that there are two Cells, another one is in our graph ir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better name wanted. Feel free to propose a new one.
nni/retiarii/nn/pytorch/api.py
Outdated
blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], | ||
depth: Union[int, Tuple[int, int]], label=None): | ||
super().__init__() | ||
self._label = label if label is not None else f'valuechoice_{uid()}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
valuechoice_{uid()}?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops
nni/retiarii/nn/pytorch/api.py
Outdated
self.min_depth = depth if isinstance(depth, int) else depth[0] | ||
self.max_depth = depth if isinstance(depth, int) else depth[1] | ||
assert self.max_depth >= self.min_depth > 0 | ||
if not isinstance(blocks, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of "list of module"? users replicate module to max_depth
by themselves?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
nni/retiarii/nn/pytorch/api.py
Outdated
# Support loose end concat (shape inference on the following cells) | ||
# How to dynamically create convolution with stride as the first node | ||
|
||
def __init__(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems it is better to put this Cell class into search space zoo/hub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's discuss on teams.
@@ -76,6 +76,40 @@ def mutate(self, model): | |||
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value}) | |||
|
|||
|
|||
class RepeatMutator(Mutator): | |||
def __init__(self, nodes: List[Node]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to explain the meaning of nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
for edge in rm_node.outgoing_edges: | ||
edge.remove() | ||
rm_node.remove() | ||
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the meaning of this line, why it is necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To delete unused parameters. Otherwise, codegen will complain.
This PR proposes two new high-level APIs for Retiarii.
What will NOT be included in this PR: