Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-119] Add type annotations for python/paddle/quantization/config.py #66684

Merged
merged 7 commits into from
Aug 2, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions python/paddle/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
from typing import Dict, Union
from typing import TYPE_CHECKING

import paddle
from paddle import nn
from paddle.nn import Layer

from .factory import QuanterFactory
from .wrapper import ObserveWrapper

if TYPE_CHECKING:
from paddle.nn import Layer

from .factory import QuanterFactory


# TODO: Implement quanted layer and fill the mapping dict
DEFAULT_QAT_LAYER_MAPPINGS: Dict[Layer, Layer] = {
DEFAULT_QAT_LAYER_MAPPINGS: dict[Layer, Layer] = {
nn.quant.Stub: nn.quant.stub.QuanterStub,
nn.Linear: nn.quant.qat.QuantedLinear,
nn.Conv2D: nn.quant.qat.QuantedConv2D,
Expand All @@ -41,16 +46,18 @@ class SingleLayerConfig:
weight(QuanterFactory): The factory to create instance of quanter used to quantize weights.
"""

def __init__(self, activation: QuanterFactory, weight: QuanterFactory):
def __init__(
self, activation: QuanterFactory, weight: QuanterFactory
) -> None:
self._activation = activation
self._weight = weight

@property
def activation(self):
def activation(self) -> QuanterFactory:
return self._activation

@property
def weight(self):
def weight(self) -> QuanterFactory:
return self._weight

def __str__(self):
Expand Down Expand Up @@ -82,7 +89,9 @@ class QuantConfig:

"""

def __init__(self, activation: QuanterFactory, weight: QuanterFactory):
def __init__(
self, activation: QuanterFactory, weight: QuanterFactory
enkilee marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
if activation is None and weight is None:
self._global_config = None
else:
Expand All @@ -98,16 +107,16 @@ def __init__(self, activation: QuanterFactory, weight: QuanterFactory):

def add_layer_config(
self,
layer: Union[Layer, list],
layer: Layer | list[Layer],
activation: QuanterFactory = None,
enkilee marked this conversation as resolved.
Show resolved Hide resolved
weight: QuanterFactory = None,
):
) -> None:
r"""
Set the quantization config by layer. It has the highest priority among
all the setting methods.

Args:
layer(Union[Layer, list]): One or a list of layers.
layer(Layer|list[Layer]]): One or a list of layers.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.

Expand Down Expand Up @@ -147,16 +156,16 @@ def add_layer_config(

def add_name_config(
self,
layer_name: Union[str, list],
layer_name: str | list[str],
activation: QuanterFactory = None,
weight: QuanterFactory = None,
):
) -> None:
r"""
Set the quantization config by full name of layer. Its priority is
lower than `add_layer_config`.

Args:
layer_name(Union[str, list]): One or a list of layers' full name.
layer_name(str|list[str]): One or a list of layers' full name.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.

Expand Down Expand Up @@ -195,17 +204,17 @@ def add_name_config(

def add_type_config(
self,
layer_type: Union[type, list],
layer_type: type | list[type],
enkilee marked this conversation as resolved.
Show resolved Hide resolved
activation: QuanterFactory = None,
weight: QuanterFactory = None,
):
) -> None:
r"""
Set the quantization config by the type of layer. The `layer_type` should be
subclass of `paddle.nn.Layer`. Its priority is lower than `add_layer_config`
and `add_name_config`.

Args:
layer_type(Union[type, list]): One or a list of layers' type. It should be subclass of
layer_type(type|list[type]): One or a list of layers' type. It should be subclass of
`paddle.nn.Layer`. Python build-in function `type()` can be used to get the type of a layer.
activation(QuanterFactory): Quanter used for activations.
weight(QuanterFactory): Quanter used for weights.
Expand Down Expand Up @@ -245,7 +254,7 @@ def add_type_config(
_element, activation=activation, weight=weight
)

def add_qat_layer_mapping(self, source: type, target: type):
def add_qat_layer_mapping(self, source: type, target: type) -> None:
r"""
Add rules converting layers to simulated quantization layers
before quantization-aware training. It will convert layers
Expand Down Expand Up @@ -280,7 +289,7 @@ def add_qat_layer_mapping(self, source: type, target: type):
self._qat_layer_mapping[source] = target
self._customized_qat_layer_mapping[source] = target

def add_customized_leaf(self, layer_type: type):
def add_customized_leaf(self, layer_type: type) -> None:
r"""
Declare the customized layer as leaf of model for quantization.
The leaf layer is quantized as one layer. The sublayers of
Expand All @@ -302,7 +311,7 @@ def add_customized_leaf(self, layer_type: type):
self._customized_leaves.append(layer_type)

@property
def customized_leaves(self):
def customized_leaves(self) -> list:
enkilee marked this conversation as resolved.
Show resolved Hide resolved
r"""
Get all the customized leaves.
"""
Expand Down Expand Up @@ -362,11 +371,11 @@ def _get_observe_wrapper(self, layer: Layer):
return ObserveWrapper(_observer, layer)

@property
def qat_layer_mappings(self):
def qat_layer_mappings(self) -> dict[Layer, Layer]:
return self._qat_layer_mapping

@property
def default_qat_layer_mapping(self):
def default_qat_layer_mapping(self) -> dict[Layer, Layer]:
return DEFAULT_QAT_LAYER_MAPPINGS

@property
Expand Down