Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-116,C-117] Add type annotations for "paddle/geometric/*" #66792

Merged
merged 11 commits into from
Aug 1, 2024
25 changes: 21 additions & 4 deletions python/paddle/geometric/reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

import paddle
from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.framework import Variable
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor

__all__ = []


def reindex_graph(
x, neighbors, count, value_buffer=None, index_buffer=None, name=None
):
x: Tensor,
neighbors: Tensor,
count: Tensor,
value_buffer: Tensor | None = None,
index_buffer: Tensor | None = None,
name: str | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""

Reindex Graph API.
Expand Down Expand Up @@ -137,8 +149,13 @@ def reindex_graph(


def reindex_heter_graph(
x, neighbors, count, value_buffer=None, index_buffer=None, name=None
):
x: Tensor,
neighbors: Sequence[Tensor],
count: Sequence[Tensor],
value_buffer: Tensor | None = None,
index_buffer: Tensor | None = None,
name: str | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
"""

Reindex HeterGraph API.
Expand Down
90 changes: 90 additions & 0 deletions python/paddle/geometric/sampling/neighbors.py
Copy link
Contributor

@megemini megemini Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重写了一遍函数?overload 不是这么用的,可以搜一下已经使用了 overload 标注的地方 ~

一个简单的例子

@overload
def foo(arg0: Any, arg1: Literal[True]) -> int: ...
@overload
def foo(arg0: Any, arg1: Literal[False]) -> str: ...
@overload
def foo(arg0: Any, arg1: bool = ...) -> int | str: ...
def foo(arg0, arg1=True):
    if arg1:
        return int(arg0)
    return str(arg0)

x: int = foo(1, True)
y: str = foo(1, False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修正,感谢您的耐心指导

Copy link
Contributor

@megemini megemini Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改的还是有点问题,注意几个地方:

  • overload 中一般不写具体参数值,更不能更改原赋值
  • overload 中使用 Literal ,如果原定义中有默认值,可以跟 = ...
  • 需要单独一个 arg1: bool = ...overload ,用于针对参数传值而非直接传值
  • 使用 overload 后,原函数不需要再标注类型(也可以不定义最后的 overload ,而直接在原函数中标注)

overload 本身有些行为定义的还不是很清楚,但,一定要分清楚类型标注和默认值的关系 ~

p.s. 更新了上面的示例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样吗?

Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, overload

from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor
__all__ = []


@overload
def sample_neighbors(
row: Tensor,
colptr: Tensor,
input_nodes: Tensor,
sample_size: int = ...,
eids: Tensor | None = ...,
return_eids: Literal[True] = ...,
perm_buffer: Tensor | None = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor, Tensor]:
...


@overload
def sample_neighbors(
row: Tensor,
colptr: Tensor,
input_nodes: Tensor,
sample_size: int = ...,
eids: Tensor | None = ...,
return_eids: Literal[False] = ...,
perm_buffer: Tensor | None = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor]:
...


@overload
def sample_neighbors(
row: Tensor,
colptr: Tensor,
input_nodes: Tensor,
sample_size: int = ...,
eids: Tensor | None = ...,
return_eids: bool = ...,
perm_buffer: Tensor | None = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
...


def sample_neighbors(
row,
colptr,
Expand Down Expand Up @@ -169,6 +217,48 @@ def sample_neighbors(
return out_neighbors, out_count


@overload
def weighted_sample_neighbors(
row: Tensor,
colptr: Tensor,
edge_weight: Tensor,
input_nodes: Tensor,
sample_size: int = ...,
eids: Tensor | None = ...,
return_eids: Literal[True] = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor, Tensor]:
...


@overload
def weighted_sample_neighbors(
row: Tensor,
colptr: Tensor,
edge_weight: Tensor,
input_nodes: Tensor,
sample_size: int = ...,
eids: Tensor | None = ...,
return_eids: Literal[False] = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor]:
...


@overload
def weighted_sample_neighbors(
row: Tensor,
colptr: Tensor,
edge_weight: Tensor,
input_nodes: Tensor,
sample_size: int = ...,
eids: Tensor | None = ...,
return_eids: bool = ...,
name: str | None = ...,
) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, Tensor]:
...


def weighted_sample_neighbors(
row,
colptr,
Expand Down