From 4ee4476e39335df180588e9531b38f2beeaedd60 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Wed, 27 Sep 2023 09:52:57 -0700 Subject: [PATCH] Helper: to_torch Add helper methods to generate PyTorch tensors. --- src/amrex/Array4.py | 3 +++ src/amrex/ArrayOfStructs.py | 3 +++ src/amrex/MultiFab.py | 5 +++++ src/amrex/PODVector.py | 14 ++++++++++++++ src/amrex/StructOfArrays.py | 14 ++++++++++++++ 5 files changed, 39 insertions(+) diff --git a/src/amrex/Array4.py b/src/amrex/Array4.py index 5bc28448..375560d2 100644 --- a/src/amrex/Array4.py +++ b/src/amrex/Array4.py @@ -82,6 +82,9 @@ def array4_to_cupy(self, copy=False, order="F"): raise ValueError("The order argument must be F or C.") +# torch + + def register_Array4_extension(amr): """Array4 helper methods""" import inspect diff --git a/src/amrex/ArrayOfStructs.py b/src/amrex/ArrayOfStructs.py index ff2ed4fd..74b6e609 100644 --- a/src/amrex/ArrayOfStructs.py +++ b/src/amrex/ArrayOfStructs.py @@ -75,6 +75,9 @@ def aos_to_cupy(self, copy=False): return cp.array(self, copy=copy) +# torch + + def register_AoS_extension(amr): """ArrayOfStructs helper methods""" import inspect diff --git a/src/amrex/MultiFab.py b/src/amrex/MultiFab.py index a2b46899..142b5733 100644 --- a/src/amrex/MultiFab.py +++ b/src/amrex/MultiFab.py @@ -89,6 +89,9 @@ def mf_to_cupy(self, copy=False, order="F"): return views +# torch + + def register_MultiFab_extension(amr): """MultiFab helper methods""" @@ -99,3 +102,5 @@ def register_MultiFab_extension(amr): amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__ amr.MultiFab.to_cupy = mf_to_cupy + + # torch diff --git a/src/amrex/PODVector.py b/src/amrex/PODVector.py index c241405c..611c0a9b 100644 --- a/src/amrex/PODVector.py +++ b/src/amrex/PODVector.py @@ -68,6 +68,19 @@ def podvector_to_cupy(self, copy=False): raise ValueError("Vector is empty.") +def podvector_to_torch(self, copy=False): + """ + Provide PyTorch tensor views into a PODVector (e.g., RealVector, IntVector). + + ... + """ + import torch + + # if CUDA else ... + # pick right device (context? device number?) + return torch.as_tensor(self.to_cupy(copy), device="cuda") + + def register_PODVector_extension(amr): """PODVector helper methods""" import inspect @@ -82,3 +95,4 @@ def register_PODVector_extension(amr): ): POD_type.to_numpy = podvector_to_numpy POD_type.to_cupy = podvector_to_cupy + POD_type.to_torch = podvector_to_torch diff --git a/src/amrex/StructOfArrays.py b/src/amrex/StructOfArrays.py index e906732b..126a7d31 100644 --- a/src/amrex/StructOfArrays.py +++ b/src/amrex/StructOfArrays.py @@ -83,6 +83,19 @@ def soa_to_cupy(self, copy=False): return soa_view +def soa_to_torch(self, copy=False): + """ + Provide PyTorch tensor views into a StructOfArrays. + + ... + """ + import torch + + # if CUDA else ... + # pick right device (context? device number?) + return torch.as_tensor(self.to_cupy(copy), device="cuda") + + def register_SoA_extension(amr): """StructOfArrays helper methods""" import inspect @@ -97,3 +110,4 @@ def register_SoA_extension(amr): ): SoA_type.to_numpy = soa_to_numpy SoA_type.to_cupy = soa_to_cupy + SoA_type.to_torch = soa_to_torch