Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type check for better fortran interface support #182

Merged
merged 3 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ some of the options are used:
--f90-max-line-length F90_MAX_LINE_LENGTH
Maximum length of lines in fortan files written.
Default: 120
--type-check Check for type/shape matching of Python argument with the wrapped Fortran subroutine


Author
Expand Down
3 changes: 2 additions & 1 deletion examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ EXAMPLES = arrayderivedtypes \
strings \
subroutine_contains_issue101 \
type_bn \
docstring
docstring \
type_check

PYTHON = python

Expand Down
42 changes: 42 additions & 0 deletions examples/type_check/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#=======================================================================
# define the compiler names
#=======================================================================

CC = gcc
F90 = gfortran
PYTHON = python
CFLAGS = -fPIC
F90FLAGS = -fPIC
PY_MOD = pywrapper
F90_SRC = main.f90
OBJ = $(F90_SRC:.f90=.o)
F90WRAP_SRC = $(addprefix f90wrap_,${F90_SRC})
WRAPFLAGS = -v --type-check
F2PYFLAGS = --build-dir build
F90WRAP = f90wrap
F2PY = f2py-f90wrap
.PHONY: all clean

all: test

clean:
rm -rf *.mod *.smod *.o f90wrap*.f90 ${PY_MOD}.py _${PY_MOD}*.so __pycache__/ .f2py_f2cmap build ${PY_MOD}/

main.o: ${F90_SRC}
${F90} ${F90FLAGS} -c $< -o $@

%.o: %.f90
${F90} ${F90FLAGS} -c $< -o $@

${F90WRAP_SRC}: ${OBJ}
${F90WRAP} -m ${PY_MOD} ${WRAPFLAGS} ${F90_SRC}

f90wrap: ${F90WRAP_SRC}

f2py: ${F90WRAP_SRC}
CFLAGS="${CFLAGS}" ${F2PY} -c -m _${PY_MOD} ${F2PYFLAGS} f90wrap_*.f90 *.o

wrapper: f2py

test: wrapper
python type_check_test.py
68 changes: 68 additions & 0 deletions examples/type_check/main.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@

module m_circle
implicit none
private

type, public :: t_square
real :: length
end type t_square

type, public :: t_circle
real :: radius
end type t_circle

interface is_circle
module procedure is_circle_circle
module procedure is_circle_square
end interface is_circle

interface write_array
module procedure write_array_int_1d
module procedure write_array_int_2d
module procedure write_array_real
module procedure write_array_double
end interface write_array

public :: is_circle
public :: write_array
public :: is_circle_circle
public :: is_circle_square
public :: write_array_int_1d

contains

subroutine is_circle_circle(circle, output)
type(t_circle) :: circle
integer :: output(:)
output(:) = 1
end subroutine is_circle_circle

subroutine is_circle_square(square, output)
type(t_square) :: square
integer :: output(:)
output(:) = 0
end subroutine is_circle_square

subroutine write_array_int_1d(output)
integer :: output(:)
output(:) = 1
end subroutine write_array_int_1d

subroutine write_array_int_2d(output)
integer :: output(:,:)
output(:,:) = 2
end subroutine write_array_int_2d

subroutine write_array_real(output)
real :: output(:)
output(:) = 3
end subroutine write_array_real

subroutine write_array_double(output)
double precision :: output(:)
output(:) = 4
end subroutine write_array_double

end module m_circle


87 changes: 87 additions & 0 deletions examples/type_check/type_check_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import unittest
import numpy as np

from pywrapper import m_circle

class TestTypeCheck(unittest.TestCase):

def __init__(self, *args, **kwargs):
super(TestTypeCheck, self).__init__(*args, **kwargs)
self._circle = m_circle.t_circle()
self._square = m_circle.t_square()

def test_derived_type_selection(self):
out_circle = np.array([-1], dtype=np.int32)
out_square = np.array([-1], dtype=np.int32)

