Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.15】 为 Paddle 新增 Tensor.to() 以及 Layer.astype() API #6305

Merged
merged 7 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/paddle/Tensor/Overview_en.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ Methods
tanh
tanh_
tile
to
tolist
topk
trace
Expand Down
21 changes: 21 additions & 0 deletions docs/api/paddle/Tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,27 @@ tile(repeat_times, name=None)

请参考 :ref:`cn_api_paddle_tile`

to(*args, **kwargs)
:::::::::

转换 Tensor 的设备或/和数据类型,并且返回转换后的 Tensor。该函数将会从 ``args`` 以及 ``kwargs`` 中解析出要转换到的目标类型 dtype 以及目标设备 place。
目前支持一下三种方式调用该方法:

1. to(dtype, blocking=True)
2. to(device, dtype=None, blocking=True)
3. to(other, blocking=True)

其中, ``dtype`` 可以是 ``paddle.dtype``, ``numpy.dtype`` 类型或者是 ``["bfloat16", "float16", "float32", "float64", "int8", "int16", "int32",
"int64", "uint8", "complex64", "complex128", "bool"]`` 中的任意一个 ``str``。 ``device`` 可以是 ``paddle.CPUPlace()``, ``paddle.CUDAPlace()``,
``paddle.CUDAPinnedPlace()``, ``paddle.XPUPlace()``, ``paddle.CustomPlace()`` 或者 ``str``。 ``other`` 需要是 ``Tensor`` 类型。

返回:类型转换后的新的 Tensor

返回类型:Tensor

**代码示例**
COPY-FROM: paddle.Tensor.to

tolist()
:::::::::

Expand Down
15 changes: 15 additions & 0 deletions docs/api/paddle/nn/Layer_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,21 @@ to(device=None, dtype=None, blocking=None)

COPY-FROM: paddle.nn.Layer.to

astype(dtype=None)
:::::::::

将 Layer 的所有 ``parameters`` 和 ``buffers`` 的数据类型转换为 ``dtype``,并返回这个 Layer。

**参数**
- **dtype** (str | paddle.dtype | numpy.dtype) - 转换后的 dtype,str 类型支持"bool", "bfloat16", "float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "complex64", "complex128"。
YibinLiu666 marked this conversation as resolved.
Show resolved Hide resolved

返回:类型转换后的 Layer

返回类型:Layer

**代码示例**
COPY-FROM: paddle.nn.Layer.astype

float(excluded_layers=None)
'''''''''

Expand Down
2 changes: 1 addition & 1 deletion docs/api/paddle/nn/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ Embedding 相关函数
:header: "API 名称", "API 功能"


" :ref:`paddle.nn.functional.diag_embed <cn_api_paddle_nn_functional_diag_embed>` ", "对角线 Embedding 方法"
" paddle.nn.functional.diag_embed ", "对角线 Embedding 方法,paddle.nn.functional.diag_embed 已废弃,请使用 :ref:`paddle.diag_embed <cn_api_paddle_diag_embed>` "
" :ref:`paddle.nn.functional.embedding <cn_api_paddle_nn_functional_embedding>` ", "Embedding 方法"

.. _loss_functional:
Expand Down