Skip to content

Commit

Permalink
[3.13] gh-127183: Add _ctypes.CopyComPointer tests (GH-127184) (GH-…
Browse files Browse the repository at this point in the history
…127251)

gh-127183: Add `_ctypes.CopyComPointer` tests (GH-127184)

* Make `create_shelllink_persist` top level function.

* Add `CopyComPointerTests`.

* Add more tests.

* Update tests.

* Add assertions for `Release`'s return value.
(cherry picked from commit c7f1e3e)

Co-authored-by: Jun Komoda <45822440+junkmd@users.noreply.github.com>
  • Loading branch information
miss-islington and junkmd authored Nov 26, 2024
1 parent 8b2e303 commit 0f77357
Showing 1 changed file with 115 additions and 17 deletions.
132 changes: 115 additions & 17 deletions Lib/test/test_ctypes/test_win32_com_foreign_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
raise unittest.SkipTest("Windows-specific test")


from _ctypes import COMError
from _ctypes import COMError, CopyComPointer
from ctypes import HRESULT


Expand Down Expand Up @@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2):
)


def create_shelllink_persist(typ):
ppst = typ()
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
ole32.CoCreateInstance(
byref(CLSID_ShellLink),
None,
CLSCTX_SERVER,
byref(IID_IPersist),
byref(ppst),
)
return ppst


class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
def setUp(self):
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
Expand All @@ -88,19 +101,6 @@ def tearDown(self):
ole32.CoUninitialize()
gc.collect()

@staticmethod
def create_shelllink_persist(typ):
ppst = typ()
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
ole32.CoCreateInstance(
byref(CLSID_ShellLink),
None,
CLSCTX_SERVER,
byref(IID_IPersist),
byref(ppst),
)
return ppst

def test_without_paramflags_and_iid(self):
class IUnknown(c_void_p):
QueryInterface = proto_query_interface()
Expand All @@ -110,7 +110,7 @@ class IUnknown(c_void_p):
class IPersist(IUnknown):
GetClassID = proto_get_class_id()

ppst = self.create_shelllink_persist(IPersist)
ppst = create_shelllink_persist(IPersist)

clsid = GUID()
hr_getclsid = ppst.GetClassID(byref(clsid))
Expand Down Expand Up @@ -142,7 +142,7 @@ class IUnknown(c_void_p):
class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),))

ppst = self.create_shelllink_persist(IPersist)
ppst = create_shelllink_persist(IPersist)

clsid = ppst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
Expand All @@ -167,7 +167,7 @@ class IUnknown(c_void_p):
class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)

ppst = self.create_shelllink_persist(IPersist)
ppst = create_shelllink_persist(IPersist)

clsid = ppst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
Expand All @@ -184,5 +184,103 @@ class IPersist(IUnknown):
self.assertEqual(0, ppst.Release())


class CopyComPointerTests(unittest.TestCase):
def setUp(self):
ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)

class IUnknown(c_void_p):
QueryInterface = proto_query_interface(None, IID_IUnknown)
AddRef = proto_add_ref()
Release = proto_release()

class IPersist(IUnknown):
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)

self.IUnknown = IUnknown
self.IPersist = IPersist

def tearDown(self):
ole32.CoUninitialize()
gc.collect()

def test_both_are_null(self):
src = self.IPersist()
dst = self.IPersist()

hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)

self.assertIsNone(src.value)
self.assertIsNone(dst.value)

def test_src_is_nonnull_and_dest_is_null(self):
# The reference count of the COM pointer created by `CoCreateInstance`
# is initially 1.
src = create_shelllink_persist(self.IPersist)
dst = self.IPersist()

# `CopyComPointer` calls `AddRef` explicitly in the C implementation.
# The refcount of `src` is incremented from 1 to 2 here.
hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)
self.assertEqual(src.value, dst.value)

# This indicates that the refcount was 2 before the `Release` call.
self.assertEqual(1, src.Release())

clsid = dst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))

self.assertEqual(0, dst.Release())

def test_src_is_null_and_dest_is_nonnull(self):
src = self.IPersist()
dst_orig = create_shelllink_persist(self.IPersist)
dst = self.IPersist()
CopyComPointer(dst_orig, byref(dst))
self.assertEqual(1, dst_orig.Release())

clsid = dst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))

# This does NOT affects the refcount of `dst_orig`.
hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)
self.assertIsNone(dst.value)

with self.assertRaises(ValueError):
dst.GetClassID() # NULL COM pointer access

# This indicates that the refcount was 1 before the `Release` call.
self.assertEqual(0, dst_orig.Release())

def test_both_are_nonnull(self):
src = create_shelllink_persist(self.IPersist)
dst_orig = create_shelllink_persist(self.IPersist)
dst = self.IPersist()
CopyComPointer(dst_orig, byref(dst))
self.assertEqual(1, dst_orig.Release())

self.assertEqual(dst.value, dst_orig.value)
self.assertNotEqual(src.value, dst.value)

hr = CopyComPointer(src, byref(dst))

self.assertEqual(S_OK, hr)
self.assertEqual(src.value, dst.value)
self.assertNotEqual(dst.value, dst_orig.value)

self.assertEqual(1, src.Release())

clsid = dst.GetClassID()
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))

self.assertEqual(0, dst.Release())
self.assertEqual(0, dst_orig.Release())


if __name__ == '__main__':
unittest.main()

0 comments on commit 0f77357

Please sign in to comment.