m_circle.is_circle(self._circle, out_circle)
m_circle.is_circle(self._square, out_square)

assert out_circle[0]==1
assert out_square[0]==0

def test_shape_selection_1d(self):
out = np.array([-1], dtype=np.int32)
m_circle.write_array(out)

assert out[0]==1

def test_shape_selection_2d(self):
out = np.array([[-1]], dtype=np.int32)
m_circle.write_array(out)

assert out[0]==2

def test_type_selection(self):
out = np.array([-1], dtype=np.float32)
m_circle.write_array(out)

assert out[0]==3

def test_kind_selection(self):
out = np.array([-1], dtype=np.float64)
m_circle.write_array(out)

assert out[0]==4

def test_wrong_derived_type(self):
out = np.array([-1], dtype=np.int32)

with self.assertRaises(TypeError):
m_circle._is_circle_square(self._circle, out)

with self.assertRaises(TypeError):
m_circle._is_circle_circle(self._square, out)

def test_wrong_kind(self):
out = np.array([-1], dtype=np.int64)

with self.assertRaises(TypeError):
m_circle._write_array_int_1d(out)

def test_wrong_type(self):
out = np.array([-1], dtype=np.float32)

with self.assertRaises(TypeError):
m_circle._write_array_int_1d(out)

def test_wrong_dim(self):
out = np.array([[-1]], dtype=np.int32)

with self.assertRaises(TypeError):
m_circle._write_array_int_1d(out)

def test_no_suitable_version(self):
with self.assertRaises(TypeError):
m_circle.is_circle(1., 1.)

def test_no_suitable_version_2(self):
out = np.array([-1], dtype=np.complex)

with self.assertRaises(TypeError):
m_circle.write_array(out)


if __name__ == '__main__':

