Skip to content

Implement __index__ to avoid TypeError #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

- name: Test
run: |
python -m unittest discover -v -s tests
pytest -v tests


publish:
Expand All @@ -41,12 +41,12 @@ jobs:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python 3.8
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: 3.8

Expand Down
3 changes: 3 additions & 0 deletions cwrap/basecenum.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __and__(self, other):

def __int__(self):
return self.value

def __index__(self):
return self.value

def __contains__(self, item):
return self & item == item
Expand Down
125 changes: 75 additions & 50 deletions cwrap/prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@
import ctypes
import inspect
import re
import six
import sys
from types import MethodType

import six


class TypeDefinition(object):
def __init__(self,
type_class_or_function,
is_return_type,
storage_type,
errcheck):
def __init__(self, type_class_or_function, is_return_type, storage_type, errcheck):
self.storage_type = storage_type
self.is_return_type = is_return_type
self.type_class_or_function = type_class_or_function
Expand All @@ -38,6 +35,7 @@ def __init__(self,


if six.PY3:

class CStringHelper(object):
@classmethod
def from_param(cls, value):
Expand Down Expand Up @@ -67,22 +65,24 @@ def toStr(result, func, arguments):
return None
return result.decode()


REGISTERED_TYPES = {}
""":type: dict[str,TypeDefinition]"""


def _registerType(type_name,
type_class_or_function,
is_return_type=True,
storage_type=None,
errcheck=None):
def _registerType(
type_name,
type_class_or_function,
is_return_type=True,
storage_type=None,
errcheck=None,
):
if type_name in REGISTERED_TYPES:
raise PrototypeError("Type: '%s' already registered!" % type_name)

REGISTERED_TYPES[type_name] = TypeDefinition(type_class_or_function,
is_return_type,
storage_type,
errcheck)
REGISTERED_TYPES[type_name] = TypeDefinition(
type_class_or_function, is_return_type, storage_type, errcheck
)


_registerType("void", None)
Expand All @@ -108,16 +108,19 @@ def _registerType(type_name,
"char*",
CStringHelper,
storage_type=ctypes.c_char_p,
errcheck=CStringHelper.toStr)
errcheck=CStringHelper.toStr,
)
_registerType("float", ctypes.c_float)
_registerType("float*", ctypes.POINTER(ctypes.c_float))
_registerType("double", ctypes.c_double)
_registerType("double*", ctypes.POINTER(ctypes.c_double))
_registerType("py_object", ctypes.py_object)

PROTOTYPE_PATTERN = ("(?P<return>[a-zA-Z][a-zA-Z0-9_*]*)"
" +(?P<function>[a-zA-Z]\w*)"
" *[(](?P<arguments>[a-zA-Z0-9_*, ]*)[)]")
PROTOTYPE_PATTERN = (
r"(?P<return>[a-zA-Z][a-zA-Z0-9_*]*)"
r" +(?P<function>[a-zA-Z]\w*)"
r" *[(](?P<arguments>[a-zA-Z0-9_*, ]*)[)]"
)


class PrototypeError(Exception):
Expand All @@ -127,7 +130,7 @@ class PrototypeError(Exception):
class Prototype(object):
pattern = re.compile(PROTOTYPE_PATTERN)

def __init__(self, lib, prototype, bind=False, allow_attribute_error = False):
def __init__(self, lib, prototype, bind=False, allow_attribute_error=False):
super(Prototype, self).__init__()
self._lib = lib
self._prototype = prototype
Expand All @@ -143,13 +146,13 @@ def _parseType(self, type_name):

if type_name in REGISTERED_TYPES:
type_definition = REGISTERED_TYPES[type_name]
return (type_definition.type_class_or_function,
type_definition.storage_type,
type_definition.errcheck)
return (
type_definition.type_class_or_function,
type_definition.storage_type,
type_definition.errcheck,
)
raise ValueError("Unknown type: %s" % type_name)



def shouldBeBound(self):
return self._bind

Expand All @@ -168,15 +171,27 @@ def resolve(self):
except AttributeError:
if self._allow_attribute_error:
return
raise PrototypeError("Can not find function: %s in library: %s" % (function_name , self._lib))

if not restype in REGISTERED_TYPES or not REGISTERED_TYPES[restype].is_return_type:
sys.stderr.write("The type used as return type: %s is not registered as a return type.\n" % restype)
raise PrototypeError(
"Can not find function: %s in library: %s"
% (function_name, self._lib)
)

if (
not restype in REGISTERED_TYPES
or not REGISTERED_TYPES[restype].is_return_type
):
sys.stderr.write(
"The type used as return type: %s is not registered as a return type.\n"
% restype
)

return_type, storage_type, errcheck = self._parseType(restype)

if inspect.isclass(return_type):
sys.stderr.write(" Correct type may be: %s_ref or %s_obj.\n" % (restype, restype))
sys.stderr.write(
" Correct type may be: %s_ref or %s_obj.\n"
% (restype, restype)
)

