Skip to content

Commit

Permalink
Implement static import for ISequentialStream (#474) (#578)
Browse files Browse the repository at this point in the history
* Implement static import for `ISequentialStream` (#474)

* add `Test_KnownSymbols.test_symbols_in_comtypes_objidl`

* add `Test_GetModule.test_portabledeviceapi`

* split `Test_IStream`

* Rename objidl.py to stream.py

* Add RemoteRead explanation to stream.py

* Fix refactoring issues

* Apply suggestions from code review

---------

Co-authored-by: jonschz <jonschz@users.noreply.github.com>
Co-authored-by: Jun Komoda <45822440+junkmd@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 6, 2024
1 parent 716719d commit b3a5ed7
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 1 deletion.
3 changes: 2 additions & 1 deletion comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _get_known_namespaces() -> Tuple[
Note:
The interfaces that should be included in `__known_symbols__` should be limited
to those that can be said to be bound to the design concept of COM, such as
`IUnknown`, and those defined in `objidl` and `oaidl`.
`IUnknown`, `IDispatch` and `ITypeInfo`.
`comtypes` does NOT aim to statically define all COM object interfaces in
its repository.
"""
Expand All @@ -272,6 +272,7 @@ def _get_known_namespaces() -> Tuple[
"comtypes.persist",
"comtypes.typeinfo",
"comtypes.automation",
"comtypes.stream",
"comtypes",
"ctypes.wintypes",
"ctypes",
Expand Down
65 changes: 65 additions & 0 deletions comtypes/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from ctypes import Array, c_ubyte, c_ulong, HRESULT, POINTER, pointer
from typing import Tuple, TYPE_CHECKING

from comtypes import COMMETHOD, GUID, IUnknown


class ISequentialStream(IUnknown):
"""Defines methods for the stream objects in sequence."""

_iid_ = GUID("{0C733A30-2A1C-11CE-ADE5-00AA0044773D}")
_idlflags_ = []

_methods_ = [
# Note that these functions are called `Read` and `Write` in Microsoft's documentation,
# see https://learn.microsoft.com/en-us/windows/win32/api/objidl/nn-objidl-isequentialstream.
# However, the comtypes code generation detects these as `RemoteRead` and `RemoteWrite`
# for very subtle reasons, see e.g. https://stackoverflow.com/q/19820999/. We will not
# rename these in this manual import for the sake of consistency.
COMMETHOD(
[],
HRESULT,
"RemoteRead",
# This call only works if `pv` is pre-allocated with `cb` bytes,
# which cannot be done by the high level function generated by metaclasses.
# Therefore, we override the high level function to implement this behaviour
# and then delegate the call the raw COM method.
(["out"], POINTER(c_ubyte), "pv"),
(["in"], c_ulong, "cb"),
(["out"], POINTER(c_ulong), "pcbRead"),
),
COMMETHOD(
[],
HRESULT,
"RemoteWrite",
(["in"], POINTER(c_ubyte), "pv"),
(["in"], c_ulong, "cb"),
(["out"], POINTER(c_ulong), "pcbWritten"),
),
]

def RemoteRead(self, cb: int) -> Tuple["Array[c_ubyte]", int]:
"""Reads a specified number of bytes from the stream object into memory
starting at the current seek pointer.
"""
# Behaves as if `pv` is pre-allocated with `cb` bytes by the high level func.
pv = (c_ubyte * cb)()
pcb_read = pointer(c_ulong(0))
self.__com_RemoteRead(pv, c_ulong(cb), pcb_read) # type: ignore
# return both `out` parameters
return pv, pcb_read.contents.value

if TYPE_CHECKING:

def RemoteWrite(self, pv: "Array[c_ubyte]", cb: int) -> int:
"""Writes a specified number of bytes into the stream object starting at
the current seek pointer.
"""
...


# fmt: off
__known_symbols__ = [
'ISequentialStream',
]
# fmt: on
11 changes: 11 additions & 0 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def test_mscorlib(self):
# the `_Pointer` interface, rather than importing `_Pointer` from `ctypes`.
self.assertTrue(issubclass(mod._Pointer, comtypes.IUnknown))

def test_portabledeviceapi(self):
mod = comtypes.client.GetModule("portabledeviceapi.dll")
from comtypes.stream import ISequentialStream

self.assertTrue(issubclass(mod.IStream, ISequentialStream))

def test_no_replacing_Patch_namespace(self):
# NOTE: An object named `Patch` is defined in some dll.
# Depending on how the namespace is defined in the static module,
Expand Down Expand Up @@ -117,6 +123,11 @@ def test_symbols_in_comtypes(self):

self._doit(comtypes)

def test_symbols_in_comtypes_stream(self):
import comtypes.stream

self._doit(comtypes.stream)

def test_symbols_in_comtypes_automation(self):
import comtypes.automation

Expand Down
53 changes: 53 additions & 0 deletions comtypes/test/test_istream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest as ut

from ctypes import POINTER, byref, c_bool, c_ubyte
import comtypes
import comtypes.client

comtypes.client.GetModule("portabledeviceapi.dll")
from comtypes.gen.PortableDeviceApiLib import IStream


def _create_stream() -> IStream:
# Create an IStream
stream = POINTER(IStream)() # type: ignore
comtypes._ole32.CreateStreamOnHGlobal(None, c_bool(True), byref(stream))
return stream # type: ignore


class Test_RemoteWrite(ut.TestCase):
def test_RemoteWrite(self):
stream = _create_stream()
test_data = "Some data".encode("utf-8")
pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data))

written = stream.RemoteWrite(pv, len(test_data))

# Verification
self.assertEqual(written, len(test_data))


class Test_RemoteRead(ut.TestCase):
def test_RemoteRead(self):
stream = _create_stream()
test_data = "Some data".encode("utf-8")
pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data))
stream.RemoteWrite(pv, len(test_data))

# Make sure the data actually gets written before trying to read back
stream.Commit(0)
# Move the stream back to the beginning
STREAM_SEEK_SET = 0
stream.RemoteSeek(0, STREAM_SEEK_SET)

buffer_size = 1024

read_buffer, data_read = stream.RemoteRead(buffer_size)

# Verification
self.assertEqual(data_read, len(test_data))
self.assertEqual(bytearray(read_buffer)[0:data_read], test_data)


if __name__ == "__main__":
ut.main()

0 comments on commit b3a5ed7

Please sign in to comment.