Skip to content

Commit

Permalink
Implement static import for ISequentialStream (enthought#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonschz committed Jun 29, 2024
1 parent 11d7651 commit 8fbc267
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def _get_known_namespaces() -> Tuple[
"comtypes.persist",
"comtypes.typeinfo",
"comtypes.automation",
"comtypes.objidl",
"comtypes",
"ctypes.wintypes",
"ctypes",
Expand Down
59 changes: 59 additions & 0 deletions comtypes/objidl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from ctypes import Array, c_ubyte, c_ulong, HRESULT, POINTER, pointer
from typing import Tuple, TYPE_CHECKING

from comtypes import COMMETHOD, GUID, IUnknown


class ISequentialStream(IUnknown):
"""Defines methods for the stream objects in sequence."""

_iid_ = GUID("{0C733A30-2A1C-11CE-ADE5-00AA0044773D}")
_idlflags_ = []

_methods_ = [
COMMETHOD(
[],
HRESULT,
"RemoteRead",
# This call only works if `pv` is pre-allocated with `cb` bytes,
# which cannot be done by the high level function generated by metaclasses.
# Therefore, we override the high level function to implement this behaviour
# and then delegate the call the raw COM method.
(["out"], POINTER(c_ubyte), "pv"),
(["in"], c_ulong, "cb"),
(["out"], POINTER(c_ulong), "pcbRead"),
),
COMMETHOD(
[],
HRESULT,
"RemoteWrite",
(["in"], POINTER(c_ubyte), "pv"),
(["in"], c_ulong, "cb"),
(["out"], POINTER(c_ulong), "pcbWritten"),
),
]

def RemoteRead(self, cb: int) -> Tuple["Array[c_ubyte]", int]:
"""Reads a specified number of bytes from the stream object into memory
starting at the current seek pointer.
"""
# Behaves as if `pv` is pre-allocated with `cb` bytes by the high level func.
pv = (c_ubyte * cb)()
pcb_read = pointer(c_ulong(0))
self.__com_RemoteRead(pv, c_ulong(cb), pcb_read) # type: ignore
return pv, pcb_read.contents.value

if TYPE_CHECKING:

def RemoteWrite(self, pv: "Array[c_ubyte]", cb: int) -> int:
"""Writes a specified number of bytes into the stream object starting at
the current seek pointer.
"""
...


# fmt: off
__known_symbols__ = [
'ISequentialStream',
]
# fmt: on
39 changes: 39 additions & 0 deletions comtypes/test/test_istream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest as ut

from ctypes import POINTER, byref, c_bool, c_ubyte
import comtypes
import comtypes.client

comtypes.client.GetModule("portabledeviceapi.dll")
from comtypes.gen.PortableDeviceApiLib import IStream


class Test_IStream(ut.TestCase):
def test_istream(self):
# Create an IStream
stream: IStream = POINTER(IStream)() # type: ignore
comtypes._ole32.CreateStreamOnHGlobal(None, c_bool(True), byref(stream))

test_data = "Some data".encode("utf-8")
pv = (c_ubyte * len(test_data)).from_buffer(bytearray(test_data))

written = stream.RemoteWrite(pv, len(test_data))
self.assertEqual(written, len(test_data))
# make sure the data actually gets written before trying to read back
stream.Commit(0)

# Move the stream back to the beginning
STREAM_SEEK_SET = 0
stream.RemoteSeek(0, STREAM_SEEK_SET)

read_buffer_size = 1024

read_buffer, data_read = stream.RemoteRead(read_buffer_size)

# Verification
self.assertEqual(data_read, len(test_data))
self.assertEqual(bytearray(read_buffer)[0:data_read], test_data)


if __name__ == "__main__":
ut.main()

0 comments on commit 8fbc267

Please sign in to comment.