return None

Expand Down Expand Up @@ -211,12 +226,16 @@ def __call__(self, *args):
self._resolved = True
if self._func is None:
if self._allow_attribute_error:
raise NotImplementedError("Function:%s has not been properly resolved" % self.__name__)
raise NotImplementedError(
"Function:%s has not been properly resolved" % self.__name__
)
else:
raise PrototypeError("Prototype has not been properly resolved")
if self._bind and not args[0].is_initialized():
raise ValueError("Called bound function with uninitialized object of type "
f"{type(args[0]).__name__}")
raise ValueError(
"Called bound function with uninitialized object of type "
f"{type(args[0]).__name__}"
)
try:
return self._func(*args)
except ctypes.ArgumentError as err:
Expand All @@ -225,19 +244,21 @@ def __call__(self, *args):
# ctypes and, as such, just an implementation detail.
# ArgumentError.message will look like this
# `argument 4: <type 'exceptions.TypeError'>: wrong type`
# The only useful information here is the index of the argument
# The only useful information here is the index of the argument
errMsg = err.message if hasattr(err, "message") else str(err)
tokens = re.split("[ :]", errMsg)
argidx = int(tokens[1]) - 1 # it starts from 1
raise TypeError((
"Argument {argidx}: cannot create a {argtype} from the given "
"value {actval} ({acttype})")
.format(argtype=self._func.argtypes[argidx],
argidx=argidx,
actval=repr(args[argidx]),
acttype=type(args[argidx]),
)
) from err
raise TypeError(
(
"Argument {argidx}: cannot create a {argtype} from the given "
"value {actval} ({acttype})"
).format(
argtype=self._func.argtypes[argidx],
argidx=argidx,
actval=repr(args[argidx]),
acttype=type(args[argidx]),
)
) from err

def __get__(self, instance, owner):
if not self._resolved:
Expand All @@ -259,11 +280,15 @@ def __repr__(self):
return 'Prototype("%s"%s)' % (self._prototype, bound)

@classmethod
def registerType(cls, type_name, type_class_or_function, is_return_type=True, storage_type=None):
def registerType(
cls, type_name, type_class_or_function, is_return_type=True, storage_type=None
):
if storage_type is None and (inspect.isfunction(type_class_or_function)):
storage_type = ctypes.c_void_p

_registerType(type_name,
type_class_or_function,
is_return_type = is_return_type,
storage_type = storage_type)
storage_type = ctypes.c_void_p

_registerType(
type_name,
type_class_or_function,
is_return_type=is_return_type,
storage_type=storage_type,
)
14 changes: 11 additions & 3 deletions tests/test_basecenum.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import ctypes
from cwrap import BaseCEnum, Prototype, load
import os
import unittest

from cwrap import BaseCEnum, Prototype, load


class BaseCEnumTest(unittest.TestCase):
def test_base_c_enum(self):
class enum(BaseCEnum):
Expand Down Expand Up @@ -121,3 +120,12 @@ class Endumb(BaseCEnum):

Endumb.addEnum("SOME_VALUE", -1)
assert Endumb.SOME_VALUE.abs() == 1

def test_base_c_enum_to_c_int():
class NumberEnum(BaseCEnum):
pass

NumberEnum.addEnum("ONE", 1)

c_int_value = ctypes.c_int(NumberEnum.ONE)
assert c_int_value.value == NumberEnum.ONE
6 changes: 3 additions & 3 deletions tests/test_basecvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import unittest

class TestPrototype(Prototype):
class ForTestPrototype(Prototype):
lib = load("msvcrt" if os.name == "nt" else None)

def __init__(self, prototype):
super(TestPrototype, self).__init__(self.lib, prototype)
super(ForTestPrototype, self).__init__(self.lib, prototype)

class UnsignedByteValue(BaseCValue):
DATA_TYPE = c_ubyte
Expand All @@ -21,7 +21,7 @@ class SqrtDouble(BaseCValue):

class BaseCValueTest(unittest.TestCase):
def setUp(self):
self.sqrt_double = TestPrototype("sqrt_double sqrt(double)")
self.sqrt_double = ForTestPrototype("sqrt_double sqrt(double)")


def test_illegal_type(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from cwrap import Prototype, CFILE, load, open as copen

# Local copies so that the real ones don't get changed
class TestUtilPrototype(Prototype):
class ForTestUtilPrototype(Prototype):
lib = load("msvcrt" if os.name == "nt" else None)
def __init__(self, prototype, bind=False):
super(TestUtilPrototype, self).__init__(TestUtilPrototype.lib, prototype, bind=bind)
super(ForTestUtilPrototype, self).__init__(ForTestUtilPrototype.lib, prototype, bind=bind)

fileno = TestUtilPrototype("int fileno(FILE)")
fileno = ForTestUtilPrototype("int fileno(FILE)")


class CFILETest(unittest.TestCase):
Expand Down
Loading