-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest
- Loading branch information
Showing
13 changed files
with
535 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
mmrazor/models/task_modules/recorder/function_inputs_recorder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import functools | ||
from inspect import signature | ||
from typing import Callable, List | ||
|
||
from mmrazor.registry import TASK_UTILS | ||
from .function_outputs_recorder import FunctionOutputsRecorder | ||
|
||
|
||
@TASK_UTILS.register_module() | ||
class FunctionInputsRecorder(FunctionOutputsRecorder): | ||
"""Recorder for intermediate results which are ``FunctionType``'s inputs. | ||
Notes: | ||
The form of `source` needs special attention. For example, | ||
`anchor_inside_flags` is a function in mmdetection to check whether the | ||
anchors are inside the border. This function is in | ||
`mmdet/core/anchor/utils.py` and used in | ||
`mmdet/models/dense_heads/anchor_head.py`. Then the source should be | ||
`mmdet.models.dense_heads.anchor_head.anchor_inside_flags` but not | ||
`mmdet.core.anchor.utils.anchor_inside_flags`. | ||
Examples: | ||
>>> # Below code in toy_module.py | ||
>>> import random | ||
>>> def toy_func(a, b): | ||
... return a, b | ||
>>> def execute_toy_func(a, b): | ||
... toy_func(a, b) | ||
>>> # Below code in main.py | ||
>>> # Now, we want to get teacher's inputs by recorder. | ||
>>> from toy_module import execute_toy_func | ||
>>> r1 = FunctionInputsRecorder('toy_module.toy_func') | ||
>>> r1.initialize() | ||
>>> with r1: | ||
... execute_toy_func(1, 2) | ||
... execute_toy_func(1, b=2) | ||
... execute_toy_func(b=2, a=1) | ||
>>> r1.data_buffer | ||
[[1, 2], [1, 2], [1, 2]] | ||
""" | ||
|
||
def func_record_wrapper(self, origin_func: Callable, | ||
data_buffer: List) -> Callable: | ||
"""Save the function's inputs. | ||
Args: | ||
origin_func (FunctionType): The method whose inputs need to be | ||
recorded. | ||
data_buffer (list): A list of data. | ||
""" | ||
|
||
func_input_params = signature(origin_func).parameters.keys() | ||
|
||
@functools.wraps(origin_func) | ||
def wrap_func(*args, **kwargs): | ||
outputs = origin_func(*args, **kwargs) | ||
inputs = list(args) | ||
for keyword in func_input_params: | ||
if keyword in kwargs: | ||
inputs.append(kwargs[keyword]) | ||
# assume a func execute N times, there will be N inputs need to | ||
# save. | ||
data_buffer.append(inputs) | ||
return outputs | ||
|
||
return wrap_func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
mmrazor/models/task_modules/recorder/method_inputs_recorder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import functools | ||
from inspect import signature | ||
from typing import Callable, List | ||
|
||
from mmrazor.registry import TASK_UTILS | ||
from .method_outputs_recorder import MethodOutputsRecorder | ||
|
||
|
||
@TASK_UTILS.register_module() | ||
class MethodInputsRecorder(MethodOutputsRecorder): | ||
"""Recorder for intermediate results which are ``MethodType``'s inputs. | ||
Note: | ||
Different from ``FunctionType``, ``MethodType`` is the type of methods | ||
of class instances. | ||
Examples: | ||
>>> # Below code in toy_module.py | ||
>>> import random | ||
>>> class Toy(): | ||
... def toy_func(self, x, y=0): | ||
... return x + y | ||
>>> # Below code in main.py | ||
>>> # Now, we want to get teacher's inputs by recorder. | ||
>>> from toy_module import Toy | ||
>>> toy = Toy() | ||
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func') | ||
>>> r1.initialize() | ||
>>> with r1: | ||
... _ = toy.toy_func(1, 2) | ||
>>> r1.data_buffer | ||
[[1, 2]] | ||
>>> r1.get_record_data(record_idx=0, data_idx=0) | ||
1 | ||
>>> r1.get_record_data(record_idx=0, data_idx=1) | ||
2 | ||
>>> from toy_module import Toy | ||
>>> toy = Toy() | ||
>>> r1 = MethodInputsRecorder('toy_module.Toy.toy_func') | ||
>>> r1.initialize() | ||
>>> with r1: | ||
... _ = toy.toy_func(1, 2) | ||
... _ = toy.toy_func(y=2, x=1) | ||
>>> r1.data_buffer | ||
[[1, 2], [1, 2]] | ||
>>> r1.get_record_data(record_idx=1, data_idx=0) | ||
1 | ||
>>> r1.get_record_data(record_idx=1, data_idx=1) | ||
2 | ||
""" | ||
|
||
def method_record_wrapper(self, orgin_method: Callable, | ||
data_buffer: List) -> Callable: | ||
"""Save the method's inputs. | ||
Args: | ||
origin_method (MethodType): The method whose inputs need to be | ||
recorded. | ||
data_buffer (list): A list of data. | ||
""" | ||
|
||
method_input_params = signature(orgin_method).parameters.keys() | ||
|
||
@functools.wraps(orgin_method) | ||
def wrap_method(*args, **kwargs): | ||
outputs = orgin_method(*args, **kwargs) | ||
# the first element of a class method is the class itself | ||
inputs = list(args[1:]) | ||
for keyword in method_input_params: | ||
if keyword in kwargs: | ||
inputs.append(kwargs[keyword]) | ||
# Assume a func execute N times, there will be N inputs need to | ||
# save. | ||
data_buffer.append(inputs) | ||
return outputs | ||
|
||
return wrap_method |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.