From 72b841ddcbac294ee02294645845e3edc592af8e Mon Sep 17 00:00:00 2001 From: MikhayEeer <814416829@qq.com> Date: Tue, 22 Oct 2024 19:40:13 +0800 Subject: [PATCH 1/3] No.40 part1 --- .../cuda/torch.cuda.comm.gather.md | 35 +++++++++++++ .../cuda/torch.cuda.comm.scatter.md | 50 +++++++++++++++++++ .../cuda/torch.cuda.device_of.md | 18 +++++++ .../cuda/torch.cuda.is_initialized.md | 28 +++++++++++ .../torch/torch.get_default_device.md | 17 +++++++ .../torch/torch.set_default_device.md | 17 +++++++ 6 files changed, 165 insertions(+) create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.gather.md create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.scatter.md create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device_of.md create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.get_default_device.md create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.set_default_device.md diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.gather.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.gather.md new file mode 100644 index 00000000000..f785f39aada --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.gather.md @@ -0,0 +1,35 @@ +## [组合替代实现] torch.cuda.comm.gather + +### [torch.cuda.comm.gather](https://pytorch.org/docs/stable/generated/torch.cuda.comm.gather.html) +```python +torch.cuda.comm.gather(tensors, dim=0, destination=None, *, out=None) +``` + +将多个设备的张量集中起来,Paddle 无此 API,需要组合替代实现。 + +### 转写示例 +```python +# PyTorch 写法 +destination = 'cuda:0' +gathered_tensor = torch.cuda.comm.gather(tensors, destination=destination) + +# Paddle 写法 +def paddle_comm_gather(tensors, dim=0, destination=None, *, out=None): + if destination is None: + destination = paddle.CPUPlace() + elif 'cuda' in destination: + destination = paddle.CUDAPlace(int(destination.split(':')[-1])) + + gathered_tensors = [t.cuda(destination) if 'cuda' in t.place.__str__() else t.cpu() for t in tensors] + + gathered_tensor = paddle.concat(gathered_tensors, axis=dim) + + if out is not None: + out.copy_(gathered_tensor) + return out + + return gathered_tensor + +destination = 'gpu:0' +gathered_tensor = paddle_comm_gather(tensors, dim=dim, destination=destination) +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.scatter.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.scatter.md new file mode 100644 index 00000000000..72017d0ee16 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.comm.scatter.md @@ -0,0 +1,50 @@ +## [组合替代实现] torch.cuda.comm.scatter + +### [torch.cuda.comm.scatter](https://pytorch.org/docs/stable/generated/torch.cuda.comm.scatter.html) + +```python +torch.cuda.comm.scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None) +``` + +将张量分散到多个设备上,Paddle 无此 API,需要组合替代实现 + +### 转写示例 +```python +# torch 写法 +devices = [torch.device('cuda:0'), torch.device('cuda:1')] +torch.cuda.comm.scatter(inputs, devices=devices) + +# paddle 写法 +def paddle_comm_scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, out=None): + if devices is None: + devices = ['cpu'] * len(tensor) + + if chunk_sizes is not None: + chunks = paddle.split(tensor, num_or_sections=chunk_sizes, dim=dim) + else: + chunks = tensor if isinstance(tensor, list) else [tensor] + + scattered_tensors = out if out is not None else [] + + for idx, (chunk, device) in enumerate(zip(chunks, devices)): + place = paddle.CUDAPlace(int(device.split(':')[-1])) if 'cuda' in device else paddle.CPUPlace() + + tensor_on_device = chunk.cuda(place) if 'cuda' in device else chunk.cpu() + + if streams is not None: + stream = streams[idx] + tensor_on_device = tensor_on_device.cuda(place, non_blocking=True) + tensor_on_device = tensor_on_device.cuda_stream(stream) + + if out is not None: + out[idx].copy_(tensor_on_device) + else: + scattered_tensors.append(tensor_on_device) + + if out is None: + return scattered_tensors + +devices = ['gpu:0', 'gpu:1'] +chunk_sizes = [5, 5] +scattered_tensors = paddle_comm_scatter(tensor, devices=devices, chunk_sizes=chunk_sizes) +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device_of.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device_of.md new file mode 100644 index 00000000000..c3864a1db7d --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device_of.md @@ -0,0 +1,18 @@ +## [组合替代实现] torch.cuda.device_of + +### [torch.cuda.device_of](https://pytorch.org/docs/stable/generated/torch.cuda.device_of.html#torch.cuda.device_of) +```python +torch.cuda.device_of(obj) +``` + +获取张量所在的设备,Paddle 无此 api,需要组合实现 +可以通过`tensor.place`来获取张量所在的设备信息 + +### 转写示例 +```python +# torch 写法 +device = torch.cuda.device_of(tensor) + +# paddle 写法 +device = tensor.place +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md new file mode 100644 index 00000000000..a30eeb54401 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md @@ -0,0 +1,28 @@ +## [ 待定 ]torch.cuda.is_initialized + +### [torch.cuda.is_initialized]()(https://pytorch.org/docs/stable/generated/torch.cuda.is_initialized.html) + +```python +torch.cuda.is_initialized() +``` + +判断 cuda 是否初始化,Paddle 无此 API,需要组合实现。 +Paddle 可以通过检查是否支持 cuda,并且尝试创建一个张量来判断初始化是否成功。 + +### 转写示例 + +```python +# torch 写法 +torch.cuda.is_initialized() + +# paddle 写法 +def paddle_cuda_is_initialized(): + if not paddle.is_compiled_with_cuda(): + return False + try: + cuda_tensor = paddle.rand([1], place=paddle.CUDAPlace(0)) + return True + except Exception as e: + return False +paddle_cuda_is_initialized() +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.get_default_device.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.get_default_device.md new file mode 100644 index 00000000000..16ba00eb4fa --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.get_default_device.md @@ -0,0 +1,17 @@ +## [组合替代实现] torch.get_default_device + +### [torch.get_default_device](https://pytorch.org/docs/stable/generated/torch.get_default_device.html) +```python +torch.get_default_device() +``` + +获取默认的设备,Paddle 无此 api, 需要组合实现 + +### 转写示例 +```python +# torch 写法 +device = torch.get_default_device() + +# paddle 写法 +device = paddle.device.get_device() +``` diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.set_default_device.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.set_default_device.md new file mode 100644 index 00000000000..1b873dc176c --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/torch/torch.set_default_device.md @@ -0,0 +1,17 @@ +## [组合替代实现] torch.set_default_device +### [torch.set_default_device](https://pytorch.org/docs/stable/generated/torch.set_default_device.html#torch.set_default_device) +```python +torch.set_default_device(device) +``` + +设置默认设备,Paddle 无此 api,需要组合替代实现。 + +### 转写示例 + +```python +# torch 写法 +torch.set_default_device(device) + +# paddle 写法 +paddle.device.set_device(device) +``` From bcbca05d2c35b49108f8ea75b9d114ffee347a2a Mon Sep 17 00:00:00 2001 From: MikhayEeer <814416829@qq.com> Date: Tue, 22 Oct 2024 20:08:28 +0800 Subject: [PATCH 2/3] Corrected the title naming of is_initialized --- .../api_difference/cuda/torch.cuda.is_initialized.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md index a30eeb54401..e41e27f0ff7 100644 --- a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md @@ -1,4 +1,4 @@ -## [ 待定 ]torch.cuda.is_initialized +## [组合替代实现]torch.cuda.is_initialized ### [torch.cuda.is_initialized]()(https://pytorch.org/docs/stable/generated/torch.cuda.is_initialized.html) From d8db4a833b4a4671ad2b4194b826c49c01fa303f Mon Sep 17 00:00:00 2001 From: MikhayEeer <814416829@qq.com> Date: Tue, 22 Oct 2024 20:43:35 +0800 Subject: [PATCH 3/3] Fix link format --- .../api_difference/cuda/torch.cuda.is_initialized.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md index e41e27f0ff7..503a56bccd1 100644 --- a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.is_initialized.md @@ -1,6 +1,6 @@ -## [组合替代实现]torch.cuda.is_initialized +## [ 组合替代实现 ]torch.cuda.is_initialized -### [torch.cuda.is_initialized]()(https://pytorch.org/docs/stable/generated/torch.cuda.is_initialized.html) +### [torch.cuda.is_initialized](xly.bce.baidu.com/paddlepaddle/fluid-doc/newipipe/detail/11746629/job/27824342/realTimeLog/479) ```python torch.cuda.is_initialized()