From eac2260c9ac5bfddba243be8b96467b1c866affd Mon Sep 17 00:00:00 2001
From: melt <thomas.meltzer1@gmail.com>
Date: Fri, 3 Nov 2023 17:06:01 +0000
Subject: [PATCH 01/20] add mwe for fypp preprocessor

this example generates multiple variants of a function
`torch_tensor_from_array_c_*` depending on data type and rank of input
data.

We should discuss how we actually want to implement this.
---
 src/ftorch.fypp | 307 ++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 307 insertions(+)
 create mode 100644 src/ftorch.fypp

diff --git a/src/ftorch.fypp b/src/ftorch.fypp
new file mode 100644
index 00000000..93a7b4be
--- /dev/null
+++ b/src/ftorch.fypp
@@ -0,0 +1,307 @@
+#:def ranksuffix(RANK)
+$:'' if RANK == 0 else '(' + ':' + ',:' * (RANK - 1) + ')'
+#:enddef ranksuffix
+#:set PRECISIONS = ['float', 'double']
+#:set RANKS = range(1, 3)
+#:set ENUMS = dict(zip(PRECISIONS, ['torch_kFloat32', 'torch_kFloat64']))
+#:def enum_from_prec(PRECISION)
+$:ENUMS[PRECISION]
+#:enddef enum_from_prec
+module ftorch
+
+  use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
+    c_float, c_double, c_char, c_ptr, c_null_ptr
+  implicit none
+
+  type torch_module
+    type(c_ptr) :: p = c_null_ptr
+  end type torch_module
+
+  type torch_tensor
+    type(c_ptr) :: p = c_null_ptr
+  end type torch_tensor
+
+  ! From c_torch.h (torch_data_t)
+  enum, bind(c)
+    enumerator :: torch_kUInt8 = 0
+    enumerator :: torch_kInt8 = 1
+    enumerator :: torch_kInt16 = 2
+    enumerator :: torch_kInt32 = 3
+    enumerator :: torch_kInt64 = 4
+    enumerator :: torch_kFloat16 = 5
+    enumerator :: torch_kFloat32 = 6
+    enumerator :: torch_kFloat64 = 7
+  end enum
+
+  ! From c_torch.h (torch_device_t)
+  enum, bind(c)
+    enumerator :: torch_kCPU = 0
+    enumerator :: torch_kCUDA = 1
+  end enum
+
+  ! Interface for calculating tensor from array for different possible input types
+  interface torch_tensor_from_array
+    #:for PREC in PRECISIONS
+    #:for RANK in RANKS
+    module procedure torch_tensor_from_array_c_${PREC}$_${RANK}$
+    #:endfor
+    #:endfor
+    ! module procedure torch_tensor_from_array_c_int8_t
+    ! module procedure torch_tensor_from_array_c_int16_t
+    ! module procedure torch_tensor_from_array_c_int32_t
+    ! module procedure torch_tensor_from_array_c_int64_t
+  end interface
+
+contains
+
+  ! Torch Tensor API
+  !> Exposes the given data as a tensor without taking ownership of the original data.
+  !> This routine will take an (i, j, k) array and return an (k, j, i) tensor.
+  function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+    type(c_ptr), intent(in)        :: data       !! Pointer to data
+    integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
+    integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
+    integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: layout(*)  !! Layout for strides for accessing data
+    type(torch_tensor)             :: tensor     !! Returned tensor
+
+    integer(c_int)                 :: i          !! loop index
+    integer(c_int64_t)             :: strides(ndims) !! Strides for accessing data
+
+    interface
+      function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
+          bind(c, name='torch_from_blob')
+        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+        type(c_ptr), value, intent(in)    :: data
+        integer(c_int), value, intent(in) :: ndims
+        integer(c_int64_t), intent(in)    :: tensor_shape(*)
+        integer(c_int64_t), intent(in)    :: strides(*)
+        integer(c_int), value, intent(in) :: dtype
+        integer(c_int), value, intent(in) :: device
+        type(c_ptr)                       :: tensor
+      end function torch_from_blob_c
+    end interface
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+    end do
+    tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
+  end function torch_tensor_from_blob
+
+  !> This routine will take an (i, j, k) array and return an (k, j, i) tensor
+  !> it is invoked from a set of interfaces `torch_tensor_from_array_dtype`
+  function t_t_from_array(data_arr, tensor_shape, dtype, device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc
+    type(c_ptr), intent(in)          :: data_arr       !! Pointer to data
+    integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
+    integer(c_int), intent(in)       :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    type(torch_tensor)               :: tensor     !! Returned tensor
+
+    integer(c_int)                   :: i          !! loop index
+    integer(c_int64_t), allocatable  :: strides(:) !! Strides for accessing data
+    integer(c_int), allocatable      :: layout(:)  !! Layout for strides for accessing data
+    integer(c_int)                   :: ndims      !! Number of dimensions of the tensor
+
+    interface
+      function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
+          bind(c, name='torch_from_blob')
+        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+        type(c_ptr), value, intent(in)    :: data
+        integer(c_int), value, intent(in) :: ndims
+        integer(c_int64_t), intent(in)    :: tensor_shape(*)
+        integer(c_int64_t), intent(in)    :: strides(*)
+        integer(c_int), value, intent(in) :: dtype
+        integer(c_int), value, intent(in) :: device
+        type(c_ptr)                       :: tensor
+      end function torch_from_blob_c
+    end interface
+
+    ndims = size(tensor_shape)
+
+    allocate(strides(ndims))
+    allocate(layout(ndims))
+
+    ! Fortran Layout
+    do i=1, ndims
+      layout(i) = i
+    end do
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+    end do
+
+    tensor%p = torch_from_blob_c(data_arr, ndims, tensor_shape, strides, dtype, device)
+
+    deallocate(strides)
+    deallocate(layout)
+
+  end function t_t_from_array
+
+  !> Returns a tensor filled with the scalar value 1.
+  function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+    integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
+    integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
+    integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    type(torch_tensor)             :: tensor     !! Returned tensor
+
+    interface
+      function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
+          bind(c, name='torch_ones')
+        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+        integer(c_int), value, intent(in) :: ndims
+        integer(c_int64_t), intent(in)    :: tensor_shape(*)
+        integer(c_int), value, intent(in) :: dtype
+        integer(c_int), value, intent(in) :: device
+        type(c_ptr)                       :: tensor
+      end function torch_ones_c
+    end interface
+
+    tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device)
+  end function torch_tensor_ones
+
+  !> Returns a tensor filled with the scalar value 0.
+  function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+    integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
+    integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
+    integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    type(torch_tensor)             :: tensor     !! Returned tensor
+
+    interface
+      function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
+          bind(c, name='torch_zeros')
+        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+        integer(c_int), value, intent(in) :: ndims
+        integer(c_int64_t), intent(in)    :: tensor_shape(*)
+        integer(c_int), value, intent(in) :: dtype
+        integer(c_int), value, intent(in) :: device
+        type(c_ptr)                       :: tensor
+      end function torch_zeros_c
+    end interface
+
+    tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device)
+  end function torch_tensor_zeros
+
+  !> Prints the contents of a tensor.
+  subroutine torch_tensor_print(tensor)
+    type(torch_tensor), intent(in) :: tensor     !! Input tensor
+
+    interface
+      subroutine torch_tensor_print_c(tensor) &
+          bind(c, name='torch_tensor_print')
+        use, intrinsic :: iso_c_binding, only : c_ptr
+        type(c_ptr), value, intent(in) :: tensor
+      end subroutine torch_tensor_print_c
+    end interface
+
+    call torch_tensor_print_c(tensor%p)
+  end subroutine torch_tensor_print
+
+  !> Deallocates a tensor.
+  subroutine torch_tensor_delete(tensor)
+    type(torch_tensor), intent(in) :: tensor     !! Input tensor
+
+    interface
+      subroutine torch_tensor_delete_c(tensor) &
+          bind(c, name='torch_tensor_delete')
+        use, intrinsic :: iso_c_binding, only : c_ptr
+        type(c_ptr), value, intent(in) :: tensor
+      end subroutine torch_tensor_delete_c
+    end interface
+
+    call torch_tensor_delete_c(tensor%p)
+  end subroutine torch_tensor_delete
+
+  ! Torch Module API
+  !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
+  function torch_module_load(filename) result(module)
+    use, intrinsic :: iso_c_binding, only : c_char
+    character(c_char), intent(in) :: filename(*) !! Filename of Torch Script module
+    type(torch_module)            :: module      !! Returned deserialized module
+
+    interface
+      function torch_jit_load_c(filename) result(module) &
+          bind(c, name='torch_jit_load')
+        use, intrinsic :: iso_c_binding, only : c_char, c_ptr
+        character(c_char), intent(in) :: filename(*)
+        type(c_ptr)                   :: module
+      end function torch_jit_load_c
+    end interface
+
+    ! Need to append c_null_char at end of filename
+    module%p = torch_jit_load_c(filename)
+    end function torch_module_load
+
+    !> Performs a forward pass of the module with the input tensors
+    subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
+      use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
+      type(torch_module), intent(in) :: module        !! Module
+      type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
+      type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
+      integer(c_int) ::  n_inputs
+
+      integer :: i
+      type(c_ptr), dimension(n_inputs), target  :: input_ptrs
+
+      interface
+        subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
+            output_tensor) &
+            bind(c, name='torch_jit_module_forward')
+          use, intrinsic :: iso_c_binding, only : c_ptr, c_int
+          type(c_ptr), value, intent(in) :: module
+          type(c_ptr), value, intent(in) :: input_tensors
+          integer(c_int), value, intent(in) :: n_inputs
+          type(c_ptr), value, intent(in) :: output_tensor
+        end subroutine torch_jit_module_forward_c
+      end interface
+
+      ! Assign array of pointers to the input tensors
+      do i = 1, n_inputs
+        input_ptrs(i) = input_tensors(i)%p
+      end do
+
+      call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
+    end subroutine torch_module_forward
+
+    !> Deallocates a Torch Script module
+    subroutine torch_module_delete(module)
+      type(torch_module), intent(in) :: module     !! Module
+
+      interface
+        subroutine torch_jit_module_delete_c(module) &
+            bind(c, name='torch_jit_module_delete')
+          use, intrinsic :: iso_c_binding, only : c_ptr
+          type(c_ptr), value, intent(in) :: module
+        end subroutine torch_jit_module_delete_c
+      end interface
+
+      call torch_jit_module_delete_c(module%p)
+    end subroutine torch_module_delete
+
+    ! Series of interface functions
+    #:for PREC in PRECISIONS
+    #:for RANK in RANKS
+    function torch_tensor_from_array_c_${PREC}$_${RANK}$(data_arr, tensor_shape, device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      real(c_${PREC}$), intent(in), target :: data_arr${ranksuffix(RANK)}$   !! Fortran array of data
+      integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
+      integer(c_int), parameter :: dtype = ${enum_from_prec(PREC)}$
+      integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+      type(torch_tensor)               :: tensor     !! Returned tensor
+
+      tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device)
+
+    end function torch_tensor_from_array_c_${PREC}$_${RANK}$
+
+    #:endfor
+    #:endfor
+
+  end module ftorch

From e94e0c9c780383a24bc1932320aaf32e8930336a Mon Sep 17 00:00:00 2001
From: melt <thomas.meltzer1@gmail.com>
Date: Fri, 17 Nov 2023 13:02:08 +0000
Subject: [PATCH 02/20] working example of _fortran only_ interface

---
 examples/2_ResNet18/resnet_infer_fortran.f90 |   39 +-
 src/ftorch.f90                               | 1194 ++++++++++++++----
 src/ftorch.fypp                              |  147 +--
 3 files changed, 1017 insertions(+), 363 deletions(-)

diff --git a/examples/2_ResNet18/resnet_infer_fortran.f90 b/examples/2_ResNet18/resnet_infer_fortran.f90
index dfc012b1..1af256af 100644
--- a/examples/2_ResNet18/resnet_infer_fortran.f90
+++ b/examples/2_ResNet18/resnet_infer_fortran.f90
@@ -1,18 +1,12 @@
 program inference
 
-   ! Imports primitives used to interface with C
-   use, intrinsic :: iso_c_binding, only: c_sp=>c_float, c_dp=>c_double, c_int64_t, c_loc
-   use, intrinsic :: iso_fortran_env, only : sp => real32, dp => real64
+   use, intrinsic :: iso_fortran_env, only : sp => real32
    ! Import our library for interfacing with PyTorch
    use :: ftorch
 
    implicit none
 
-   ! Define working precision for C primitives
-   ! Precision must match `wp` in resnet18.py and `wp_torch` in pt2ts.py
-   integer, parameter :: c_wp = c_sp
    integer, parameter :: wp = sp
-   integer, parameter :: torch_wp = torch_kFloat32
 
    call main()
 
@@ -25,21 +19,21 @@ subroutine main()
       integer :: num_args, ix
       character(len=128), dimension(:), allocatable :: args
 
-      ! Set up types of input and output data and the interface with C
+      ! Set up types of input and output data
       type(torch_module) :: model
       type(torch_tensor), dimension(1) :: in_tensor
       type(torch_tensor) :: out_tensor
 
-      real(c_wp), dimension(:,:,:,:), allocatable, target :: in_data
-      integer(c_int), parameter :: n_inputs = 1
-      real(c_wp), dimension(:,:), allocatable, target :: out_data
+      real(wp), dimension(:,:,:,:), allocatable, target :: in_data
+      real(wp), dimension(:,:), allocatable, target :: out_data
+      integer, parameter :: n_inputs = 1
 
-      integer(c_int), parameter :: in_dims = 4
-      integer(c_int64_t) :: in_shape(in_dims) = [1, 3, 224, 224]
-      integer(c_int) :: in_layout(in_dims) = [1,2,3,4]
-      integer(c_int), parameter :: out_dims = 2
-      integer(c_int64_t) :: out_shape(out_dims) = [1, 1000]
-      integer(c_int) :: out_layout(out_dims) = [1,2]
+      integer, parameter :: in_dims = 4
+      integer :: in_shape(in_dims) = [1, 3, 224, 224]
+      integer :: in_layout(in_dims) = [1,2,3,4]
+      integer, parameter :: out_dims = 2
+      integer :: out_shape(out_dims) = [1, 1000]
+      integer :: out_layout(out_dims) = [1,2]
 
       ! Binary file containing input tensor
       character(len=*), parameter :: filename = '../data/image_tensor.dat'
@@ -72,8 +66,9 @@ subroutine main()
       call load_data(filename, tensor_length, in_data)
 
       ! Create input/output tensors from the above arrays
-      in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), in_dims, in_shape, torch_wp, torch_kCPU, in_layout)
-      out_tensor = torch_tensor_from_blob(c_loc(out_data), out_dims, out_shape, torch_wp, torch_kCPU, out_layout)
+      in_tensor(1) = torch_tensor_from_array(in_data, in_layout, torch_kCPU)
+
+      out_tensor = torch_tensor_from_array(out_data, out_layout, torch_kCPU)
 
       ! Load ML model (edit this line to use different models)
       model = torch_module_load(args(1))
@@ -113,9 +108,9 @@ subroutine load_data(filename, tensor_length, in_data)
 
       character(len=*), intent(in) :: filename
       integer, intent(in) :: tensor_length
