Skip to content

Commit

Permalink
search loop fly
Browse files Browse the repository at this point in the history
  • Loading branch information
aptsunny committed Nov 10, 2022
1 parent 4a43b9d commit 8c9a38b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
27 changes: 18 additions & 9 deletions mmrazor/engine/runner/evolution_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from mmengine import fileio
from mmengine.dist import broadcast_object_list
from mmengine.evaluator import Evaluator
from mmengine.runner import EpochBasedTrainLoop
from mmengine.utils import is_list_of
Expand Down Expand Up @@ -212,19 +213,27 @@ def sample_candidates(self) -> None:
"""Update candidate pool contains specified number of candicates."""
candidates_resources = []
init_candidates = len(self.candidates)
while len(self.candidates) < self.num_candidates:
candidate = self.model.sample_subnet()
is_pass, result = self._check_constraints(random_subnet=candidate)
if is_pass:
self.candidates.append(candidate)
candidates_resources.append(result)
self.candidates = Candidates(self.candidates.data)
if self.runner.rank == 0:
while len(self.candidates) < self.num_candidates:
candidate = self.model.sample_subnet()
is_pass, result = self._check_constraints(
random_subnet=candidate)
if is_pass:
self.candidates.append(candidate)
candidates_resources.append(result)
self.candidates = Candidates(self.candidates.data)
else:
self.candidates = Candidates([dict(a=0)] * self.num_candidates)

if len(candidates_resources) > 0:
self.candidates.update_resources(
candidates_resources,
start=len(self.candidates.data) - len(candidates_resources))
assert init_candidates + len(
candidates_resources) == self.num_candidates
assert init_candidates + len(
candidates_resources) == self.num_candidates

# broadcast candidates to val with multi-GPUs.
broadcast_object_list(self.candidates.data)

def update_candidates_scores(self) -> None:
"""Validate candicate one by one from the candicate pool, and update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def to_static_op(self: PatchEmbed) -> nn.Module:
img_size=self.img_size,
in_channels=3,
embed_dims=self.mutable_embed_dims.activated_channels)
# embed_dims=self.mutable_embed_dims.current_choice)

static_patch_embed.projection.weight = nn.Parameter(weight.clone())
static_patch_embed.projection.bias = nn.Parameter(bias.clone())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def to_static_op(self) -> nn.Module:
self.check_if_mutables_fixed()
assert self.mutable_head_dims is not None

# self.current_head_dim = self.mutable_head_dims.current_choice
self.current_head_dim = self.mutable_head_dims.activated_channels
static_relative_position = self.static_op_factory(
self.current_head_dim)
Expand Down
6 changes: 4 additions & 2 deletions mmrazor/structures/subnet/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, initdata: Optional[_format_input] = None):
def scores(self) -> List[float]:
"""The scores of candidates."""
return [
value.get('score', 0.) for item in self.data
round(value.get('score', 0.), 2) for item in self.data
for _, value in item.items()
]

Expand Down Expand Up @@ -96,12 +96,14 @@ def _format(self, data: _format_input) -> _format_return:
def _format_item(
cond: Union[Dict, Dict[str, Dict]]) -> Dict[str, Dict]:
"""Transform Dict to Dict[str, Dict]."""
if isinstance(list(cond.values())[0], dict):
if len(cond.values()) > 0 and isinstance(
list(cond.values())[0], dict):
for value in list(cond.values()):
for key in list(self._indicators):
value.setdefault(key, 0.)
return cond
else:
# import pdb;pdb.set_trace()
return {str(cond): {}.fromkeys(self._indicators, -1)}

if isinstance(data, UserList):
Expand Down

0 comments on commit 8c9a38b

Please sign in to comment.