Skip to content

Commit

Permalink
add type annotations other than Dispatch func (enthought#444)
Browse files Browse the repository at this point in the history
(cherry picked from commit 11410ce)
  • Loading branch information
junkmd committed Feb 3, 2024
1 parent 2a1e918 commit d9737d7
Showing 1 changed file with 35 additions and 29 deletions.
64 changes: 35 additions & 29 deletions comtypes/client/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import ctypes
from typing import Any, Dict, Optional, Set, Type, TypeVar

from comtypes import automation
from comtypes.client import lazybind
from comtypes import COMError, hresult as hres, _is_object
from comtypes import COMError, GUID, IUnknown, hresult as hres, _is_object


_T_IUnknown = TypeVar("_T_IUnknown", bound=IUnknown)
# These errors generally mean the property or method exists,
# but can't be used in this context - eg, property instead of a method, etc.
# Used to determine if we have a real error or not.
Expand All @@ -17,8 +20,9 @@


def Dispatch(obj):
# Wrap an object in a Dispatch instance, exposing methods and properties
# via fully dynamic dispatch
"""Wrap an object in a Dispatch instance, exposing methods and properties
via fully dynamic dispatch.
"""
if isinstance(obj, _Dispatch):
return obj
if isinstance(obj, ctypes.POINTER(automation.IDispatch)):
Expand All @@ -33,19 +37,19 @@ def Dispatch(obj):
class MethodCaller:
# Wrong name: does not only call methods but also handle
# property accesses.
def __init__(self, _id, _obj):
def __init__(self, _id: int, _obj: "_Dispatch") -> None:
self._id = _id
self._obj = _obj

def __call__(self, *args):
def __call__(self, *args: Any) -> Any:
return self._obj._comobj.Invoke(self._id, *args)

def __getitem__(self, *args):
def __getitem__(self, *args: Any) -> Any:
return self._obj._comobj.Invoke(
self._id, *args, _invkind=automation.DISPATCH_PROPERTYGET
)

def __setitem__(self, *args):
def __setitem__(self, *args: Any) -> None:
if _is_object(args[-1]):
self._obj._comobj.Invoke(
self._id, *args, _invkind=automation.DISPATCH_PROPERTYPUTREF
Expand All @@ -57,22 +61,26 @@ def __setitem__(self, *args):


class _Dispatch(object):
# Expose methods and properties via fully dynamic dispatch
def __init__(self, comobj):
"""Expose methods and properties via fully dynamic dispatch."""

_comobj: automation.IDispatch
_ids: Dict[str, int]
_methods: Set[str]

def __init__(self, comobj: "ctypes._Pointer[automation.IDispatch]"):
self.__dict__["_comobj"] = comobj
self.__dict__[
"_ids"
] = {} # Tiny optimization: trying not to use GetIDsOfNames more than once
# Tiny optimization: trying not to use GetIDsOfNames more than once
self.__dict__["_ids"] = {}
self.__dict__["_methods"] = set()

def __enum(self):
e = self._comobj.Invoke(-4) # DISPID_NEWENUM
def __enum(self) -> automation.IEnumVARIANT:
e: IUnknown = self._comobj.Invoke(-4) # DISPID_NEWENUM
return e.QueryInterface(automation.IEnumVARIANT)

def __hash__(self):
def __hash__(self) -> int:
return hash(self._comobj)

def __getitem__(self, index):
def __getitem__(self, index: Any) -> Any:
enum = self.__enum()
if index > 0:
if 0 != enum.Skip(index):
Expand All @@ -82,11 +90,13 @@ def __getitem__(self, index):
raise IndexError("index out of range")
return item

def QueryInterface(self, interface, iid=None):
def QueryInterface(
self, interface: Type[_T_IUnknown], iid: Optional[GUID] = None
) -> _T_IUnknown:
"""QueryInterface is forwarded to the real com object."""
return self._comobj.QueryInterface(interface, iid)

def _FlagAsMethod(self, *names):
def _FlagAsMethod(self, *names: str) -> None:
"""Flag these attribute names as being methods.
Some objects do not correctly differentiate methods and
properties, leading to problems when calling these methods.
Expand All @@ -100,7 +110,7 @@ def _FlagAsMethod(self, *names):
"""
self._methods.update(names)

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
if name.startswith("__") and name.endswith("__"):
raise AttributeError(name)
# tc = self._comobj.GetTypeInfo(0).QueryInterface(comtypes.typeinfo.ITypeComp)
Expand All @@ -119,20 +129,16 @@ def __getattr__(self, name):
try:
result = self._comobj.Invoke(dispid, _invkind=flags)
except COMError as err:
(hresult, text, details) = err.args
(hresult, _, _) = err.args
if hresult in ERRORS_BAD_CONTEXT:
result = MethodCaller(dispid, self)
self.__dict__[name] = result
else:
# The line break is important for 2to3 to work correctly
raise
except:
# The line break is important for 2to3 to work correctly
raise
raise err

return result

def __setattr__(self, name, value):
def __setattr__(self, name: str, value: Any) -> None:
dispid = self._ids.get(name)
if not dispid:
dispid = self._comobj.GetIDsOfNames(name)[0]
Expand All @@ -142,7 +148,7 @@ def __setattr__(self, name, value):
flags = 8 if _is_object(value) else 4
return self._comobj.Invoke(dispid, value, _invkind=flags)

def __iter__(self):
def __iter__(self) -> "_Collection":
return _Collection(self.__enum())

# def __setitem__(self, index, value):
Expand All @@ -156,10 +162,10 @@ def __iter__(self):


class _Collection(object):
def __init__(self, enum):
def __init__(self, enum: automation.IEnumVARIANT):
self.enum = enum

def __next__(self):
def __next__(self) -> Any:
item, fetched = self.enum.Next(1)
if fetched:
return item
Expand Down

0 comments on commit d9737d7

Please sign in to comment.