-      real(c_wp), dimension(:,:,:,:), intent(out) :: in_data
+      real(wp), dimension(:,:,:,:), intent(out) :: in_data
 
-      real(c_wp) :: flat_data(tensor_length)
+      real(wp) :: flat_data(tensor_length)
       integer :: ios
       character(len=100) :: ioerrmsg
 
@@ -166,7 +161,7 @@ subroutine calc_probs(out_data, probabilities)
 
       implicit none
 
-      real(c_wp), dimension(:,:), intent(in) :: out_data
+      real(wp), dimension(:,:), intent(in) :: out_data
       real(wp), dimension(:,:), intent(out) :: probabilities
       real(wp) :: prob_sum
 
diff --git a/src/ftorch.f90 b/src/ftorch.f90
index 3412320a..3e18df29 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -1,305 +1,991 @@
 module ftorch
 
-   use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
-                                          c_float, c_double, c_char, c_ptr, c_null_ptr
-   implicit none
-
-   type torch_module
-      type(c_ptr) :: p = c_null_ptr
-   end type torch_module
-
-   type torch_tensor
-      type(c_ptr) :: p = c_null_ptr
-   end type torch_tensor
-
-   ! From c_torch.h (torch_data_t)
-   enum, bind(c)
-      enumerator :: torch_kUInt8 = 0
-      enumerator :: torch_kInt8 = 1
-      enumerator :: torch_kInt16 = 2
-      enumerator :: torch_kInt32 = 3
-      enumerator :: torch_kInt64 = 4
-      enumerator :: torch_kFloat16 = 5
-      enumerator :: torch_kFloat32 = 6
-      enumerator :: torch_kFloat64 = 7
-   end enum
-
-   ! From c_torch.h (torch_device_t)
-   enum, bind(c)
-      enumerator :: torch_kCPU = 0
-      enumerator :: torch_kCUDA = 1
-   end enum
-
-   ! Interface for calculating tensor from array for different possible input types
-   interface torch_tensor_from_array
-      module procedure torch_tensor_from_array_c_float
-      module procedure torch_tensor_from_array_c_double
-      ! module procedure torch_tensor_from_array_c_int8_t
-      ! module procedure torch_tensor_from_array_c_int16_t
-      ! module procedure torch_tensor_from_array_c_int32_t
-      ! module procedure torch_tensor_from_array_c_int64_t
-   end interface
+  use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
+    c_float, c_double, c_char, c_ptr, c_null_ptr
+  use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64
+  implicit none
+
+  type torch_module
+    type(c_ptr) :: p = c_null_ptr
+  end type torch_module
+
+  type torch_tensor
+    type(c_ptr) :: p = c_null_ptr
+  end type torch_tensor
+
+  ! From c_torch.h (torch_data_t)
+  enum, bind(c)
+    enumerator :: torch_kUInt8 = 0 ! not supported in fortran
+    enumerator :: torch_kInt8 = 1
+    enumerator :: torch_kInt16 = 2
+    enumerator :: torch_kInt32 = 3
+    enumerator :: torch_kInt64 = 4
+    enumerator :: torch_kFloat16 = 5 ! not supported in fortran
+    enumerator :: torch_kFloat32 = 6
+    enumerator :: torch_kFloat64 = 7
+  end enum
+
+
+  ! From c_torch.h (torch_device_t)
+  enum, bind(c)
+    enumerator :: torch_kCPU = 0
+    enumerator :: torch_kCUDA = 1
+  end enum
+
+  ! Interface for calculating tensor from array for different possible input types
+  interface torch_tensor_from_array
+    module procedure torch_tensor_from_array_int8_1
+    module procedure torch_tensor_from_array_int8_2
+    module procedure torch_tensor_from_array_int8_3
+    module procedure torch_tensor_from_array_int8_4
+    module procedure torch_tensor_from_array_int16_1
+    module procedure torch_tensor_from_array_int16_2
+    module procedure torch_tensor_from_array_int16_3
+    module procedure torch_tensor_from_array_int16_4
+    module procedure torch_tensor_from_array_int32_1
+    module procedure torch_tensor_from_array_int32_2
+    module procedure torch_tensor_from_array_int32_3
+    module procedure torch_tensor_from_array_int32_4
+    module procedure torch_tensor_from_array_int64_1
+    module procedure torch_tensor_from_array_int64_2
+    module procedure torch_tensor_from_array_int64_3
+    module procedure torch_tensor_from_array_int64_4
+    module procedure torch_tensor_from_array_real32_1
+    module procedure torch_tensor_from_array_real32_2
+    module procedure torch_tensor_from_array_real32_3
+    module procedure torch_tensor_from_array_real32_4
+    module procedure torch_tensor_from_array_real64_1
+    module procedure torch_tensor_from_array_real64_2
+    module procedure torch_tensor_from_array_real64_3
+    module procedure torch_tensor_from_array_real64_4
+  end interface
+
+  interface
+    function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) &
+        bind(c, name='torch_from_blob')
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+      type(c_ptr), value, intent(in)    :: data
+      integer(c_int), value, intent(in) :: ndims
+      integer(c_int64_t), intent(in)    :: tensor_shape(*)
+      integer(c_int64_t), intent(in)    :: strides(*)
+      integer(c_int), value, intent(in) :: dtype
+      integer(c_int), value, intent(in) :: device
+      type(c_ptr)                       :: tensor_p
+    end function torch_from_blob_c
+  end interface
 
 contains
 
-   ! Torch Tensor API
-   !> Exposes the given data as a tensor without taking ownership of the original data.
-   !> This routine will take an (i, j, k) array and return an (k, j, i) tensor.
-   function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-      type(c_ptr), intent(in)        :: data       !! Pointer to data
-      integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
-      integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
-      integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-      integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA)
-      integer(c_int), intent(in)     :: layout(*)  !! Layout for strides for accessing data
-      type(torch_tensor)             :: tensor     !! Returned tensor
+  ! Torch Tensor API
+  !> Exposes the given data as a tensor without taking ownership of the original data.
+  !> This routine will take an (i, j, k) array and return an (k, j, i) tensor.
+  function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+    type(c_ptr), intent(in)        :: data       !! Pointer to data
+    integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
+    integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
+    integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: layout(*)  !! Layout for strides for accessing data
+    type(torch_tensor)             :: tensor     !! Returned tensor
+
+    integer(c_int)                 :: i          !! loop index
+    integer(c_int64_t)             :: strides(ndims) !! Strides for accessing data
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+    end do
+    tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
+  end function torch_tensor_from_blob
+
+  !> Returns a tensor filled with the scalar value 1.
+  function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+    integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
+    integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
+    integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    type(torch_tensor)             :: tensor     !! Returned tensor
+
+    interface
+      function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
+          bind(c, name='torch_ones')
+        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+        integer(c_int), value, intent(in) :: ndims
+        integer(c_int64_t), intent(in)    :: tensor_shape(*)
+        integer(c_int), value, intent(in) :: dtype
+        integer(c_int), value, intent(in) :: device
+        type(c_ptr)                       :: tensor
+      end function torch_ones_c
+    end interface
+
+    tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device)
+  end function torch_tensor_ones
+
+  !> Returns a tensor filled with the scalar value 0.
+  function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
+    integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
+    integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
+    integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    type(torch_tensor)             :: tensor     !! Returned tensor
+
+    interface
+      function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
+          bind(c, name='torch_zeros')
+        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+        integer(c_int), value, intent(in) :: ndims
+        integer(c_int64_t), intent(in)    :: tensor_shape(*)
+        integer(c_int), value, intent(in) :: dtype
+        integer(c_int), value, intent(in) :: device
+        type(c_ptr)                       :: tensor
+      end function torch_zeros_c
+    end interface
+
+    tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device)
+  end function torch_tensor_zeros
+
+  !> Prints the contents of a tensor.
+  subroutine torch_tensor_print(tensor)
+    type(torch_tensor), intent(in) :: tensor     !! Input tensor
+
+    interface
+      subroutine torch_tensor_print_c(tensor) &
+          bind(c, name='torch_tensor_print')
+        use, intrinsic :: iso_c_binding, only : c_ptr
+        type(c_ptr), value, intent(in) :: tensor
+      end subroutine torch_tensor_print_c
+    end interface
+
+    call torch_tensor_print_c(tensor%p)
+  end subroutine torch_tensor_print
+
+  !> Deallocates a tensor.
+  subroutine torch_tensor_delete(tensor)
+    type(torch_tensor), intent(in) :: tensor     !! Input tensor
+
+    interface
+      subroutine torch_tensor_delete_c(tensor) &
+          bind(c, name='torch_tensor_delete')
+        use, intrinsic :: iso_c_binding, only : c_ptr
+        type(c_ptr), value, intent(in) :: tensor
+      end subroutine torch_tensor_delete_c
+    end interface
+
+    call torch_tensor_delete_c(tensor%p)
+  end subroutine torch_tensor_delete
+
+  ! Torch Module API
+  !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
+  function torch_module_load(filename) result(module)
+    use, intrinsic :: iso_c_binding, only : c_null_char
+    character(*), intent(in) :: filename !! Filename of Torch Script module
+    type(torch_module)            :: module      !! Returned deserialized module
+
+    interface
+      function torch_jit_load_c(filename) result(module) &
+          bind(c, name='torch_jit_load')
+        use, intrinsic :: iso_c_binding, only : c_char, c_ptr
+        character(c_char), intent(in) :: filename(*)
+        type(c_ptr)                   :: module
+      end function torch_jit_load_c
+    end interface
+
+    ! Need to append c_null_char at end of filename
+    module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
+    end function torch_module_load
+
+    !> Performs a forward pass of the module with the input tensors
+    subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
+      use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
+      type(torch_module), intent(in) :: module        !! Module
+      type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
+      type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
+      integer(c_int) ::  n_inputs
 
-      integer(c_int)                 :: i          !! loop index
-      integer(c_int64_t)             :: strides(ndims) !! Strides for accessing data
+      integer :: i
+      type(c_ptr), dimension(n_inputs), target  :: input_ptrs
 
       interface
-         function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
-            bind(c, name='torch_from_blob')
-            use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-            type(c_ptr), value, intent(in)    :: data
-            integer(c_int), value, intent(in) :: ndims
-            integer(c_int64_t), intent(in)    :: tensor_shape(*)
-            integer(c_int64_t), intent(in)    :: strides(*)
-            integer(c_int), value, intent(in) :: dtype
-            integer(c_int), value, intent(in) :: device
-            type(c_ptr)                       :: tensor
-         end function torch_from_blob_c
+        subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
+            output_tensor) &
+            bind(c, name='torch_jit_module_forward')
+          use, intrinsic :: iso_c_binding, only : c_ptr, c_int
+          type(c_ptr), value, intent(in) :: module
+          type(c_ptr), value, intent(in) :: input_tensors
+          integer(c_int), value, intent(in) :: n_inputs
+          type(c_ptr), value, intent(in) :: output_tensor
+        end subroutine torch_jit_module_forward_c
       end interface
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+      ! Assign array of pointers to the input tensors
+      do i = 1, n_inputs
+        input_ptrs(i) = input_tensors(i)%p
       end do
-      tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
-   end function torch_tensor_from_blob
-
-   !> This routine will take an (i, j, k) array and return an (k, j, i) tensor
-   !> it is invoked from a set of interfaces `torch_tensor_from_array_dtype`
-   function t_t_from_array(data_arr, tensor_shape, dtype, device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc
-      type(c_ptr), intent(in)          :: data_arr       !! Pointer to data
-      integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
-      integer(c_int), intent(in)       :: dtype      !! Data type of the tensor
-      integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA)
-      type(torch_tensor)               :: tensor     !! Returned tensor
-
-      integer(c_int)                   :: i          !! loop index
-      integer(c_int64_t), allocatable  :: strides(:) !! Strides for accessing data
-      integer(c_int), allocatable      :: layout(:)  !! Layout for strides for accessing data
-      integer(c_int)                   :: ndims      !! Number of dimensions of the tensor
+
+      call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
+    end subroutine torch_module_forward
+
+    !> Deallocates a Torch Script module
+    subroutine torch_module_delete(module)
+      type(torch_module), intent(in) :: module     !! Module
 
       interface
-         function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
-            bind(c, name='torch_from_blob')
-            use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-            type(c_ptr), value, intent(in)    :: data
-            integer(c_int), value, intent(in) :: ndims
-            integer(c_int64_t), intent(in)    :: tensor_shape(*)
-            integer(c_int64_t), intent(in)    :: strides(*)
-            integer(c_int), value, intent(in) :: dtype
-            integer(c_int), value, intent(in) :: device
-            type(c_ptr)                       :: tensor
-         end function torch_from_blob_c
+        subroutine torch_jit_module_delete_c(module) &
+            bind(c, name='torch_jit_module_delete')
+          use, intrinsic :: iso_c_binding, only : c_ptr
+          type(c_ptr), value, intent(in) :: module
+        end subroutine torch_jit_module_delete_c
       end interface
 
