Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add channel-flow #301

Merged
merged 15 commits into from
Sep 29, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(self, *args, **kwargs) -> None:
self.candidate_bn = nn.ModuleDict()

def init_candidates(self, candidates: List):
"""Initialize candicates."""
assert len(self.candidate_bn) == 0
self._check_candidates(candidates)
for num in candidates:
Expand All @@ -155,6 +156,7 @@ def init_candidates(self, candidates: List):
self.weight.dtype)

def forward(self, input: Tensor) -> Tensor:
"""Forward."""
choice_num = self.activated_channel_num()
if choice_num == self.num_features:
return super().forward(input)
Expand All @@ -163,6 +165,7 @@ def forward(self, input: Tensor) -> Tensor:
return self.candidate_bn[str(choice_num)](input)

def to_static_op(self: _BatchNorm) -> nn.Module:
"""Convert to a normal BatchNorm."""
choice_num = self.activated_channel_num()
if choice_num == self.num_features:
return super().to_static_op()
Expand All @@ -173,15 +176,18 @@ def to_static_op(self: _BatchNorm) -> nn.Module:
# private methods

def activated_channel_num(self):
"""Number to activated channels."""
LKJacky marked this conversation as resolved.
Show resolved Hide resolved
mask = self._get_num_features_mask()
choice_num = (mask == 1).sum().item()
return choice_num

def _check_candidates(self, candidates: List):
"""Check if candidates aviliable."""
for value in candidates:
assert isinstance(value, int)
assert 0 < value <= self.num_features

@property
def static_op_factory(self):
"""Return initializer of static op."""
return nn.BatchNorm2d
Loading