Skip to content

Commit 53dbdf0

Browse files
authored
Make methods DRY and add type annotations to server/connectionpoints.py. (enthought#826)
* Add type hints. * Split methods. * Update `test.test_comserver.TestEvents.test`.
1 parent 371848a commit 53dbdf0

File tree

2 files changed

+60
-64
lines changed

2 files changed

+60
-64
lines changed

comtypes/server/connectionpoints.py

+56-63
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
import functools
12
import logging
23
from _ctypes import COMError
3-
from ctypes import pointer
4-
from typing import TYPE_CHECKING, ClassVar, Dict, List, Tuple, Type
4+
from ctypes import c_void_p, pointer
5+
from ctypes.wintypes import DWORD
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Tuple, Type
7+
from typing import Union as _UnionT
58

6-
from comtypes import COMObject, IUnknown
9+
from comtypes import GUID, COMObject, IUnknown
710
from comtypes.automation import IDispatch
811
from comtypes.connectionpoints import IConnectionPoint
912
from comtypes.hresult import *
1013
from comtypes.typeinfo import ITypeInfo, LoadRegTypeLib
1114

15+
if TYPE_CHECKING:
16+
from ctypes import _Pointer
17+
from typing import ClassVar
18+
19+
from comtypes import hints # type: ignore
20+
1221
logger = logging.getLogger(__name__)
1322

1423
__all__ = ["ConnectableObjectMixin"]
@@ -31,7 +40,9 @@ def __init__(
3140
# per MSDN, all interface methods *must* be implemented, E_NOTIMPL
3241
# is no allowed return value
3342

34-
def IConnectionPoint_Advise(self, this, pUnk, pdwCookie):
43+
def IConnectionPoint_Advise(
44+
self, this: Any, pUnk: IUnknown, pdwCookie: "_Pointer[DWORD]"
45+
) -> "hints.Hresult":
3546
if not pUnk or not pdwCookie:
3647
return E_POINTER
3748
logger.debug("Advise")
@@ -43,21 +54,25 @@ def IConnectionPoint_Advise(self, this, pUnk, pdwCookie):
4354
self._connections[self._cookie] = ptr
4455
return S_OK
4556

46-
def IConnectionPoint_Unadvise(self, this, dwCookie):
57+
def IConnectionPoint_Unadvise(self, this: Any, dwCookie: int) -> "hints.Hresult":
4758
logger.debug("Unadvise %s", dwCookie)
4859
try:
4960
del self._connections[dwCookie]
5061
except KeyError:
5162
return CONNECT_E_NOCONNECTION
5263
return S_OK
5364

54-
def IConnectionPoint_GetConnectionPointContainer(self, this, ppCPC):
65+
def IConnectionPoint_GetConnectionPointContainer(
66+
self, this: Any, ppCPC: c_void_p
67+
) -> "hints.Hresult":
5568
return E_NOTIMPL
5669

57-
def IConnectionPoint_GetConnectionInterface(self, this, pIID):
70+
def IConnectionPoint_GetConnectionInterface(
71+
self, this: Any, pIID: "_Pointer[GUID]"
72+
) -> "hints.Hresult":
5873
return E_NOTIMPL
5974

60-
def _call_sinks(self, name, *args, **kw):
75+
def _call_sinks(self, name: str, *args: Any, **kw: Any) -> List[Any]:
6176
results = []
6277
logger.debug("_call_sinks(%s, %s, *%s, **%s)", self, name, args, kw)
6378
# Is it an IDispatch derived interface? Then, events have to be delivered
@@ -66,61 +81,33 @@ def _call_sinks(self, name, *args, **kw):
6681
# for better performance, we could cache the dispids.
6782
dispid = self._typeinfo.GetIDsOfNames(name)[0]
6883
for key, p in self._connections.items():
69-
try:
70-
result = p.Invoke(dispid, *args, **kw)
71-
except COMError as details:
72-
if details.hresult == RPC_S_SERVER_UNAVAILABLE:
73-
logger.warning(
74-
"_call_sinks(%s, %s, *%s, **%s) failed; removing connection",
75-
self,
76-
name,
77-
args,
78-
kw,
79-
exc_info=True,
80-
)
81-
try:
82-
del self._connections[key]
83-
except KeyError:
84-
pass # connection already gone
85-
else:
86-
logger.warning(
87-
"_call_sinks(%s, %s, *%s, **%s)",
88-
self,
89-
name,
90-
args,
91-
kw,
92-
exc_info=True,
93-
)
94-
else:
95-
results.append(result)
84+
mth = functools.partial(p.Invoke, dispid) # type: ignore
85+
results.extend(self._call_sink(name, key, mth, *args, **kw))
9686
else:
97-
for p in self._connections.values():
98-
try:
99-
result = getattr(p, name)(*args, **kw)
100-
except COMError as details:
101-
if details.hresult == RPC_S_SERVER_UNAVAILABLE:
102-
logger.warning(
103-
"_call_sinks(%s, %s, *%s, **%s) failed; removing connection",
104-
self,
105-
name,
106-
args,
107-
kw,
108-
exc_info=True,
109-
)
110-
del self._connections[key]
111-
else:
112-
logger.warning(
113-
"_call_sinks(%s, %s, *%s, **%s)",
114-
self,
115-
name,
116-
args,
117-
kw,
118-
exc_info=True,
119-
)
120-
else:
121-
results.append(result)
87+
for key, p in self._connections.items():
88+
mth = getattr(p, name)
89+
results.extend(self._call_sink(name, key, mth, *args, **kw))
12290
return results
12391

92+
def _call_sink(
93+
self, name: str, key: int, mth: Callable[..., Any], *args: Any, **kw: Any
94+
) -> Iterator[Any]:
95+
try:
96+
result = mth(*args, **kw)
97+
except COMError as details:
98+
if details.hresult == RPC_S_SERVER_UNAVAILABLE:
99+
warn_msg = "_call_sinks(%s, %s, *%s, **%s) failed; removing connection"
100+
logger.warning(warn_msg, self, name, args, kw, exc_info=True)
101+
try:
102+
del self._connections[key]
103+
except KeyError:
104+
pass # connection already gone
105+
else:
106+
warn_msg = "_call_sinks(%s, %s, *%s, **%s)"
107+
logger.warning(warn_msg, self, name, args, kw, exc_info=True)
108+
else:
109+
yield result
110+
124111

125112
class ConnectableObjectMixin(object):
126113
"""Mixin which implements IConnectionPointContainer.
@@ -143,13 +130,17 @@ def __init__(self) -> None:
143130
typeinfo = tlib.GetTypeInfoOfGuid(itf._iid_)
144131
self.__connections[itf] = ConnectionPointImpl(itf, typeinfo)
145132

146-
def IConnectionPointContainer_EnumConnectionPoints(self, this, ppEnum):
133+
def IConnectionPointContainer_EnumConnectionPoints(
134+
self, this: Any, ppEnum: c_void_p
135+
) -> "hints.Hresult":
147136
# according to MSDN, E_NOTIMPL is specificially disallowed
148137
# because, without typeinfo, there's no way for the caller to
149138
# find out.
150139
return E_NOTIMPL
151140

152-
def IConnectionPointContainer_FindConnectionPoint(self, this, refiid, ppcp):
141+
def IConnectionPointContainer_FindConnectionPoint(
142+
self, this: Any, refiid: "_Pointer[GUID]", ppcp: c_void_p
143+
) -> "hints.Hresult":
153144
iid = refiid[0]
154145
logger.debug("FindConnectionPoint %s", iid)
155146
if not ppcp:
@@ -169,7 +160,9 @@ def IConnectionPointContainer_FindConnectionPoint(self, this, refiid, ppcp):
169160
logger.debug("No connectionpoint found")
170161
return CONNECT_E_NOCONNECTION
171162

172-
def Fire_Event(self, itf, name, *args, **kw):
163+
def Fire_Event(
164+
self, itf: _UnionT[int, Type[IDispatch]], name: str, *args: Any, **kw: Any
165+
) -> Any:
173166
# Fire event 'name' with arguments *args and **kw.
174167
# Accepts either an interface index or an interface as first argument.
175168
# Returns a list of results.

comtypes/test/test_comserver.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,10 @@ def test(self):
297297
import comtypes.test.test_comserver
298298

299299
doctest.testmod(
300-
comtypes.test.test_comserver, verbose=False, optionflags=doctest.ELLIPSIS
300+
comtypes.test.test_comserver,
301+
verbose=False,
302+
optionflags=doctest.ELLIPSIS,
303+
raise_on_error=True,
301304
)
302305

303306

0 commit comments

Comments
 (0)