-      ndims = size(tensor_shape)
+      call torch_jit_module_delete_c(module%p)
+    end subroutine torch_module_delete
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int8_1(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int8
+
+      ! inputs
+      integer(kind=int8), intent(in), target :: data_in(:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(1) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
 
-      allocate(strides(ndims))
-      allocate(layout(ndims))
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
+      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
+      integer                   :: i
 
-      ! Fortran Layout
-      do i=1, ndims
-          layout(i) = i
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
       end do
 
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int8_1
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int8_2(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int8
+
+      ! inputs
+      integer(kind=int8), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(2) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
+      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
       strides(layout(1)) = 1
       do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
       end do
 
-      tensor%p = torch_from_blob_c(data_arr, ndims, tensor_shape, strides, dtype, device)
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      deallocate(strides)
-      deallocate(layout)
+    end function torch_tensor_from_array_int8_2
 
-   end function t_t_from_array
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int8_3(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int8
 
-   !> Returns a tensor filled with the scalar value 1.
-   function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
-      integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
-      integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
-      integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-      integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA)
-      type(torch_tensor)             :: tensor     !! Returned tensor
+      ! inputs
+      integer(kind=int8), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(3) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
 
-      interface
-         function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
-            bind(c, name='torch_ones')
-            use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-            integer(c_int), value, intent(in) :: ndims
-            integer(c_int64_t), intent(in)    :: tensor_shape(*)
-            integer(c_int), value, intent(in) :: dtype
-            integer(c_int), value, intent(in) :: device
-            type(c_ptr)                       :: tensor
-         end function torch_ones_c
-      end interface
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_ones_c(ndims, tensor_shape, dtype, device)
-   end function torch_tensor_ones
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
+      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
+      integer                   :: i
 
-   !> Returns a tensor filled with the scalar value 0.
-   function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
-      integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
-      integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
-      integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-      integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA)
-      type(torch_tensor)             :: tensor     !! Returned tensor
+      c_tensor_shape = shape(data_in)
 
-      interface
-         function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
-            bind(c, name='torch_zeros')
-            use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-            integer(c_int), value, intent(in) :: ndims
-            integer(c_int64_t), intent(in)    :: tensor_shape(*)
-            integer(c_int), value, intent(in) :: dtype
-            integer(c_int), value, intent(in) :: device
-            type(c_ptr)                       :: tensor
-         end function torch_zeros_c
-      end interface
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
 
-      tensor%p = torch_zeros_c(ndims, tensor_shape, dtype, device)
-   end function torch_tensor_zeros
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-   !> Prints the contents of a tensor.
-   subroutine torch_tensor_print(tensor)
-      type(torch_tensor), intent(in) :: tensor     !! Input tensor
+    end function torch_tensor_from_array_int8_3
 
-      interface
-         subroutine torch_tensor_print_c(tensor) &
-            bind(c, name='torch_tensor_print')
-            use, intrinsic :: iso_c_binding, only : c_ptr
-            type(c_ptr), value, intent(in) :: tensor
-         end subroutine torch_tensor_print_c
-      end interface
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int8_4(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int8
 
-      call torch_tensor_print_c(tensor%p)
-   end subroutine torch_tensor_print
+      ! inputs
+      integer(kind=int8), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(4) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
 
-   !> Deallocates a tensor.
-   subroutine torch_tensor_delete(tensor)
-      type(torch_tensor), intent(in) :: tensor     !! Input tensor
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
 
-      interface
-         subroutine torch_tensor_delete_c(tensor) &
-            bind(c, name='torch_tensor_delete')
-            use, intrinsic :: iso_c_binding, only : c_ptr
-            type(c_ptr), value, intent(in) :: tensor
-         end subroutine torch_tensor_delete_c
-      end interface
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
+      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
+      integer                   :: i
 
-      call torch_tensor_delete_c(tensor%p)
-   end subroutine torch_tensor_delete
+      c_tensor_shape = shape(data_in)
 
-   ! Torch Module API
-   !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
-   function torch_module_load(filename) result(module)
-      use, intrinsic :: iso_c_binding, only : c_null_char
-      character(*), intent(in) :: filename !! Filename of Torch Script module
-      type(torch_module)            :: module      !! Returned deserialized module
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
 
-      interface
-         function torch_jit_load_c(filename) result(module) &
-            bind(c, name='torch_jit_load')
-            use, intrinsic :: iso_c_binding, only : c_char, c_ptr
-            character(c_char), intent(in) :: filename(*)
-            type(c_ptr)                   :: module
-         end function torch_jit_load_c
-      end interface
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! Need to append c_null_char at end of filename
-      module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
-   end function torch_module_load
+    end function torch_tensor_from_array_int8_4
 
-   !> Performs a forward pass of the module with the input tensors
-   subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
-      use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
-      type(torch_module), intent(in) :: module        !! Module
-      type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
-      type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
-      integer(c_int) ::  n_inputs
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int16_1(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int16
 
-      integer :: i
-      type(c_ptr), dimension(n_inputs), target  :: input_ptrs
+      ! inputs
+      integer(kind=int16), intent(in), target :: data_in(:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(1) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
 
-      interface
-         subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
-                                                        output_tensor) &
-            bind(c, name='torch_jit_module_forward')
-            use, intrinsic :: iso_c_binding, only : c_ptr, c_int
-            type(c_ptr), value, intent(in) :: module
-            type(c_ptr), value, intent(in) :: input_tensors
-            integer(c_int), value, intent(in) :: n_inputs
-            type(c_ptr), value, intent(in) :: output_tensor
-         end subroutine torch_jit_module_forward_c
-      end interface
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
 
-      ! Assign array of pointers to the input tensors
-      do i = 1, n_inputs
-          input_ptrs(i) = input_tensors(i)%p
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
+      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
       end do
-      
-      call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
-   end subroutine torch_module_forward
 
-   !> Deallocates a Torch Script module
-   subroutine torch_module_delete(module)
-      type(torch_module), intent(in) :: module     !! Module
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      interface
-         subroutine torch_jit_module_delete_c(module) &
-            bind(c, name='torch_jit_module_delete')
-            use, intrinsic :: iso_c_binding, only : c_ptr
-            type(c_ptr), value, intent(in) :: module
-         end subroutine torch_jit_module_delete_c
-      end interface
+    end function torch_tensor_from_array_int16_1
 
-      call torch_jit_module_delete_c(module%p)
-   end subroutine torch_module_delete
-
-   ! Series of interface functions
-   function torch_tensor_from_array_c_double(data_arr, tensor_shape, device) result(tensor)
-   !function torch_tensor_from_array_c_double(data_arr, tensor_shape) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc
-      real(c_double), intent(in), target :: data_arr(*)   !! Fortran array of data
-      ! real(c_double), intent(in), target :: data_arr(*)   !! Fortran array of data
-      integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
-      integer(c_int), parameter :: dtype = torch_kFloat64
-      integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA)
-      type(torch_tensor)               :: tensor     !! Returned tensor
-     
-      
-      tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device)
-
-   end function torch_tensor_from_array_c_double
-
-   function torch_tensor_from_array_c_float(data_arr, tensor_shape, device) result(tensor)
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int16_2(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int16
+
+      ! inputs
+      integer(kind=int16), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(2) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
+      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int16_2
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int16_3(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int16
+
+      ! inputs
+      integer(kind=int16), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(3) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
+      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int16_3
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int16_4(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int16
+
+      ! inputs
+      integer(kind=int16), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(4) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
+      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int16_4
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int32_1(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int32
+
+      ! inputs
+      integer(kind=int32), intent(in), target :: data_in(:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(1) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
+      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int32_1
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int32_2(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int32
+
+      ! inputs
+      integer(kind=int32), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(2) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
+      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int32_2
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int32_3(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int32
+
+      ! inputs
+      integer(kind=int32), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(3) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
+      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int32_3
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int32_4(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int32
+
+      ! inputs
+      integer(kind=int32), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(4) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
+      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int32_4
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int64_1(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int64
+
+      ! inputs
+      integer(kind=int64), intent(in), target :: data_in(:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(1) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
+      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int64_1
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int64_2(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int64
+
+      ! inputs
+      integer(kind=int64), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(2) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
+      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int64_2
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int64_3(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int64
+
+      ! inputs
+      integer(kind=int64), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(3) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
+      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int64_3
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_int64_4(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : int64
+
+      ! inputs
+      integer(kind=int64), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(4) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
+      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_int64_4
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real32_1(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      real(c_float), intent(in), target :: data_arr(*)   !! Fortran array of data
-      integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
-      integer(c_int), parameter :: dtype = torch_kFloat32
-      integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kCUDA)
-      type(torch_tensor)               :: tensor     !! Returned tensor
-     
-     tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device)
+      use, intrinsic :: iso_fortran_env, only : real32
+
+      ! inputs
+      real(kind=real32), intent(in), target :: data_in(:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(1) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
+      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real32_1
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real32_2(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real32
+
+      ! inputs
+      real(kind=real32), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(2) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
+      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real32_2
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real32_3(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real32
+
+      ! inputs
+      real(kind=real32), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(3) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
+      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real32_3
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real32_4(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real32
+
+      ! inputs
+      real(kind=real32), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(4) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
+      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real32_4
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real64_1(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real64
+
+      ! inputs
+      real(kind=real64), intent(in), target :: data_in(:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(1) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
+      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real64_1
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real64_2(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real64
+
+      ! inputs
+      real(kind=real64), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(2) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
+      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real64_2
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real64_3(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real64
+
+      ! inputs
+      real(kind=real64), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(3) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
+      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real64_3
+
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_real64_4(data_in, layout, c_device) result(tensor)
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+      use, intrinsic :: iso_fortran_env, only : real64
+
+      ! inputs
+      real(kind=real64), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
+      integer, intent(in)        :: layout(4) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
+      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
+
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+    end function torch_tensor_from_array_real64_4
 
-   end function torch_tensor_from_array_c_float
 
-end module ftorch
+  end module ftorch
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index 93a7b4be..790623ef 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -1,16 +1,25 @@
 #:def ranksuffix(RANK)
 $:'' if RANK == 0 else '(' + ':' + ',:' * (RANK - 1) + ')'
 #:enddef ranksuffix
-#:set PRECISIONS = ['float', 'double']
-#:set RANKS = range(1, 3)
-#:set ENUMS = dict(zip(PRECISIONS, ['torch_kFloat32', 'torch_kFloat64']))
+#:set PRECISIONS = ['int8', 'int16', 'int32', 'int64', 'real32', 'real64']
+#:set C_PRECISIONS = ['c_int8_t', 'c_int16_t', 'c_int32_t', 'c_int64_t', 'c_float', 'c_double']
+#:set C_PRECISIONS = dict(zip(PRECISIONS, C_PRECISIONS))
+#:set ENUMS = dict(zip(PRECISIONS, ['torch_kInt8', 'torch_kInt16', 'torch_kInt32', 'torch_kInt64', 'torch_kFloat32', 'torch_kFloat64']))
+#:set RANKS = range(1, 5)
 #:def enum_from_prec(PRECISION)
 $:ENUMS[PRECISION]
 #:enddef enum_from_prec
+#:def c_prec(PRECISION)
+$:C_PRECISIONS[PRECISION]
+#:enddef c_prec
+#:def f_type(PRECISION)
+$:'integer' if PRECISION[:3] == 'int' else 'real'
+#:enddef f_type
 module ftorch
 
   use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
     c_float, c_double, c_char, c_ptr, c_null_ptr
+  use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64
   implicit none
 
   type torch_module
@@ -23,16 +32,17 @@ module ftorch
 
   ! From c_torch.h (torch_data_t)
   enum, bind(c)
-    enumerator :: torch_kUInt8 = 0
+    enumerator :: torch_kUInt8 = 0 ! not supported in fortran
     enumerator :: torch_kInt8 = 1
     enumerator :: torch_kInt16 = 2
     enumerator :: torch_kInt32 = 3
     enumerator :: torch_kInt64 = 4
-    enumerator :: torch_kFloat16 = 5
+    enumerator :: torch_kFloat16 = 5 ! not supported in fortran
     enumerator :: torch_kFloat32 = 6
     enumerator :: torch_kFloat64 = 7
   end enum
 
+
   ! From c_torch.h (torch_device_t)
   enum, bind(c)
     enumerator :: torch_kCPU = 0
@@ -43,13 +53,23 @@ module ftorch
   interface torch_tensor_from_array
     #:for PREC in PRECISIONS
     #:for RANK in RANKS
-    module procedure torch_tensor_from_array_c_${PREC}$_${RANK}$
+    module procedure torch_tensor_from_array_${PREC}$_${RANK}$
     #:endfor
     #:endfor
-    ! module procedure torch_tensor_from_array_c_int8_t
-    ! module procedure torch_tensor_from_array_c_int16_t
-    ! module procedure torch_tensor_from_array_c_int32_t
-    ! module procedure torch_tensor_from_array_c_int64_t
+  end interface
+
+  interface
+    function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) &
+        bind(c, name='torch_from_blob')
+      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+      type(c_ptr), value, intent(in)    :: data
+      integer(c_int), value, intent(in) :: ndims
+      integer(c_int64_t), intent(in)    :: tensor_shape(*)
+      integer(c_int64_t), intent(in)    :: strides(*)
+      integer(c_int), value, intent(in) :: dtype
+      integer(c_int), value, intent(in) :: device
+      type(c_ptr)                       :: tensor_p
+    end function torch_from_blob_c
   end interface
 
 contains
@@ -70,20 +90,6 @@ contains
     integer(c_int)                 :: i          !! loop index
     integer(c_int64_t)             :: strides(ndims) !! Strides for accessing data
 
-    interface
-      function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
-          bind(c, name='torch_from_blob')
-        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-        type(c_ptr), value, intent(in)    :: data
-        integer(c_int), value, intent(in) :: ndims
-        integer(c_int64_t), intent(in)    :: tensor_shape(*)
-        integer(c_int64_t), intent(in)    :: strides(*)
-        integer(c_int), value, intent(in) :: dtype
-        integer(c_int), value, intent(in) :: device
-        type(c_ptr)                       :: tensor
-      end function torch_from_blob_c
-    end interface
-
     strides(layout(1)) = 1
     do i = 2, ndims
       strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
@@ -91,57 +97,6 @@ contains
     tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
   end function torch_tensor_from_blob
 
-  !> This routine will take an (i, j, k) array and return an (k, j, i) tensor
-  !> it is invoked from a set of interfaces `torch_tensor_from_array_dtype`
-  function t_t_from_array(data_arr, tensor_shape, dtype, device) result(tensor)
-    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_double, c_loc
-    type(c_ptr), intent(in)          :: data_arr       !! Pointer to data
-    integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
-    integer(c_int), intent(in)       :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
-    type(torch_tensor)               :: tensor     !! Returned tensor
-
-    integer(c_int)                   :: i          !! loop index
-    integer(c_int64_t), allocatable  :: strides(:) !! Strides for accessing data
-    integer(c_int), allocatable      :: layout(:)  !! Layout for strides for accessing data
-    integer(c_int)                   :: ndims      !! Number of dimensions of the tensor
-
-    interface
-      function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor) &
-          bind(c, name='torch_from_blob')
-        use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
-        type(c_ptr), value, intent(in)    :: data
-        integer(c_int), value, intent(in) :: ndims
-        integer(c_int64_t), intent(in)    :: tensor_shape(*)
-        integer(c_int64_t), intent(in)    :: strides(*)
-        integer(c_int), value, intent(in) :: dtype
-        integer(c_int), value, intent(in) :: device
-        type(c_ptr)                       :: tensor
-      end function torch_from_blob_c
-    end interface
-
-    ndims = size(tensor_shape)
-
-    allocate(strides(ndims))
-    allocate(layout(ndims))
-
-    ! Fortran Layout
-    do i=1, ndims
-      layout(i) = i
-    end do
-
-    strides(layout(1)) = 1
-    do i = 2, ndims
-      strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
-    end do
-
-    tensor%p = torch_from_blob_c(data_arr, ndims, tensor_shape, strides, dtype, device)
-
-    deallocate(strides)
-    deallocate(layout)
-
-  end function t_t_from_array
-
   !> Returns a tensor filled with the scalar value 1.
   function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t
@@ -223,8 +178,8 @@ contains
   ! Torch Module API
   !> Loads a Torch Script module (pre-trained PyTorch model saved with Torch Script)
   function torch_module_load(filename) result(module)
-    use, intrinsic :: iso_c_binding, only : c_char
-    character(c_char), intent(in) :: filename(*) !! Filename of Torch Script module
+    use, intrinsic :: iso_c_binding, only : c_null_char
+    character(*), intent(in) :: filename !! Filename of Torch Script module
     type(torch_module)            :: module      !! Returned deserialized module
 
     interface
@@ -237,7 +192,7 @@ contains
     end interface
 
     ! Need to append c_null_char at end of filename
-    module%p = torch_jit_load_c(filename)
+    module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
     end function torch_module_load
 
     !> Performs a forward pass of the module with the input tensors
@@ -286,20 +241,38 @@ contains
       call torch_jit_module_delete_c(module%p)
     end subroutine torch_module_delete
 
-    ! Series of interface functions
     #:for PREC in PRECISIONS
     #:for RANK in RANKS
-    function torch_tensor_from_array_c_${PREC}$_${RANK}$(data_arr, tensor_shape, device) result(tensor)
+    !> return a torch tensor pointing to data_in array
+    function torch_tensor_from_array_${PREC}$_${RANK}$(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      real(c_${PREC}$), intent(in), target :: data_arr${ranksuffix(RANK)}$   !! Fortran array of data
-      integer(c_int64_t), intent(in)   :: tensor_shape(:)   !! Shape of the tensor
-      integer(c_int), parameter :: dtype = ${enum_from_prec(PREC)}$
-      integer(c_int), intent(in)       :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
-      type(torch_tensor)               :: tensor     !! Returned tensor
+      use, intrinsic :: iso_fortran_env, only : ${PREC}$
+
+      ! inputs
+      ${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$   !! input data that tensor will point at
+      integer, intent(in)        :: layout(${RANK}$) !! control order of indices
+      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+
+      ! output tensory
+      type(torch_tensor) :: tensor     !! Returned tensor
+
+      ! local data
+      integer(c_int64_t)        :: c_tensor_shape(${RANK}$)           !! Shape of the tensor
+      integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! data type
+      integer(c_int64_t)        :: strides(${RANK}$)                  !! Strides for accessing data
+      integer(c_int), parameter :: ndims = ${RANK}$                   !! number of dimension of input data
+      integer                   :: i
+
+      c_tensor_shape = shape(data_in)
+
+      strides(layout(1)) = 1
+      do i = 2, ndims
+        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
+      end do
 
-      tensor = t_t_from_array(c_loc(data_arr), tensor_shape, dtype, device)
+      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_c_${PREC}$_${RANK}$
+    end function torch_tensor_from_array_${PREC}$_${RANK}$
 
     #:endfor
     #:endfor

From 2f08e98644dcad0f5684901d76a7144ecb36030e Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 14:37:17 +0000
Subject: [PATCH 03/20] Add  suffix to fypp functions to show dimensionality.

---
 src/ftorch.f90  | 144 ++++++++++++++++++++++++------------------------
 src/ftorch.fypp |   6 +-
 2 files changed, 75 insertions(+), 75 deletions(-)

diff --git a/src/ftorch.f90 b/src/ftorch.f90
index 3e18df29..752e5e51 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -34,30 +34,30 @@ module ftorch
 
   ! Interface for calculating tensor from array for different possible input types
   interface torch_tensor_from_array
-    module procedure torch_tensor_from_array_int8_1
-    module procedure torch_tensor_from_array_int8_2
-    module procedure torch_tensor_from_array_int8_3
-    module procedure torch_tensor_from_array_int8_4
-    module procedure torch_tensor_from_array_int16_1
-    module procedure torch_tensor_from_array_int16_2
-    module procedure torch_tensor_from_array_int16_3
-    module procedure torch_tensor_from_array_int16_4
-    module procedure torch_tensor_from_array_int32_1
-    module procedure torch_tensor_from_array_int32_2
-    module procedure torch_tensor_from_array_int32_3
-    module procedure torch_tensor_from_array_int32_4
-    module procedure torch_tensor_from_array_int64_1
-    module procedure torch_tensor_from_array_int64_2
-    module procedure torch_tensor_from_array_int64_3
-    module procedure torch_tensor_from_array_int64_4
-    module procedure torch_tensor_from_array_real32_1
-    module procedure torch_tensor_from_array_real32_2
-    module procedure torch_tensor_from_array_real32_3
-    module procedure torch_tensor_from_array_real32_4
-    module procedure torch_tensor_from_array_real64_1
-    module procedure torch_tensor_from_array_real64_2
-    module procedure torch_tensor_from_array_real64_3
-    module procedure torch_tensor_from_array_real64_4
+    module procedure torch_tensor_from_array_int8_1d
+    module procedure torch_tensor_from_array_int8_2d
+    module procedure torch_tensor_from_array_int8_3d
+    module procedure torch_tensor_from_array_int8_4d
+    module procedure torch_tensor_from_array_int16_1d
+    module procedure torch_tensor_from_array_int16_2d
+    module procedure torch_tensor_from_array_int16_3d
+    module procedure torch_tensor_from_array_int16_4d
+    module procedure torch_tensor_from_array_int32_1d
+    module procedure torch_tensor_from_array_int32_2d
+    module procedure torch_tensor_from_array_int32_3d
+    module procedure torch_tensor_from_array_int32_4d
+    module procedure torch_tensor_from_array_int64_1d
+    module procedure torch_tensor_from_array_int64_2d
+    module procedure torch_tensor_from_array_int64_3d
+    module procedure torch_tensor_from_array_int64_4d
+    module procedure torch_tensor_from_array_real32_1d
+    module procedure torch_tensor_from_array_real32_2d
+    module procedure torch_tensor_from_array_real32_3d
+    module procedure torch_tensor_from_array_real32_4d
+    module procedure torch_tensor_from_array_real64_1d
+    module procedure torch_tensor_from_array_real64_2d
+    module procedure torch_tensor_from_array_real64_3d
+    module procedure torch_tensor_from_array_real64_4d
   end interface
 
   interface
@@ -244,7 +244,7 @@ end subroutine torch_jit_module_delete_c
     end subroutine torch_module_delete
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_1(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int8_1d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int8
 
@@ -272,10 +272,10 @@ function torch_tensor_from_array_int8_1(data_in, layout, c_device) result(tensor
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int8_1
+    end function torch_tensor_from_array_int8_1d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_2(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int8_2d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int8
 
@@ -303,10 +303,10 @@ function torch_tensor_from_array_int8_2(data_in, layout, c_device) result(tensor
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int8_2
+    end function torch_tensor_from_array_int8_2d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_3(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int8_3d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int8
 
@@ -334,10 +334,10 @@ function torch_tensor_from_array_int8_3(data_in, layout, c_device) result(tensor
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int8_3
+    end function torch_tensor_from_array_int8_3d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_4(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int8_4d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int8
 
@@ -365,10 +365,10 @@ function torch_tensor_from_array_int8_4(data_in, layout, c_device) result(tensor
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int8_4
+    end function torch_tensor_from_array_int8_4d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_1(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int16_1d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int16
 
@@ -396,10 +396,10 @@ function torch_tensor_from_array_int16_1(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int16_1
+    end function torch_tensor_from_array_int16_1d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_2(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int16_2d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int16
 
@@ -427,10 +427,10 @@ function torch_tensor_from_array_int16_2(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int16_2
+    end function torch_tensor_from_array_int16_2d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_3(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int16_3d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int16
 
@@ -458,10 +458,10 @@ function torch_tensor_from_array_int16_3(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int16_3
+    end function torch_tensor_from_array_int16_3d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_4(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int16_4d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int16
 
@@ -489,10 +489,10 @@ function torch_tensor_from_array_int16_4(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int16_4
+    end function torch_tensor_from_array_int16_4d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_1(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int32_1d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int32
 
@@ -520,10 +520,10 @@ function torch_tensor_from_array_int32_1(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int32_1
+    end function torch_tensor_from_array_int32_1d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_2(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int32_2d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int32
 
@@ -551,10 +551,10 @@ function torch_tensor_from_array_int32_2(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int32_2
+    end function torch_tensor_from_array_int32_2d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_3(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int32_3d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int32
 
@@ -582,10 +582,10 @@ function torch_tensor_from_array_int32_3(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int32_3
+    end function torch_tensor_from_array_int32_3d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_4(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int32_4d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int32
 
@@ -613,10 +613,10 @@ function torch_tensor_from_array_int32_4(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int32_4
+    end function torch_tensor_from_array_int32_4d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_1(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int64_1d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int64
 
@@ -644,10 +644,10 @@ function torch_tensor_from_array_int64_1(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int64_1
+    end function torch_tensor_from_array_int64_1d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_2(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int64_2d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int64
 
@@ -675,10 +675,10 @@ function torch_tensor_from_array_int64_2(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int64_2
+    end function torch_tensor_from_array_int64_2d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_3(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int64_3d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int64
 
@@ -706,10 +706,10 @@ function torch_tensor_from_array_int64_3(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int64_3
+    end function torch_tensor_from_array_int64_3d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_4(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_int64_4d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : int64
 
@@ -737,10 +737,10 @@ function torch_tensor_from_array_int64_4(data_in, layout, c_device) result(tenso
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_int64_4
+    end function torch_tensor_from_array_int64_4d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_1(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real32_1d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real32
 
@@ -768,10 +768,10 @@ function torch_tensor_from_array_real32_1(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real32_1
+    end function torch_tensor_from_array_real32_1d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_2(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real32_2d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real32
 
@@ -799,10 +799,10 @@ function torch_tensor_from_array_real32_2(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real32_2
+    end function torch_tensor_from_array_real32_2d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_3(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real32_3d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real32
 
@@ -830,10 +830,10 @@ function torch_tensor_from_array_real32_3(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real32_3
+    end function torch_tensor_from_array_real32_3d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_4(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real32_4d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real32
 
@@ -861,10 +861,10 @@ function torch_tensor_from_array_real32_4(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real32_4
+    end function torch_tensor_from_array_real32_4d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_1(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real64_1d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real64
 
@@ -892,10 +892,10 @@ function torch_tensor_from_array_real64_1(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_1
+    end function torch_tensor_from_array_real64_1d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_2(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real64_2d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real64
 
@@ -923,10 +923,10 @@ function torch_tensor_from_array_real64_2(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_2
+    end function torch_tensor_from_array_real64_2d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_3(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real64_3d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real64
 
@@ -954,10 +954,10 @@ function torch_tensor_from_array_real64_3(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_3
+    end function torch_tensor_from_array_real64_3d
 
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_4(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_real64_4d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : real64
 
@@ -985,7 +985,7 @@ function torch_tensor_from_array_real64_4(data_in, layout, c_device) result(tens
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_4
+    end function torch_tensor_from_array_real64_4d
 
 
   end module ftorch
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index 790623ef..147552a6 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -53,7 +53,7 @@ module ftorch
   interface torch_tensor_from_array
     #:for PREC in PRECISIONS
     #:for RANK in RANKS
-    module procedure torch_tensor_from_array_${PREC}$_${RANK}$
+    module procedure torch_tensor_from_array_${PREC}$_${RANK}$d
     #:endfor
     #:endfor
   end interface
@@ -244,7 +244,7 @@ contains
     #:for PREC in PRECISIONS
     #:for RANK in RANKS
     !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_${PREC}$_${RANK}$(data_in, layout, c_device) result(tensor)
+    function torch_tensor_from_array_${PREC}$_${RANK}$d(data_in, layout, c_device) result(tensor)
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
       use, intrinsic :: iso_fortran_env, only : ${PREC}$
 
@@ -272,7 +272,7 @@ contains
 
       tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_${PREC}$_${RANK}$
+    end function torch_tensor_from_array_${PREC}$_${RANK}$d
 
     #:endfor
     #:endfor

From 2c471db8e53039a5419bb54a84b2ab0a985f20f6 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 15:43:15 +0000
Subject: [PATCH 04/20] Tidy some of fypp file and add some documentation.

---
 src/ftorch.f90  | 1296 ++++++++++++++++++++++++-----------------------
 src/ftorch.fypp |  208 ++++----
 2 files changed, 768 insertions(+), 736 deletions(-)

diff --git a/src/ftorch.f90 b/src/ftorch.f90
index 752e5e51..d7432801 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -1,38 +1,52 @@
+!| Main module for FTorch containing types and procedures.
+!  Generated from `ftorch.fypp` using the [fypp Fortran preprocessor](https://fypp.readthedocs.io/en/stable/index.html).
+!
+!  * License  
+!    FTorch is released under an MIT license.
+!    See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
+!    file for details.
+
 module ftorch
 
   use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
-    c_float, c_double, c_char, c_ptr, c_null_ptr
+                                         c_float, c_double, c_char, c_ptr, c_null_ptr
   use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64
+
   implicit none
 
+  !> Type for holding a torch neural net (nn.Module).
   type torch_module
-    type(c_ptr) :: p = c_null_ptr
+    type(c_ptr) :: p = c_null_ptr  !! pointer to the neural net module in memory
   end type torch_module
 
+  !> Type for holding a Torch tensor.
   type torch_tensor
-    type(c_ptr) :: p = c_null_ptr
+    type(c_ptr) :: p = c_null_ptr  !! pointer to the tensor in memory
   end type torch_tensor
 
-  ! From c_torch.h (torch_data_t)
+  !| Enumerator for Torch data types
+  !  From c_torch.h (torch_data_t)
+  !  Note that torch_kUInt8 and torch_kFloat16 are not sypported in Fortran
   enum, bind(c)
-    enumerator :: torch_kUInt8 = 0 ! not supported in fortran
+    enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
     enumerator :: torch_kInt8 = 1
     enumerator :: torch_kInt16 = 2
     enumerator :: torch_kInt32 = 3
     enumerator :: torch_kInt64 = 4
-    enumerator :: torch_kFloat16 = 5 ! not supported in fortran
+    enumerator :: torch_kFloat16 = 5 ! not supported in Fortran
     enumerator :: torch_kFloat32 = 6
     enumerator :: torch_kFloat64 = 7
   end enum
 
 
-  ! From c_torch.h (torch_device_t)
+  !| Enumerator for Torch devices
+  !  From c_torch.h (torch_device_t)
   enum, bind(c)
     enumerator :: torch_kCPU = 0
     enumerator :: torch_kCUDA = 1
   end enum
 
-  ! Interface for calculating tensor from array for different possible input types
+  !> Interface for directing `torch_tensor_from_array` to possible input types and ranks
   interface torch_tensor_from_array
     module procedure torch_tensor_from_array_int8_1d
     module procedure torch_tensor_from_array_int8_2d
@@ -62,8 +76,10 @@ module ftorch
 
   interface
     function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) &
-        bind(c, name='torch_from_blob')
+                               bind(c, name = 'torch_from_blob')
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+
+      ! Arguments
       type(c_ptr), value, intent(in)    :: data
       integer(c_int), value, intent(in) :: ndims
       integer(c_int64_t), intent(in)    :: tensor_shape(*)
@@ -77,15 +93,15 @@ end function torch_from_blob_c
 contains
 
   ! Torch Tensor API
-  !> Exposes the given data as a tensor without taking ownership of the original data.
-  !> This routine will take an (i, j, k) array and return an (k, j, i) tensor.
+  !| Exposes the given data as a tensor without taking ownership of the original data.
+  !  This routine will take an (i, j, k) array and return an (k, j, i) tensor.
   function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
     type(c_ptr), intent(in)        :: data       !! Pointer to data
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
     integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
     integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
     integer(c_int), intent(in)     :: layout(*)  !! Layout for strides for accessing data
     type(torch_tensor)             :: tensor     !! Returned tensor
 
@@ -94,7 +110,7 @@ function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout
 
     strides(layout(1)) = 1
     do i = 2, ndims
-      strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+      strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1))
     end do
     tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
   end function torch_tensor_from_blob
@@ -105,12 +121,12 @@ function torch_tensor_ones(ndims, tensor_shape, dtype, device) result(tensor)
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
     integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
     integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
     type(torch_tensor)             :: tensor     !! Returned tensor
 
     interface
       function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
-          bind(c, name='torch_ones')
+          bind(c, name = 'torch_ones')
         use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
         integer(c_int), value, intent(in) :: ndims
         integer(c_int64_t), intent(in)    :: tensor_shape(*)
@@ -129,12 +145,12 @@ function torch_tensor_zeros(ndims, tensor_shape, dtype, device) result(tensor)
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
     integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
     integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
     type(torch_tensor)             :: tensor     !! Returned tensor
 
     interface
       function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
-          bind(c, name='torch_zeros')
+          bind(c, name = 'torch_zeros')
         use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
         integer(c_int), value, intent(in) :: ndims
         integer(c_int64_t), intent(in)    :: tensor_shape(*)
@@ -149,11 +165,11 @@ end function torch_tensor_zeros
 
   !> Prints the contents of a tensor.
   subroutine torch_tensor_print(tensor)
-    type(torch_tensor), intent(in) :: tensor     !! Input tensor
+    type(torch_tensor), intent(in) :: tensor  !! Input tensor
 
     interface
       subroutine torch_tensor_print_c(tensor) &
-          bind(c, name='torch_tensor_print')
+          bind(c, name = 'torch_tensor_print')
         use, intrinsic :: iso_c_binding, only : c_ptr
         type(c_ptr), value, intent(in) :: tensor
       end subroutine torch_tensor_print_c
@@ -168,7 +184,7 @@ subroutine torch_tensor_delete(tensor)
 
     interface
       subroutine torch_tensor_delete_c(tensor) &
-          bind(c, name='torch_tensor_delete')
+          bind(c, name = 'torch_tensor_delete')
         use, intrinsic :: iso_c_binding, only : c_ptr
         type(c_ptr), value, intent(in) :: tensor
       end subroutine torch_tensor_delete_c
@@ -186,7 +202,7 @@ function torch_module_load(filename) result(module)
 
     interface
       function torch_jit_load_c(filename) result(module) &
-          bind(c, name='torch_jit_load')
+          bind(c, name = 'torch_jit_load')
         use, intrinsic :: iso_c_binding, only : c_char, c_ptr
         character(c_char), intent(in) :: filename(*)
         type(c_ptr)                   :: module
@@ -195,797 +211,797 @@ end function torch_jit_load_c
 
     ! Need to append c_null_char at end of filename
     module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
-    end function torch_module_load
-
-    !> Performs a forward pass of the module with the input tensors
-    subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
-      use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
-      type(torch_module), intent(in) :: module        !! Module
-      type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
-      type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
-      integer(c_int) ::  n_inputs
-
-      integer :: i
-      type(c_ptr), dimension(n_inputs), target  :: input_ptrs
-
-      interface
-        subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
-            output_tensor) &
-            bind(c, name='torch_jit_module_forward')
-          use, intrinsic :: iso_c_binding, only : c_ptr, c_int
-          type(c_ptr), value, intent(in) :: module
-          type(c_ptr), value, intent(in) :: input_tensors
-          integer(c_int), value, intent(in) :: n_inputs
-          type(c_ptr), value, intent(in) :: output_tensor
-        end subroutine torch_jit_module_forward_c
-      end interface
-
-      ! Assign array of pointers to the input tensors
-      do i = 1, n_inputs
-        input_ptrs(i) = input_tensors(i)%p
-      end do
-
-      call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
-    end subroutine torch_module_forward
-
-    !> Deallocates a Torch Script module
-    subroutine torch_module_delete(module)
-      type(torch_module), intent(in) :: module     !! Module
-
-      interface
-        subroutine torch_jit_module_delete_c(module) &
-            bind(c, name='torch_jit_module_delete')
-          use, intrinsic :: iso_c_binding, only : c_ptr
-          type(c_ptr), value, intent(in) :: module
-        end subroutine torch_jit_module_delete_c
-      end interface
-
-      call torch_jit_module_delete_c(module%p)
-    end subroutine torch_module_delete
-
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_1d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int8
-
-      ! inputs
-      integer(kind=int8), intent(in), target :: data_in(:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(1) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
-
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
-
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
-      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
-      integer                   :: i
-
-      c_tensor_shape = shape(data_in)
-
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
-
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
-
-    end function torch_tensor_from_array_int8_1d
-
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_2d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int8
-
-      ! inputs
-      integer(kind=int8), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(2) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
-
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
-
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
-      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
-      integer                   :: i
-
-      c_tensor_shape = shape(data_in)
-
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
-
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
-
-    end function torch_tensor_from_array_int8_2d
-
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_3d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int8
-
-      ! inputs
-      integer(kind=int8), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(3) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
-
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_module_load
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
-      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
-      integer                   :: i
+  !> Performs a forward pass of the module with the input tensors
+  subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
+    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
+    type(torch_module), intent(in) :: module        !! Module
+    type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
+    type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
+    integer(c_int) ::  n_inputs
 
-      c_tensor_shape = shape(data_in)
+    integer :: i
+    type(c_ptr), dimension(n_inputs), target  :: input_ptrs
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    interface
+      subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
+          output_tensor) &
+          bind(c, name = 'torch_jit_module_forward')
+        use, intrinsic :: iso_c_binding, only : c_ptr, c_int
+        type(c_ptr), value, intent(in) :: module
+        type(c_ptr), value, intent(in) :: input_tensors
+        integer(c_int), value, intent(in) :: n_inputs
+        type(c_ptr), value, intent(in) :: output_tensor
+      end subroutine torch_jit_module_forward_c
+    end interface
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! Assign array of pointers to the input tensors
+    do i = 1, n_inputs
+      input_ptrs(i) = input_tensors(i)%p
+    end do
 
-    end function torch_tensor_from_array_int8_3d
+    call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
+  end subroutine torch_module_forward
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int8_4d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int8
+  !> Deallocates a Torch Script module
+  subroutine torch_module_delete(module)
+    type(torch_module), intent(in) :: module     !! Module to deallocate
 
-      ! inputs
-      integer(kind=int8), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(4) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    interface
+      subroutine torch_jit_module_delete_c(module) &
+          bind(c, name = 'torch_jit_module_delete')
+        use, intrinsic :: iso_c_binding, only : c_ptr
+        type(c_ptr), value, intent(in) :: module
+      end subroutine torch_jit_module_delete_c
+    end interface
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    call torch_jit_module_delete_c(module%p)
+  end subroutine torch_module_delete
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt8 !! data type
-      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int8`
+  function torch_tensor_from_array_int8_1d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int8
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int8), intent(in), target :: data_in(:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(1) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+    integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 1                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int8_4d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_1d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int16
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int16), intent(in), target :: data_in(:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(1) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int8_1d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
-      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int8`
+  function torch_tensor_from_array_int8_2d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int8
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int8), intent(in), target :: data_in(:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(2) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+    integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 2                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int16_1d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_2d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int16
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int16), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(2) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int8_2d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
-      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int8`
+  function torch_tensor_from_array_int8_3d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int8
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int8), intent(in), target :: data_in(:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(3) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+    integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 3                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int16_2d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_3d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int16
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int16), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(3) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int8_3d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
-      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int8`
+  function torch_tensor_from_array_int8_4d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int8
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int8), intent(in), target :: data_in(:,:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(4) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt8 !! Data type
+    integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 4                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int16_3d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int16_4d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int16
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int16), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(4) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int8_4d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt16 !! data type
-      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int16`
+  function torch_tensor_from_array_int16_1d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int16
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int16), intent(in), target :: data_in(:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(1) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+    integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 1                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int16_4d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_1d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int32), intent(in), target :: data_in(:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(1) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int16_1d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
-      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int16`
+  function torch_tensor_from_array_int16_2d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int16
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int16), intent(in), target :: data_in(:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(2) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+    integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 2                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int32_1d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_2d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int32), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(2) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int16_2d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
-      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int16`
+  function torch_tensor_from_array_int16_3d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int16
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int16), intent(in), target :: data_in(:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(3) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+    integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 3                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int32_2d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_3d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int32), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(3) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int16_3d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
-      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int16`
+  function torch_tensor_from_array_int16_4d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int16
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int16), intent(in), target :: data_in(:,:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(4) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt16 !! Data type
+    integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 4                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int32_3d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int32_4d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int32), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(4) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int16_4d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt32 !! data type
-      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int32`
+  function torch_tensor_from_array_int32_1d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int32
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int32), intent(in), target :: data_in(:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(1) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+    integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 1                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int32_4d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_1d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int64
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int64), intent(in), target :: data_in(:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(1) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int32_1d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
-      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int32`
+  function torch_tensor_from_array_int32_2d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int32
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int32), intent(in), target :: data_in(:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(2) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+    integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 2                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int64_1d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_2d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int64
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int64), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(2) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int32_2d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
-      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int32`
+  function torch_tensor_from_array_int32_3d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int32
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int32), intent(in), target :: data_in(:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(3) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+    integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 3                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int64_2d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_3d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int64
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int64), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(3) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int32_3d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
-      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int32`
+  function torch_tensor_from_array_int32_4d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int32
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int32), intent(in), target :: data_in(:,:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(4) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt32 !! Data type
+    integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 4                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int64_3d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_int64_4d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : int64
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      integer(kind=int64), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(4) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int32_4d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kInt64 !! data type
-      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `int64`
+  function torch_tensor_from_array_int64_1d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int64
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int64), intent(in), target :: data_in(:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(1) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+    integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 1                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_int64_4d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_1d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      real(kind=real32), intent(in), target :: data_in(:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(1) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int64_1d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
-      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `int64`
+  function torch_tensor_from_array_int64_2d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int64
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int64), intent(in), target :: data_in(:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(2) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+    integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 2                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_real32_1d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_2d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      real(kind=real32), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(2) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int64_2d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
-      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `int64`
+  function torch_tensor_from_array_int64_3d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int64
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int64), intent(in), target :: data_in(:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(3) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+    integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 3                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_real32_2d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_3d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      ! inputs
-      real(kind=real32), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(3) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+  end function torch_tensor_from_array_int64_3d
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
-      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
-      integer                   :: i
+  !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `int64`
+  function torch_tensor_from_array_int64_4d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : int64
 
-      c_tensor_shape = shape(data_in)
+    ! inputs
+    integer(kind=int64), intent(in), target :: data_in(:,:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(4) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kInt64 !! Data type
+    integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 4                   !! Number of dimension of input data
+    integer                   :: i
 
-    end function torch_tensor_from_array_real32_3d
+    c_tensor_shape = shape(data_in)
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real32_4d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real32
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
+
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+  end function torch_tensor_from_array_int64_4d
+
+  !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real32`
+  function torch_tensor_from_array_real32_1d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real32
+
+    ! inputs
+    real(kind=real32), intent(in), target :: data_in(:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(1) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
+
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+    integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 1                   !! Number of dimension of input data
+    integer                   :: i
+
+    c_tensor_shape = shape(data_in)
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
+
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+  end function torch_tensor_from_array_real32_1d
+
+  !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real32`
+  function torch_tensor_from_array_real32_2d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real32
+
+    ! inputs
+    real(kind=real32), intent(in), target :: data_in(:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(2) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
+
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+    integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 2                   !! Number of dimension of input data
+    integer                   :: i
+
+    c_tensor_shape = shape(data_in)
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
+
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+  end function torch_tensor_from_array_real32_2d
 
-      ! inputs
-      real(kind=real32), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(4) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+  !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real32`
+  function torch_tensor_from_array_real32_3d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real32
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    ! inputs
+    real(kind=real32), intent(in), target :: data_in(:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(3) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat32 !! data type
-      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
-      integer                   :: i
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      c_tensor_shape = shape(data_in)
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+    integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 3                   !! Number of dimension of input data
+    integer                   :: i
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    c_tensor_shape = shape(data_in)
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
+
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real32_4d
+  end function torch_tensor_from_array_real32_3d
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_1d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real64
+  !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real32`
+  function torch_tensor_from_array_real32_4d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real32
 
-      ! inputs
-      real(kind=real64), intent(in), target :: data_in(:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(1) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    ! inputs
+    real(kind=real32), intent(in), target :: data_in(:,:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(4) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
-      integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 1                   !! number of dimension of input data
-      integer                   :: i
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat32 !! Data type
+    integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 4                   !! Number of dimension of input data
+    integer                   :: i
 
-      c_tensor_shape = shape(data_in)
+    c_tensor_shape = shape(data_in)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_1d
+  end function torch_tensor_from_array_real32_4d
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_2d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real64
+  !> Return a Torch tensor pointing to data_in array of rank 1 containing data of type `real64`
+  function torch_tensor_from_array_real64_1d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real64
 
-      ! inputs
-      real(kind=real64), intent(in), target :: data_in(:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(2) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    ! inputs
+    real(kind=real64), intent(in), target :: data_in(:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(1) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
-      integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 2                   !! number of dimension of input data
-      integer                   :: i
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(1)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+    integer(c_int64_t)        :: strides(1)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 1                   !! Number of dimension of input data
+    integer                   :: i
 
-      c_tensor_shape = shape(data_in)
+    c_tensor_shape = shape(data_in)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_2d
+  end function torch_tensor_from_array_real64_1d
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_3d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real64
+  !> Return a Torch tensor pointing to data_in array of rank 2 containing data of type `real64`
+  function torch_tensor_from_array_real64_2d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real64
 
-      ! inputs
-      real(kind=real64), intent(in), target :: data_in(:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(3) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    ! inputs
+    real(kind=real64), intent(in), target :: data_in(:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(2) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
-      integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 3                   !! number of dimension of input data
-      integer                   :: i
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(2)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+    integer(c_int64_t)        :: strides(2)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 2                   !! Number of dimension of input data
+    integer                   :: i
 
-      c_tensor_shape = shape(data_in)
+    c_tensor_shape = shape(data_in)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_3d
+  end function torch_tensor_from_array_real64_2d
 
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_real64_4d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : real64
+  !> Return a Torch tensor pointing to data_in array of rank 3 containing data of type `real64`
+  function torch_tensor_from_array_real64_3d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real64
 
-      ! inputs
-      real(kind=real64), intent(in), target :: data_in(:,:,:,:)   !! input data that tensor will point at
-      integer, intent(in)        :: layout(4) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    ! inputs
+    real(kind=real64), intent(in), target :: data_in(:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(3) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = torch_kFloat64 !! data type
-      integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = 4                   !! number of dimension of input data
-      integer                   :: i
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(3)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+    integer(c_int64_t)        :: strides(3)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 3                   !! Number of dimension of input data
+    integer                   :: i
 
-      c_tensor_shape = shape(data_in)
+    c_tensor_shape = shape(data_in)
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
+
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+  end function torch_tensor_from_array_real64_3d
+
+  !> Return a Torch tensor pointing to data_in array of rank 4 containing data of type `real64`
+  function torch_tensor_from_array_real64_4d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : real64
+
+    ! inputs
+    real(kind=real64), intent(in), target :: data_in(:,:,:,:)   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(4) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
+
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(4)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = torch_kFloat64 !! Data type
+    integer(c_int64_t)        :: strides(4)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = 4                   !! Number of dimension of input data
+    integer                   :: i
+
+    c_tensor_shape = shape(data_in)
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
 
-    end function torch_tensor_from_array_real64_4d
+  end function torch_tensor_from_array_real64_4d
 
 
-  end module ftorch
+end module ftorch
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index 147552a6..0e7e9cfb 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -15,41 +15,55 @@ $:C_PRECISIONS[PRECISION]
 #:def f_type(PRECISION)
 $:'integer' if PRECISION[:3] == 'int' else 'real'
 #:enddef f_type
+!| Main module for FTorch containing types and procedures.
+!  Generated from `ftorch.fypp` using the [fypp Fortran preprocessor](https://fypp.readthedocs.io/en/stable/index.html).
+!
+!  * License  
+!    FTorch is released under an MIT license.
+!    See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
+!    file for details.
+
 module ftorch
 
   use, intrinsic :: iso_c_binding, only: c_int, c_int8_t, c_int16_t, c_int32_t, c_int64_t, c_int64_t, &
-    c_float, c_double, c_char, c_ptr, c_null_ptr
+                                         c_float, c_double, c_char, c_ptr, c_null_ptr
   use, intrinsic :: iso_fortran_env, only: int8, int16, int32, int64, real32, real64
+
   implicit none
 
+  !> Type for holding a torch neural net (nn.Module).
   type torch_module
-    type(c_ptr) :: p = c_null_ptr
+    type(c_ptr) :: p = c_null_ptr  !! pointer to the neural net module in memory
   end type torch_module
 
+  !> Type for holding a Torch tensor.
   type torch_tensor
-    type(c_ptr) :: p = c_null_ptr
+    type(c_ptr) :: p = c_null_ptr  !! pointer to the tensor in memory
   end type torch_tensor
 
-  ! From c_torch.h (torch_data_t)
+  !| Enumerator for Torch data types
+  !  From c_torch.h (torch_data_t)
+  !  Note that torch_kUInt8 and torch_kFloat16 are not sypported in Fortran
   enum, bind(c)
-    enumerator :: torch_kUInt8 = 0 ! not supported in fortran
+    enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
     enumerator :: torch_kInt8 = 1
     enumerator :: torch_kInt16 = 2
     enumerator :: torch_kInt32 = 3
     enumerator :: torch_kInt64 = 4
-    enumerator :: torch_kFloat16 = 5 ! not supported in fortran
+    enumerator :: torch_kFloat16 = 5 ! not supported in Fortran
     enumerator :: torch_kFloat32 = 6
     enumerator :: torch_kFloat64 = 7
   end enum
 
 
-  ! From c_torch.h (torch_device_t)
+  !| Enumerator for Torch devices
+  !  From c_torch.h (torch_device_t)
   enum, bind(c)
     enumerator :: torch_kCPU = 0
     enumerator :: torch_kCUDA = 1
   end enum
 
-  ! Interface for calculating tensor from array for different possible input types
+  !> Interface for directing `torch_tensor_from_array` to possible input types and ranks
   interface torch_tensor_from_array
     #:for PREC in PRECISIONS
     #:for RANK in RANKS
@@ -60,8 +74,10 @@ module ftorch
 
   interface
     function torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device) result(tensor_p) &
-        bind(c, name='torch_from_blob')
+                               bind(c, name = 'torch_from_blob')
       use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
+
+      ! Arguments
       type(c_ptr), value, intent(in)    :: data
       integer(c_int), value, intent(in) :: ndims
       integer(c_int64_t), intent(in)    :: tensor_shape(*)
@@ -75,15 +91,15 @@ module ftorch
 contains
 
   ! Torch Tensor API
-  !> Exposes the given data as a tensor without taking ownership of the original data.
-  !> This routine will take an (i, j, k) array and return an (k, j, i) tensor.
+  !| Exposes the given data as a tensor without taking ownership of the original data.
+  !  This routine will take an (i, j, k) array and return an (k, j, i) tensor.
   function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
     type(c_ptr), intent(in)        :: data       !! Pointer to data
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
     integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
     integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
     integer(c_int), intent(in)     :: layout(*)  !! Layout for strides for accessing data
     type(torch_tensor)             :: tensor     !! Returned tensor
 
@@ -92,7 +108,7 @@ contains
 
     strides(layout(1)) = 1
     do i = 2, ndims
-      strides(layout(i)) = strides(layout(i-1)) * tensor_shape(layout(i-1))
+      strides(layout(i)) = strides(layout(i - 1)) * tensor_shape(layout(i - 1))
     end do
     tensor%p = torch_from_blob_c(data, ndims, tensor_shape, strides, dtype, device)
   end function torch_tensor_from_blob
@@ -103,12 +119,12 @@ contains
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
     integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
     integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
     type(torch_tensor)             :: tensor     !! Returned tensor
 
     interface
       function torch_ones_c(ndims, tensor_shape, dtype, device) result(tensor) &
-          bind(c, name='torch_ones')
+          bind(c, name = 'torch_ones')
         use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
         integer(c_int), value, intent(in) :: ndims
         integer(c_int64_t), intent(in)    :: tensor_shape(*)
@@ -127,12 +143,12 @@ contains
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
     integer(c_int64_t), intent(in) :: tensor_shape(*)   !! Shape of the tensor
     integer(c_int), intent(in)     :: dtype      !! Data type of the tensor
-    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer(c_int), intent(in)     :: device     !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
     type(torch_tensor)             :: tensor     !! Returned tensor
 
     interface
       function torch_zeros_c(ndims, tensor_shape, dtype, device) result(tensor) &
-          bind(c, name='torch_zeros')
+          bind(c, name = 'torch_zeros')
         use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
         integer(c_int), value, intent(in) :: ndims
         integer(c_int64_t), intent(in)    :: tensor_shape(*)
@@ -147,11 +163,11 @@ contains
 
   !> Prints the contents of a tensor.
   subroutine torch_tensor_print(tensor)
-    type(torch_tensor), intent(in) :: tensor     !! Input tensor
+    type(torch_tensor), intent(in) :: tensor  !! Input tensor
 
     interface
       subroutine torch_tensor_print_c(tensor) &
-          bind(c, name='torch_tensor_print')
+          bind(c, name = 'torch_tensor_print')
         use, intrinsic :: iso_c_binding, only : c_ptr
         type(c_ptr), value, intent(in) :: tensor
       end subroutine torch_tensor_print_c
@@ -166,7 +182,7 @@ contains
 
     interface
       subroutine torch_tensor_delete_c(tensor) &
-          bind(c, name='torch_tensor_delete')
+          bind(c, name = 'torch_tensor_delete')
         use, intrinsic :: iso_c_binding, only : c_ptr
         type(c_ptr), value, intent(in) :: tensor
       end subroutine torch_tensor_delete_c
@@ -184,7 +200,7 @@ contains
 
     interface
       function torch_jit_load_c(filename) result(module) &
-          bind(c, name='torch_jit_load')
+          bind(c, name = 'torch_jit_load')
         use, intrinsic :: iso_c_binding, only : c_char, c_ptr
         character(c_char), intent(in) :: filename(*)
         type(c_ptr)                   :: module
@@ -193,88 +209,88 @@ contains
 
     ! Need to append c_null_char at end of filename
     module%p = torch_jit_load_c(trim(adjustl(filename))//c_null_char)
-    end function torch_module_load
-
-    !> Performs a forward pass of the module with the input tensors
-    subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
-      use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
-      type(torch_module), intent(in) :: module        !! Module
-      type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
-      type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
-      integer(c_int) ::  n_inputs
-
-      integer :: i
-      type(c_ptr), dimension(n_inputs), target  :: input_ptrs
-
-      interface
-        subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
-            output_tensor) &
-            bind(c, name='torch_jit_module_forward')
-          use, intrinsic :: iso_c_binding, only : c_ptr, c_int
-          type(c_ptr), value, intent(in) :: module
-          type(c_ptr), value, intent(in) :: input_tensors
-          integer(c_int), value, intent(in) :: n_inputs
-          type(c_ptr), value, intent(in) :: output_tensor
-        end subroutine torch_jit_module_forward_c
-      end interface
-
-      ! Assign array of pointers to the input tensors
-      do i = 1, n_inputs
-        input_ptrs(i) = input_tensors(i)%p
-      end do
-
-      call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
-    end subroutine torch_module_forward
-
-    !> Deallocates a Torch Script module
-    subroutine torch_module_delete(module)
-      type(torch_module), intent(in) :: module     !! Module
-
-      interface
-        subroutine torch_jit_module_delete_c(module) &
-            bind(c, name='torch_jit_module_delete')
-          use, intrinsic :: iso_c_binding, only : c_ptr
-          type(c_ptr), value, intent(in) :: module
-        end subroutine torch_jit_module_delete_c
-      end interface
-
-      call torch_jit_module_delete_c(module%p)
-    end subroutine torch_module_delete
+  end function torch_module_load
 
-    #:for PREC in PRECISIONS
-    #:for RANK in RANKS
-    !> return a torch tensor pointing to data_in array
-    function torch_tensor_from_array_${PREC}$_${RANK}$d(data_in, layout, c_device) result(tensor)
-      use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
-      use, intrinsic :: iso_fortran_env, only : ${PREC}$
+  !> Performs a forward pass of the module with the input tensors
+  subroutine torch_module_forward(module, input_tensors, n_inputs, output_tensor)
+    use, intrinsic :: iso_c_binding, only : c_ptr, c_int, c_loc
+    type(torch_module), intent(in) :: module        !! Module
+    type(torch_tensor), intent(in), dimension(:) :: input_tensors  !! Array of Input tensors
+    type(torch_tensor), intent(in) :: output_tensor !! Returned output tensors
+    integer(c_int) ::  n_inputs
 
-      ! inputs
-      ${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$   !! input data that tensor will point at
-      integer, intent(in)        :: layout(${RANK}$) !! control order of indices
-      integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (torch_kCPU or torch_kGPU)
+    integer :: i
+    type(c_ptr), dimension(n_inputs), target  :: input_ptrs
 
-      ! output tensory
-      type(torch_tensor) :: tensor     !! Returned tensor
+    interface
+      subroutine torch_jit_module_forward_c(module, input_tensors, n_inputs, &
+          output_tensor) &
+          bind(c, name = 'torch_jit_module_forward')
+        use, intrinsic :: iso_c_binding, only : c_ptr, c_int
+        type(c_ptr), value, intent(in) :: module
+        type(c_ptr), value, intent(in) :: input_tensors
+        integer(c_int), value, intent(in) :: n_inputs
+        type(c_ptr), value, intent(in) :: output_tensor
+      end subroutine torch_jit_module_forward_c
+    end interface
 
-      ! local data
-      integer(c_int64_t)        :: c_tensor_shape(${RANK}$)           !! Shape of the tensor
-      integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! data type
-      integer(c_int64_t)        :: strides(${RANK}$)                  !! Strides for accessing data
-      integer(c_int), parameter :: ndims = ${RANK}$                   !! number of dimension of input data
-      integer                   :: i
+    ! Assign array of pointers to the input tensors
+    do i = 1, n_inputs
+      input_ptrs(i) = input_tensors(i)%p
+    end do
 
-      c_tensor_shape = shape(data_in)
+    call torch_jit_module_forward_c(module%p, c_loc(input_ptrs), n_inputs, output_tensor%p)
+  end subroutine torch_module_forward
 
-      strides(layout(1)) = 1
-      do i = 2, ndims
-        strides(layout(i)) = strides(layout(i-1)) * c_tensor_shape(layout(i-1))
-      end do
+  !> Deallocates a Torch Script module
+  subroutine torch_module_delete(module)
+    type(torch_module), intent(in) :: module     !! Module to deallocate
 
-      tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+    interface
+      subroutine torch_jit_module_delete_c(module) &
+          bind(c, name = 'torch_jit_module_delete')
+        use, intrinsic :: iso_c_binding, only : c_ptr
+        type(c_ptr), value, intent(in) :: module
+      end subroutine torch_jit_module_delete_c
+    end interface
 
-    end function torch_tensor_from_array_${PREC}$_${RANK}$d
+    call torch_jit_module_delete_c(module%p)
+  end subroutine torch_module_delete
 
-    #:endfor
-    #:endfor
+  #:for PREC in PRECISIONS
+  #:for RANK in RANKS
+  !> Return a Torch tensor pointing to data_in array of rank ${RANK}$ containing data of type `${PREC}$`
+  function torch_tensor_from_array_${PREC}$_${RANK}$d(data_in, layout, c_device) result(tensor)
+    use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_float, c_loc
+    use, intrinsic :: iso_fortran_env, only : ${PREC}$
+
+    ! inputs
+    ${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$   !! Input data that tensor will point at
+    integer, intent(in)        :: layout(${RANK}$) !! Control order of indices
+    integer(c_int), intent(in) :: c_device         !! Device on which the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
+
+    ! output tensory
+    type(torch_tensor) :: tensor     !! Returned tensor
+
+    ! local data
+    integer(c_int64_t)        :: c_tensor_shape(${RANK}$)           !! Shape of the tensor
+    integer(c_int), parameter :: c_dtype = ${enum_from_prec(PREC)}$ !! Data type
+    integer(c_int64_t)        :: strides(${RANK}$)                  !! Strides for accessing data
+    integer(c_int), parameter :: ndims = ${RANK}$                   !! Number of dimension of input data
+    integer                   :: i
+
+    c_tensor_shape = shape(data_in)
+
+    strides(layout(1)) = 1
+    do i = 2, ndims
+      strides(layout(i)) = strides(layout(i - 1)) * c_tensor_shape(layout(i - 1))
+    end do
+
+    tensor%p = torch_from_blob_c(c_loc(data_in), ndims, c_tensor_shape, strides, c_dtype, c_device)
+
+  end function torch_tensor_from_array_${PREC}$_${RANK}$d
+
+  #:endfor
+  #:endfor
 
-  end module ftorch
+end module ftorch

From c2e50efa033c06235df3e06a80a79909f1c50fa7 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 15:45:21 +0000
Subject: [PATCH 05/20] Tidy enumerator docs

---
 src/ftorch.fypp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index 0e7e9cfb..6f9e7735 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -43,7 +43,7 @@ module ftorch
 
   !| Enumerator for Torch data types
   !  From c_torch.h (torch_data_t)
-  !  Note that torch_kUInt8 and torch_kFloat16 are not sypported in Fortran
+  !  Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran
   enum, bind(c)
     enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
     enumerator :: torch_kInt8 = 1

From e89507fe16bc9fc7e2ebaff2c4e67acc9ceb541f Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 18:08:08 +0000
Subject: [PATCH 06/20] Add pre-commit hook file to check fypp validity.

---
 pre-commit | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 69 insertions(+)
 create mode 100755 pre-commit

diff --git a/pre-commit b/pre-commit
new file mode 100755
index 00000000..d0f12fac
--- /dev/null
+++ b/pre-commit
@@ -0,0 +1,69 @@
+#!/bin/sh
+#
+# A hook script to verify what is about to be committed.
+# Called by "git commit" with no arguments.  The hook should
+# exit with non-zero status after issuing an appropriate message if
+# it wants to stop the commit.
+
+# Fail immediately at first issue with the relevant exit status.
+set -eo pipefail
+
+# ===================================================================
+
+if git rev-parse --verify HEAD >/dev/null 2>&1
+then
+	against=HEAD
+else
+	# Initial commit: diff against an empty tree object
+	against=$(git hash-object -t tree /dev/null)
+fi
+
+# ===================================================================
+
+# Check that ftorch.90 is not modified and staged alone.
+git diff --cached --name-only | if grep --quiet "ftorch.f90"; then
+  git diff --cached --name-only | if ! grep --quiet "ftorch.fypp"; then
+    cat <<\EOF
+Error: File ftorch.f90 has been modified and staged without ftorch.fypp being changed.
+ftorch.90 should be generated from ftorch.fypp using fypp.
+Please restore ftorch.f90 and make your modifications to ftorch.fypp instead.
+EOF
+    exit 1
+  fi
+fi
+
+# Check to see if ftorch.fypp has been modified AND is staged.
+git diff --cached --name-only | if grep --quiet "ftorch.fypp"; then
+
+  # Check that ftorch.90 is also modified and staged.
+  git diff --cached --name-only | if ! grep --quiet "ftorch.f90"; then
+    cat <<\EOF
+Error: File ftorch.fypp has been modified and staged, but ftorch.f90 has not.
+ftorch.90 should be generated from ftorch.fypp and both committed together.
+Please run fypp on ftorch.fypp to generate ftorch.f90 and commit together.
+EOF
+    exit 1
+  else
+    # Check fypp available, and raise error and exit if not.
+    if ! command -v fypp &> /dev/null; then
+      cat <<\EOF
+echo "Error: Could not find fypp to run on ftorch.fypp.
+Please install fypp using "pip install fypp" and then try committing again.
+EOF
+      exit 1
+    fi
+
+    # If fypp is available and both .f90 and .fypp staged, check they match.
+    fypp src/ftorch.fypp src/ftorch.f90_tmp
+    if ! diff -q "src/ftorch.f90" "src/ftorch.f90_tmp" &> /dev/null; then
+      rm src/ftorch.f90_tmp
+      cat <<\EOF
+Error: The code in ftorch.f90 does not match that expected from ftorch.fypp.
+Please re-run fypp on ftorch.fypp to ensure consistency before committing.
+EOF
+      exit 1
+    else
+      rm src/ftorch.f90_tmp
+    fi
+  fi
+fi

From 424b3ed6acb445aed049439cce4c3961dceb88cc Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 18:16:12 +0000
Subject: [PATCH 07/20] Bring fypp and f90 in line and tidy docs.

---
 src/ftorch.f90  | 8 ++++----
 src/ftorch.fypp | 6 +++---
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/ftorch.f90 b/src/ftorch.f90
index d7432801..5b6389c2 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -24,9 +24,9 @@ module ftorch
     type(c_ptr) :: p = c_null_ptr  !! pointer to the tensor in memory
   end type torch_tensor
 
-  !| Enumerator for Torch data types
-  !  From c_torch.h (torch_data_t)
-  !  Note that torch_kUInt8 and torch_kFloat16 are not sypported in Fortran
+  !| Enumerator for Torch data types  
+  !  From c_torch.h (torch_data_t)  
+  !  Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran
   enum, bind(c)
     enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
     enumerator :: torch_kInt8 = 1
@@ -39,7 +39,7 @@ module ftorch
   end enum
 
 
-  !| Enumerator for Torch devices
+  !| Enumerator for Torch devices  
   !  From c_torch.h (torch_device_t)
   enum, bind(c)
     enumerator :: torch_kCPU = 0
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index 6f9e7735..aafe5d11 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -41,8 +41,8 @@ module ftorch
     type(c_ptr) :: p = c_null_ptr  !! pointer to the tensor in memory
   end type torch_tensor
 
-  !| Enumerator for Torch data types
-  !  From c_torch.h (torch_data_t)
+  !| Enumerator for Torch data types  
+  !  From c_torch.h (torch_data_t)  
   !  Note that 0 `torch_kUInt8` and 5 `torch_kFloat16` are not sypported in Fortran
   enum, bind(c)
     enumerator :: torch_kUInt8 = 0 ! not supported in Fortran
@@ -56,7 +56,7 @@ module ftorch
   end enum
 
 
-  !| Enumerator for Torch devices
+  !| Enumerator for Torch devices  
   !  From c_torch.h (torch_device_t)
   enum, bind(c)
     enumerator :: torch_kCPU = 0

From 7fe963b5b1e57e75b3c00ad47220a4af5d8e01ec Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 18:28:59 +0000
Subject: [PATCH 08/20] Add fypp workflow attempt.

---
 .github/workflows/fypp.yml | 25 +++++++++++++++++++++++++
 1 file changed, 25 insertions(+)
 create mode 100644 .github/workflows/fypp.yml

diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml
new file mode 100644
index 00000000..765823e7
--- /dev/null
+++ b/.github/workflows/fypp.yml
@@ -0,0 +1,25 @@
+name: fypp-checks
+
+on:
+  # run on every push
+  push:
+  # run on every push (not commit) to a PR, plus open/reopen
+  pull_request:
+    types:
+    - synchronize
+    - opened
+    - reopened
+
+jobs:
+  various:
+    name: FYPP checks - run pre-commit hook in repo
+    runs-on: ubuntu-latest
+    steps:
+    - uses: actions/checkout@v3
+    - uses: actions/setup-python@v4
+      with:
+        python-version: "3.11"
+    - run: pip install fypp
+
+    - name: Check pre-commit hook
+      run: ./pre-commit

From 1c5b81460a33cb173021695428ccc4e98e570752 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 18:32:59 +0000
Subject: [PATCH 09/20] Allow fypp workflow to run hook using chmod +x.

---
 .github/workflows/fypp.yml | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml
index 765823e7..4ce302f7 100644
--- a/.github/workflows/fypp.yml
+++ b/.github/workflows/fypp.yml
@@ -22,4 +22,5 @@ jobs:
     - run: pip install fypp
 
     - name: Check pre-commit hook
-      run: ./pre-commit
+      run: chmod +x pre-commit
+           ./pre-commit

From 9d4b3d78ab130fa5fd4f1a8f68e9212b2d71a277 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 18:36:15 +0000
Subject: [PATCH 10/20] fypp runs on all pushes, so no need to duplicate on
 PRs.

---
 .github/workflows/fypp.yml | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml
index 4ce302f7..679ec3d4 100644
--- a/.github/workflows/fypp.yml
+++ b/.github/workflows/fypp.yml
@@ -3,12 +3,6 @@ name: fypp-checks
 on:
   # run on every push
   push:
-  # run on every push (not commit) to a PR, plus open/reopen
-  pull_request:
-    types:
-    - synchronize
-    - opened
-    - reopened
 
 jobs:
   various:

From 0cc57b6ed0e8bb24a712dcbc0f62b4e9797b48fe Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 19:58:09 +0000
Subject: [PATCH 11/20] Modify fypp workflow to fail if ftorch.f90 does not
 match expected result of ftorch.fypp on every push.

---
 .github/workflows/fypp.yml | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/.github/workflows/fypp.yml b/.github/workflows/fypp.yml
index 679ec3d4..88f5af8c 100644
--- a/.github/workflows/fypp.yml
+++ b/.github/workflows/fypp.yml
@@ -6,7 +6,7 @@ on:
 
 jobs:
   various:
-    name: FYPP checks - run pre-commit hook in repo
+    name: FYPP checks - runs check on fypp and f90 files
     runs-on: ubuntu-latest
     steps:
     - uses: actions/checkout@v3
@@ -15,6 +15,13 @@ jobs:
         python-version: "3.11"
     - run: pip install fypp
 
-    - name: Check pre-commit hook
-      run: chmod +x pre-commit
-           ./pre-commit
+    - name: Check fypp matches f90
+      run: |
+        fypp src/ftorch.fypp src/temp.f90_temp
+        if ! diff -q src/ftorch.f90 src/temp.f90_temp; then
+          echo "Error: The code in ftorch.f90 does not match that expected from ftorch.fypp."
+          echo "Please re-run fypp on ftorch.fypp to ensure consistency and re-commit."
+          exit 1
+        else
+          exit 0
+        fi

From f53cbc74cec28c89fbf2c7c419bb5dc3096ec2eb Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sat, 18 Nov 2023 20:24:07 +0000
Subject: [PATCH 12/20] Update README example consistent with
 torch_tensor_from_array approach.

---
 README.md | 34 ++++++++++++++--------------------
 1 file changed, 14 insertions(+), 20 deletions(-)

diff --git a/README.md b/README.md
index 7852c53c..77449b2b 100644
--- a/README.md
+++ b/README.md
@@ -128,11 +128,10 @@ To use the trained Torch model from within Fortran we need to import the `ftorch
 A very simple example is given below.
 For more detailed documentation please consult the API documentation, source code, and examples.
 
-This minimal snippet loads a saved Torch model, creates inputs consisting of two `10x10` matrices (one of ones, and one of zeros), and runs the model to infer output.
+This minimal snippet loads a saved Torch model, creates an input consisting of a `10x10` matrix of ones, and runs the model to infer output.  
+This is illustrative only, and we recommend following the [examples](examples/) before writing your own code to explore more features.
 
 ```fortran
-! Import any C bindings as required for this code
-use, intrinsic :: iso_c_binding, only: c_int, c_int64_t, c_loc
 ! Import library for interfacing with PyTorch
 use ftorch
 
@@ -141,34 +140,30 @@ implicit none
 ! Generate an object to hold the Torch model
 type(torch_module) :: model
 
-! Set up types of input and output data and the interface with C
-integer(c_int), parameter :: dims_input = 2
-integer(c_int64_t) :: shape_input(dims_input)
-integer(c_int), parameter :: n_inputs = 2
+! Set up types of input and output data
+integer, parameter :: n_inputs = 1
 type(torch_tensor), dimension(n_inputs) :: model_input_arr
-integer(c_int), parameter :: dims_output = 1
-integer(c_int64_t) :: shape_output(dims_output)
 type(torch_tensor) :: model_output
 
-! Set up the model inputs as Fortran arrays
-real, dimension(10,10), target  :: input_1, input_2
+! Set up the model input and output as Fortran arrays
+real, dimension(10,10), target  :: input
 real, dimension(5), target   :: output
 
+! Set up number of dimensions of input tensor and axis order
+integer, parameter :: in_dims = 2
+integer :: in_layout(in_dims) = [1,2]
+
 ! Initialise the Torch model to be used
 model = torch_module_load("/path/to/saved/model.pt")
 
-! Initialise the inputs as Fortran
-input_1 = 0.0
-input_2 = 1.0
+! Initialise the inputs as Fortran array of ones
+input = 1.0
 
 ! Wrap Fortran data as no-copy Torch Tensors
 ! There may well be some reshaping required depending on the 
 ! structure of the model which is not covered here (see examples)
-shape_input = (/10, 10/)
-shape_output = (/5/)
-model_input_arr(1) = torch_tensor_from_blob(c_loc(input_1), dims_input, shape_input, torch_kFloat64, torch_kCPU)
-model_input_arr(2) = torch_tensor_from_blob(c_loc(input_2), dims_input, shape_input, torch_kFloat64, torch_kCPU)
-model_output = torch_tensor_from_blob(c_loc(output), dims_output, shape_output, torch_kFloat64, torch_kCPU)
+model_input_arr(1) = torch_tensor_from_array(input, in_layout, torch_kCPU)
+model_output = torch_tensor_from_array(output, out_layout, torch_kCPU)
 
 ! Run model and Infer
 ! Again, there may be some reshaping required depending on model design
@@ -180,7 +175,6 @@ write(*,*) output
 ! Clean up
 call torch_module_delete(model)
 call torch_tensor_delete(model_input_arr(1))
-call torch_tensor_delete(model_input_arr(2))
 call torch_tensor_delete(model_output)
 ```
 

From a55cdc35f6b855ceee0c1e183bcca589d97ce59c Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Sun, 19 Nov 2023 18:18:21 +0000
Subject: [PATCH 13/20] Update example 1 to use torch_tensor_from_array.

---
 .../1_SimpleNet/simplenet_infer_fortran.f90   | 40 +++++++++----------
 1 file changed, 18 insertions(+), 22 deletions(-)

diff --git a/examples/1_SimpleNet/simplenet_infer_fortran.f90 b/examples/1_SimpleNet/simplenet_infer_fortran.f90
index 2f1ff385..199b984c 100644
--- a/examples/1_SimpleNet/simplenet_infer_fortran.f90
+++ b/examples/1_SimpleNet/simplenet_infer_fortran.f90
@@ -1,28 +1,30 @@
 program inference
 
-   ! Imports primitives used to interface with C
-   use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc
+   ! Import precision info from iso
+   use, intrinsic :: iso_fortran_env, only : sp => real32
+
    ! Import our library for interfacing with PyTorch
    use ftorch
 
    implicit none
-
+  
+   ! Set precision for reals
+   integer, parameter :: wp = sp
+   
    integer :: num_args, ix
    character(len=128), dimension(:), allocatable :: args
 
-   ! Set up types of input and output data and the interface with C
+   ! Set up Fortran data structures
+   real(wp), dimension(5), target :: in_data
+   real(wp), dimension(5), target :: out_data
+   integer, parameter :: n_inputs = 1
+   integer :: tensor_layout(1) = [1]
+
+   ! Set up Torch data structures
    type(torch_module) :: model
    type(torch_tensor), dimension(1) :: in_tensor
    type(torch_tensor) :: out_tensor
 
-   real(c_float), dimension(:), allocatable, target :: in_data
-   integer(c_int), parameter :: n_inputs = 1
-   real(c_float), dimension(:), allocatable, target :: out_data
-
-   integer(c_int), parameter :: tensor_dims = 1
-   integer(c_int64_t) :: tensor_shape(tensor_dims) = [5]
-   integer(c_int) :: tensor_layout(tensor_dims) = [1]
-
    ! Get TorchScript model file as a command line argument
    num_args = command_argument_count()
    allocate(args(num_args))
@@ -30,18 +32,14 @@ program inference
        call get_command_argument(ix,args(ix))
    end do
 
-   ! Allocate one-dimensional input/output arrays, based on multiplication of all input/output dimension sizes
-   allocate(in_data(tensor_shape(1)))
-   allocate(out_data(tensor_shape(1)))
-
    ! Initialise data
    in_data = [0.0, 1.0, 2.0, 3.0, 4.0]
 
-   ! Create input/output tensors from the above arrays
-   in_tensor(1) = torch_tensor_from_blob(c_loc(in_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout)
-   out_tensor = torch_tensor_from_blob(c_loc(out_data), tensor_dims, tensor_shape, torch_kFloat32, torch_kCPU, tensor_layout)
+   ! Create Torch input/output tensors from the above arrays
+   in_tensor(1) = torch_tensor_from_array(in_data, tensor_layout, torch_kCPU)
+   out_tensor = torch_tensor_from_array(out_data, tensor_layout, torch_kCPU)
 
-   ! Load ML model (edit this line to use different models)
+   ! Load ML model
    model = torch_module_load(args(1))
 
    ! Infer
@@ -52,7 +50,5 @@ program inference
    call torch_module_delete(model)
    call torch_tensor_delete(in_tensor(1))
    call torch_tensor_delete(out_tensor)
-   deallocate(in_data)
-   deallocate(out_data)
 
 end program inference

From b3c479f5bcf724966be065c7ea2efe9602bdba06 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Mon, 20 Nov 2023 07:45:14 +0000
Subject: [PATCH 14/20] Correction to example Fortran in README to provide
 missing output params.

---
 README.md | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/README.md b/README.md
index 77449b2b..5f6c9524 100644
--- a/README.md
+++ b/README.md
@@ -152,6 +152,8 @@ real, dimension(5), target   :: output
 ! Set up number of dimensions of input tensor and axis order
 integer, parameter :: in_dims = 2
 integer :: in_layout(in_dims) = [1,2]
+integer, parameter :: out_dims = 1
+integer :: out_layout(out_dims) = [1]
 
 ! Initialise the Torch model to be used
 model = torch_module_load("/path/to/saved/model.pt")

From 6616c3552bd6f3565d57ebbc2369f106b3418d93 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Tue, 21 Nov 2023 15:06:04 +0000
Subject: [PATCH 15/20] Update ftorch to standardise argument order.

---
 src/ftorch.f90  | 2 +-
 src/ftorch.fypp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/ftorch.f90 b/src/ftorch.f90
index 5b6389c2..572fd323 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -95,7 +95,7 @@ end function torch_from_blob_c
   ! Torch Tensor API
   !| Exposes the given data as a tensor without taking ownership of the original data.
   !  This routine will take an (i, j, k) array and return an (k, j, i) tensor.
-  function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
+  function torch_tensor_from_blob(data, layout, ndims, tensor_shape, dtype, device) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
     type(c_ptr), intent(in)        :: data       !! Pointer to data
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index aafe5d11..c26984a6 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -93,7 +93,7 @@ contains
   ! Torch Tensor API
   !| Exposes the given data as a tensor without taking ownership of the original data.
   !  This routine will take an (i, j, k) array and return an (k, j, i) tensor.
-  function torch_tensor_from_blob(data, ndims, tensor_shape, dtype, device, layout) result(tensor)
+  function torch_tensor_from_blob(data, layout, ndims, tensor_shape, dtype, device) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
     type(c_ptr), intent(in)        :: data       !! Pointer to data
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor

From 8f8345dbea5504faa84c130bd1d410f4961b59b3 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Tue, 21 Nov 2023 15:43:18 +0000
Subject: [PATCH 16/20] Remove tensor_tests as outdated and to be replaced by
 cgdrag examples in future.

---
 examples/tensor_tests/CMakeLists.txt  | 19 --------
 examples/tensor_tests/test_tensor.f90 | 69 ---------------------------
 2 files changed, 88 deletions(-)
 delete mode 100644 examples/tensor_tests/CMakeLists.txt
 delete mode 100644 examples/tensor_tests/test_tensor.f90

diff --git a/examples/tensor_tests/CMakeLists.txt b/examples/tensor_tests/CMakeLists.txt
deleted file mode 100644
index 6571f1cb..00000000
--- a/examples/tensor_tests/CMakeLists.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
-#policy CMP0076 - target_sources source files are relative to file where target_sources is run
-cmake_policy (SET CMP0076 NEW)
-
-set(PROJECT_NAME test_tensor)
-
-project(${PROJECT_NAME} LANGUAGES Fortran)
-
-# Build in Debug mode if not specified
-if(NOT CMAKE_BUILD_TYPE)
-    set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE)
-endif()
-
-find_package(FTorch)
-message(STATUS "Building with Fortran PyTorch coupling")
-
-# Some tests for tensor generation.
-add_executable(test_tensor test_tensor.f90)
-target_link_libraries(test_tensor PRIVATE FTorch::ftorch)
diff --git a/examples/tensor_tests/test_tensor.f90 b/examples/tensor_tests/test_tensor.f90
deleted file mode 100644
index d7bc1410..00000000
--- a/examples/tensor_tests/test_tensor.f90
+++ /dev/null
@@ -1,69 +0,0 @@
-program test_tensor
-  use, intrinsic :: iso_c_binding, only: c_int64_t, c_float, c_char, c_ptr, c_loc
-  use ftorch
-  implicit none
-
-  real(kind=8), dimension(:,:), allocatable, target  :: uuu_flattened, vvv_flattened
-  real(kind=8), dimension(:,:), allocatable, target    :: lat_reshaped, psfc_reshaped
-  real(kind=8), dimension(:,:), allocatable, target  :: gwfcng_x_flattened, gwfcng_y_flattened
-  type(torch_tensor), target :: output_tensor
-  integer(c_int), parameter :: dims_1D = 2
-  integer(c_int), parameter :: dims_2D = 2
-  integer(c_int64_t) :: shape_2D_F(dims_2D), shape_2D_C(dims_2D)
-  integer(c_int64_t) :: shape_1D_F(dims_1D), shape_1D_C(dims_1D)
-  integer(c_int) :: layout_F(dims_1D), layout_C(dims_1D)
-  integer :: imax, jmax, kmax, i, j, k
-
-  imax = 1
-  jmax = 5
-  kmax = 7
-
-  shape_2D_F = (/ kmax, imax*jmax /)
-  shape_1D_F = (/ 1, imax*jmax /)
-  shape_2D_C = (/ imax*jmax, kmax /)
-  shape_1D_C = (/ imax*jmax, 1 /)
-
-  layout_F = (/ 1, 2 /)
-  layout_C = (/ 2, 1 /)
-
-  allocate( lat_reshaped(imax*jmax, 1) )
-  allocate( uuu_flattened(imax*jmax, kmax) )
-  do i = 1, imax*jmax
-    lat_reshaped(i, 1) = i
-    do k = 1, kmax
-      uuu_flattened(i, k) = i + k*100
-    end do
-  end do
-
-  write(*,*) uuu_flattened
-
-  output_tensor = torch_tensor_from_blob(c_loc(uuu_flattened), &
-  dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU, layout_F)
-
-  call torch_tensor_print(output_tensor)
-
-  output_tensor = torch_tensor_from_blob(c_loc(uuu_flattened), &
-  dims_2D, shape_2D_F, torch_kFloat64, torch_kCPU, layout_C)
-
-  call torch_tensor_print(output_tensor)
-
-  shape_2D_F = shape(uuu_flattened)
-  output_tensor = torch_tensor_from_array_c_double(uuu_flattened, shape_2D_F, torch_kCPU)
-
-  call torch_tensor_print(output_tensor)
-
-  output_tensor = torch_tensor_from_array(uuu_flattened, shape_2D_F, torch_kCPU)
-  
-  call torch_tensor_print(output_tensor)
-
-  ! output_tensor = torch_tensor_zeros( &
-  ! dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU)
-
-  ! call torch_tensor_print(output_tensor)
-
-  ! output_tensor = torch_tensor_ones( &
-  ! dims_2D, shape_2D_C, torch_kFloat64, torch_kCPU)
-
-  ! call torch_tensor_print(output_tensor)
-
-end program test_tensor

From 656fa9d60044601623cf8125fb04a55577dedc90 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Tue, 21 Nov 2023 17:03:56 +0000
Subject: [PATCH 17/20] Update pre-commit hook to be bash.

---
 pre-commit | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pre-commit b/pre-commit
index d0f12fac..f3a5aa98 100755
--- a/pre-commit
+++ b/pre-commit
@@ -1,4 +1,4 @@
-#!/bin/sh
+#!/bin/bash
 #
 # A hook script to verify what is about to be committed.
 # Called by "git commit" with no arguments.  The hook should

From 703bc5ae48becb261196bcfe5ebb19d3d501a2f4 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Tue, 21 Nov 2023 17:06:16 +0000
Subject: [PATCH 18/20] Create a githooks directory for pre-commit hook.

---
 pre-commit | 69 ------------------------------------------------------
 1 file changed, 69 deletions(-)
 delete mode 100755 pre-commit

diff --git a/pre-commit b/pre-commit
deleted file mode 100755
index f3a5aa98..00000000
--- a/pre-commit
+++ /dev/null
@@ -1,69 +0,0 @@
-#!/bin/bash
-#
-# A hook script to verify what is about to be committed.
-# Called by "git commit" with no arguments.  The hook should
-# exit with non-zero status after issuing an appropriate message if
-# it wants to stop the commit.
-
-# Fail immediately at first issue with the relevant exit status.
-set -eo pipefail
-
-# ===================================================================
-
-if git rev-parse --verify HEAD >/dev/null 2>&1
-then
-	against=HEAD
-else
-	# Initial commit: diff against an empty tree object
-	against=$(git hash-object -t tree /dev/null)
-fi
-
-# ===================================================================
-
-# Check that ftorch.90 is not modified and staged alone.
-git diff --cached --name-only | if grep --quiet "ftorch.f90"; then
-  git diff --cached --name-only | if ! grep --quiet "ftorch.fypp"; then
-    cat <<\EOF
-Error: File ftorch.f90 has been modified and staged without ftorch.fypp being changed.
-ftorch.90 should be generated from ftorch.fypp using fypp.
-Please restore ftorch.f90 and make your modifications to ftorch.fypp instead.
-EOF
-    exit 1
-  fi
-fi
-
-# Check to see if ftorch.fypp has been modified AND is staged.
-git diff --cached --name-only | if grep --quiet "ftorch.fypp"; then
-
-  # Check that ftorch.90 is also modified and staged.
-  git diff --cached --name-only | if ! grep --quiet "ftorch.f90"; then
-    cat <<\EOF
-Error: File ftorch.fypp has been modified and staged, but ftorch.f90 has not.
-ftorch.90 should be generated from ftorch.fypp and both committed together.
-Please run fypp on ftorch.fypp to generate ftorch.f90 and commit together.
-EOF
-    exit 1
-  else
-    # Check fypp available, and raise error and exit if not.
-    if ! command -v fypp &> /dev/null; then
-      cat <<\EOF
-echo "Error: Could not find fypp to run on ftorch.fypp.
-Please install fypp using "pip install fypp" and then try committing again.
-EOF
-      exit 1
-    fi
-
-    # If fypp is available and both .f90 and .fypp staged, check they match.
-    fypp src/ftorch.fypp src/ftorch.f90_tmp
-    if ! diff -q "src/ftorch.f90" "src/ftorch.f90_tmp" &> /dev/null; then
-      rm src/ftorch.f90_tmp
-      cat <<\EOF
-Error: The code in ftorch.f90 does not match that expected from ftorch.fypp.
-Please re-run fypp on ftorch.fypp to ensure consistency before committing.
-EOF
-      exit 1
-    else
-      rm src/ftorch.f90_tmp
-    fi
-  fi
-fi

From d74e650ca8ce03740dfab42ff983b8232c82e799 Mon Sep 17 00:00:00 2001
From: jatkinson1000 <109271713+jatkinson1000@users.noreply.github.com>
Date: Tue, 21 Nov 2023 17:27:00 +0000
Subject: [PATCH 19/20] Create a githooks directory for pre-commit hook

---
 .githooks/pre-commit | 69 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 69 insertions(+)
 create mode 100755 .githooks/pre-commit

diff --git a/.githooks/pre-commit b/.githooks/pre-commit
new file mode 100755
index 00000000..f3a5aa98
--- /dev/null
+++ b/.githooks/pre-commit
@@ -0,0 +1,69 @@
+#!/bin/bash
+#
+# A hook script to verify what is about to be committed.
+# Called by "git commit" with no arguments.  The hook should
+# exit with non-zero status after issuing an appropriate message if
+# it wants to stop the commit.
+
+# Fail immediately at first issue with the relevant exit status.
+set -eo pipefail
+
+# ===================================================================
+
+if git rev-parse --verify HEAD >/dev/null 2>&1
+then
+	against=HEAD
+else
+	# Initial commit: diff against an empty tree object
+	against=$(git hash-object -t tree /dev/null)
+fi
+
+# ===================================================================
+
+# Check that ftorch.90 is not modified and staged alone.
+git diff --cached --name-only | if grep --quiet "ftorch.f90"; then
+  git diff --cached --name-only | if ! grep --quiet "ftorch.fypp"; then
+    cat <<\EOF
+Error: File ftorch.f90 has been modified and staged without ftorch.fypp being changed.
+ftorch.90 should be generated from ftorch.fypp using fypp.
+Please restore ftorch.f90 and make your modifications to ftorch.fypp instead.
+EOF
+    exit 1
+  fi
+fi
+
+# Check to see if ftorch.fypp has been modified AND is staged.
+git diff --cached --name-only | if grep --quiet "ftorch.fypp"; then
+
+  # Check that ftorch.90 is also modified and staged.
+  git diff --cached --name-only | if ! grep --quiet "ftorch.f90"; then
+    cat <<\EOF
+Error: File ftorch.fypp has been modified and staged, but ftorch.f90 has not.
+ftorch.90 should be generated from ftorch.fypp and both committed together.
+Please run fypp on ftorch.fypp to generate ftorch.f90 and commit together.
+EOF
+    exit 1
+  else
+    # Check fypp available, and raise error and exit if not.
+    if ! command -v fypp &> /dev/null; then
+      cat <<\EOF
+echo "Error: Could not find fypp to run on ftorch.fypp.
+Please install fypp using "pip install fypp" and then try committing again.
+EOF
+      exit 1
+    fi
+
+    # If fypp is available and both .f90 and .fypp staged, check they match.
+    fypp src/ftorch.fypp src/ftorch.f90_tmp
+    if ! diff -q "src/ftorch.f90" "src/ftorch.f90_tmp" &> /dev/null; then
+      rm src/ftorch.f90_tmp
+      cat <<\EOF
+Error: The code in ftorch.f90 does not match that expected from ftorch.fypp.
+Please re-run fypp on ftorch.fypp to ensure consistency before committing.
+EOF
+      exit 1
+    else
+      rm src/ftorch.f90_tmp
+    fi
+  fi
+fi

From 7943978fe54a5010419ed2a1ca5e635aa613b517 Mon Sep 17 00:00:00 2001
From: melt <thomas.meltzer1@gmail.com>
Date: Tue, 21 Nov 2023 18:23:43 +0000
Subject: [PATCH 20/20] update order of args for torch_tensor_from_blob

---
 src/ftorch.f90  | 2 +-
 src/ftorch.fypp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/ftorch.f90 b/src/ftorch.f90
index 572fd323..945673cf 100644
--- a/src/ftorch.f90
+++ b/src/ftorch.f90
@@ -95,7 +95,7 @@ end function torch_from_blob_c
   ! Torch Tensor API
   !| Exposes the given data as a tensor without taking ownership of the original data.
   !  This routine will take an (i, j, k) array and return an (k, j, i) tensor.
-  function torch_tensor_from_blob(data, layout, ndims, tensor_shape, dtype, device) result(tensor)
+  function torch_tensor_from_blob(data, ndims, tensor_shape, layout, dtype, device) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
     type(c_ptr), intent(in)        :: data       !! Pointer to data
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index c26984a6..16371e48 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -93,7 +93,7 @@ contains
   ! Torch Tensor API
   !| Exposes the given data as a tensor without taking ownership of the original data.
   !  This routine will take an (i, j, k) array and return an (k, j, i) tensor.
-  function torch_tensor_from_blob(data, layout, ndims, tensor_shape, dtype, device) result(tensor)
+  function torch_tensor_from_blob(data, ndims, tensor_shape, layout, dtype, device) result(tensor)
     use, intrinsic :: iso_c_binding, only : c_int, c_int64_t, c_ptr
     type(c_ptr), intent(in)        :: data       !! Pointer to data
     integer(c_int), intent(in)     :: ndims      !! Number of dimensions of the tensor