Skip to content

Commit

Permalink
1.Replace connector_name to connector in record_info. 2.Add assert th…
Browse files Browse the repository at this point in the history
…at each connector must be in connectors.
  • Loading branch information
zhangzhongyu.vendor committed Jul 28, 2022
1 parent cbcd8a3 commit d75d120
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@
from_student=True,
recorder='bb_s4',
record_idx=1,
connector_name='loss_s4_sfeat'),
connector='loss_s4_sfeat'),
t_feature=dict(
from_student=False, recorder='bb_s4', record_idx=2)),
loss_s3=dict(
s_feature=dict(
from_student=True,
recorder='bb_s3',
record_idx=1,
connector_name='loss_s3_sfeat'),
connector='loss_s3_sfeat'),
t_feature=dict(
from_student=False, recorder='bb_s3', record_idx=2)),
loss_kl=dict(
Expand Down
19 changes: 12 additions & 7 deletions mmrazor/models/distillers/configurable_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class ConfigurableDistiller(BaseDistiller):
>>> loss_forward_mappings = dict(
... loss_neck=dict(
... s_feature=dict(from_recorder='feat', from_student=True,
... connector_name='loss_neck_sfeat'),
... connector='loss_neck_sfeat'),
... t_feature=dict(from_recorder='feat', from_student=False,
... connector_name='loss_neck_tfeat')))
... connector='loss_neck_tfeat')))
"""

def __init__(self,
Expand All @@ -113,6 +113,8 @@ def __init__(self,

self.distill_losses = self.build_distill_losses(distill_losses)

self.connectors = self.build_connectors(connectors)

if loss_forward_mappings:
# Check if loss_forward_mappings is in the correct format.
self._check_loss_forward_mappings(self.distill_losses,
Expand All @@ -123,8 +125,6 @@ def __init__(self,
else:
self.loss_forward_mappings = dict()

self.connectors = self.build_connectors(connectors)

def set_deliveries_override(self, override: bool) -> None:
"""Set the `override_data` of all deliveries."""
self.deliveries.override_data = override
Expand Down Expand Up @@ -183,7 +183,7 @@ def get_record(self,
from_student: bool,
record_idx: int = 0,
data_idx: Optional[int] = None,
connector_name: Optional[str] = None) -> List:
connector: Optional[str] = None) -> List:
"""According to each item in ``record_infos``, get the corresponding
record in ``recorder_manager``."""

Expand All @@ -193,8 +193,8 @@ def get_record(self,
recorder_ = self.teacher_recorders.get_recorder(recorder)
record_data = recorder_.get_record_data(record_idx, data_idx)

if connector_name:
record_data = self.connectors[connector_name](record_data)
if connector:
record_data = self.connectors[connector](record_data)

return record_data

Expand Down Expand Up @@ -272,3 +272,8 @@ def _check_loss_forward_mappings(
assert recorder in teacher_recorders.recorders, \
f'For {forward_key}, "{recorder}" must be in \
`teacher_recorders`.'

if 'connector' in record_info:
connector: str = record_info['connector']
assert connector in self.connectors, \
f'{connector} must be in "connectors".'

0 comments on commit d75d120

Please sign in to comment.