unittest.main()
104 changes: 102 additions & 2 deletions f90wrap/pywrapgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ def _format_line_no(lineno):
class PythonWrapperGenerator(ft.FortranVisitor, cg.CodeGenerator):
def __init__(self, prefix, mod_name, types, f90_mod_name=None,
make_package=False, kind_map=None, init_file=None,
py_mod_names=None, class_names=None, max_length=None):
py_mod_names=None, class_names=None, max_length=None,
type_check=False):
if max_length is None:
max_length = 80
cg.CodeGenerator.__init__(self, indent=' ' * 4,
Expand All @@ -177,11 +178,13 @@ def __init__(self, prefix, mod_name, types, f90_mod_name=None,
kind_map = {}
self.kind_map = kind_map
self.init_file = init_file
self.type_check = type_check

def write_imports(self, insert=0):
default_imports = [(self.f90_mod_name, None),
('f90wrap.runtime', None),
('logging', None)]
('logging', None),
('numpy', None)]
imp_lines = ['from __future__ import print_function, absolute_import, division']
for (mod, symbol) in default_imports + list(self.imports):
if symbol is None:
Expand Down Expand Up @@ -416,6 +419,10 @@ def visit_Procedure(self, node):
self.write("def %(method_name)s(%(py_arg_names)s):" % dct)
self.indent()
self.write(format_doc_string(node))

if self.type_check:
self.write_type_checks(node)

for arg in node.arguments:
if 'optional' in arg.attributes and '._handle' in arg.py_value:
dct['f90_arg_names'] = dct['f90_arg_names'].replace(arg.py_value,
Expand Down Expand Up @@ -488,6 +495,28 @@ def visit_Interface(self, node):
self.write('continue')
self.dedent()
self.dedent()
self.write()

if self.type_check:
self.write('argTypes=[]')
self.write('for arg in args:')
self.indent()
self.write('try:')
self.indent()
self.write('argTypes.append("%s: dims \'%s\', type \'%s\'"%(str(type(arg)),'
'arg.ndim, arg.dtype))')
self.dedent()
self.write('except AttributeError:')
self.indent()
self.write('argTypes.append(str(type(arg)))')
self.dedent()
self.dedent()

self.write('raise TypeError("Not able to call a version of "')
self.indent()
self.write('"\'%(intf_name)s\' compatible with the provided args:"' % dct)
self.write('"\\n%s\\n"%"\\n".join(argTypes))')
self.dedent()
self.dedent()
self.write()

Expand Down Expand Up @@ -762,3 +791,74 @@ def write_dt_array_wrapper(self, node, el, dims):
self.write('return %(selfdot)s%(el_name)s' % dct)
self.dedent()
self.write()

def write_type_checks(self, node):
# This adds tests that checks data types and dimensions
# to ensure either the correct version of an interface is used
# either an exception is returned
for arg in node.arguments:
if 'optional' not in arg.attributes:
ft_array_dim_list = list(filter(lambda x: x.startswith("dimension"),
arg.attributes))
if ft_array_dim_list:
if ':' in ft_array_dim_list[0]:
ft_array_dim = ft_array_dim_list[0].count(',')+1
else:
ft_array_dim = -1
else:
ft_array_dim = 0

# Checks for derived types
if (arg.type.startswith('type') or arg.type.startswith('class')):
cls_mod_name = self.types[ft.strip_type(arg.type)].mod_name
cls_mod_name = self.py_mod_names.get(cls_mod_name, cls_mod_name)

cls_name = normalise_class_name(ft.strip_type(arg.type), self.class_names)
self.write('if not isinstance({0}, {1}.{2}) :'\
.format(arg.py_name, cls_mod_name, cls_name) )
self.indent()
self.write('raise TypeError')
self.dedent()

if self.make_package:
self.imports.add((self.py_mod_name, cls_mod_name))
else:
# Checks for Numpy array dimension and types
# It will fail for types that are not in the kind map
# Good enough for now if it works on standrad types
try:
array_type=ft.fortran_array_type(arg.type, self.kind_map)
except RuntimeError:
continue

py_type = ft.f2py_type(arg.type)

# bool are ignored because fortran logical are mapped to integers
if py_type not in ['bool']:
self.write('if isinstance({0},(numpy.ndarray, numpy.generic)):'\
.format(arg.py_name))
self.indent()
if ft_array_dim == -1:
self.write('if {0}.dtype.num != {1}:'\
.format(arg.py_name, array_type))
else:
self.write('if {0}.ndim != {1} or {0}.dtype.num != {2}:'\
.format(arg.py_name, str(ft_array_dim), array_type))

self.indent()
self.write('raise TypeError')
self.dedent()
self.dedent()
if ft_array_dim == 0:
# Do not write checks for unknown types
if py_type not in ['unknown']:
self.write('elif not isinstance({0},{1}):'\
.format(arg.py_name,py_type))
self.indent()
self.write('raise TypeError')
self.dedent()
else:
self.write('else:')
self.indent()
self.write('raise TypeError')
self.dedent()
5 changes: 4 additions & 1 deletion f90wrap/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def main(argv=None):
parser.add_argument("--py-max-line-length", help="Maximum length of lines in python files written. Default: 80")
parser.add_argument("--f90-max-line-length", help="Maximum length of lines in fortan files written. "
"Default: 120")
parser.add_argument('--type-check', action='store_true', default=False,
help="Check for type/shape matching of Python argument with the wrapped Fortran subroutine")

args = parser.parse_args()

Expand Down Expand Up @@ -384,7 +386,8 @@ def main(argv=None):
init_file=args.init_file,
py_mod_names=py_mod_names,
class_names=class_names,
max_length=py_max_line_length).visit(py_tree)
max_length=py_max_line_length,
type_check=type_check).visit(py_tree)
fwrap.F90WrapperGenerator(prefix, fsize, string_lengths,
abort_func, kind_map, types, default_to_inout,
max_length=f90_max_line_length).visit(f90_tree)
Expand Down