Skip to content

Commit

Permalink
[Typing][C-115,C-118] Add type annotations for python/paddle/geometr…
Browse files Browse the repository at this point in the history
…ic/{math,/message_passing/send_recv}.py (#67644)
  • Loading branch information
enkilee authored Aug 23, 2024
1 parent a5a32e4 commit dd3e0f7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 16 deletions.
22 changes: 18 additions & 4 deletions python/paddle/geometric/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor

__all__ = []


def segment_sum(data, segment_ids, name=None):
def segment_sum(
data: Tensor, segment_ids: Tensor, name: str | None = None
) -> Tensor:
r"""
Segment Sum Operator.
Expand Down Expand Up @@ -77,7 +85,9 @@ def segment_sum(data, segment_ids, name=None):
return out


def segment_mean(data, segment_ids, name=None):
def segment_mean(
data: Tensor, segment_ids: Tensor, name: str | None = None
) -> Tensor:
r"""
Segment mean Operator.
Expand Down Expand Up @@ -136,7 +146,9 @@ def segment_mean(data, segment_ids, name=None):
return out


def segment_min(data, segment_ids, name=None):
def segment_min(
data: Tensor, segment_ids: Tensor, name: str | None = None
) -> Tensor:
r"""
Segment min operator.
Expand Down Expand Up @@ -194,7 +206,9 @@ def segment_min(data, segment_ids, name=None):
return out


def segment_max(data, segment_ids, name=None):
def segment_max(
data: Tensor, segment_ids: Tensor, name: str | None = None
) -> Tensor:
r"""
Segment max operator.
Expand Down
55 changes: 43 additions & 12 deletions python/paddle/geometric/message_passing/send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import numpy as np
from typing_extensions import TypeAlias

from paddle import _C_ops
from paddle.base.data_feeder import (
Expand All @@ -30,12 +34,32 @@
reshape_lhs_rhs,
)

if TYPE_CHECKING:
from paddle import Tensor

_ReduceOp: TypeAlias = Literal[
"sum",
"mean",
"max",
"min",
]
_MessageOp: TypeAlias = Literal[
"add",
"sub",
"mul",
"div",
]
__all__ = []


def send_u_recv(
x, src_index, dst_index, reduce_op="sum", out_size=None, name=None
):
x: Tensor,
src_index: Tensor,
dst_index: Tensor,
reduce_op: _ReduceOp = "sum",
out_size: int | Tensor | None = None,
name: str | None = None,
) -> Tensor:
"""
Graph Learning message passing api.
Expand Down Expand Up @@ -184,15 +208,15 @@ def send_u_recv(


def send_ue_recv(
x,
y,
src_index,
dst_index,
message_op="add",
reduce_op="sum",
out_size=None,
name=None,
):
x: Tensor,
y: Tensor,
src_index: Tensor,
dst_index: Tensor,
message_op: _MessageOp = "add",
reduce_op: _ReduceOp = "sum",
out_size: int | Tensor | None = None,
name: str | None = None,
) -> Tensor:
"""
Graph Learning message passing api.
Expand Down Expand Up @@ -386,7 +410,14 @@ def send_ue_recv(
return out


def send_uv(x, y, src_index, dst_index, message_op="add", name=None):
def send_uv(
x: Tensor,
y: Tensor,
src_index: Tensor,
dst_index: Tensor,
message_op: _MessageOp = "add",
name: str | None = None,
) -> Tensor:
"""
Graph Learning message passing api.
Expand Down

0 comments on commit dd3e0f7

Please sign in to comment.