Skip to content

Commit

Permalink
Update codegenerator type annotations (enthought#457)
Browse files Browse the repository at this point in the history
* update `codegenerator` type annotations

* lcid

(cherry picked from commit 9e7b900)
  • Loading branch information
junkmd committed Feb 3, 2024
1 parent d5159f1 commit ee7ae9b
Showing 1 changed file with 56 additions and 87 deletions.
143 changes: 56 additions & 87 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,16 @@ def name_friendly_module(tlib):

################################################################

_DefValType = _UnionT["lcid", Any, None]
_IdlFlagType = _UnionT[str, dispid, helpstring]

def _to_arg_definition(type_name, arg_name, idlflags, default):
# type: (str, str, List[str], _UnionT[lcid, Any, None]) -> str

def _to_arg_definition(
type_name: str,
arg_name: str,
idlflags: List[str],
default: _DefValType,
) -> str:
if default is not None:
elms = (idlflags, type_name, arg_name, default)
code = " (%r, %s, '%s', %r)" % elms
Expand Down Expand Up @@ -217,8 +224,7 @@ def _to_arg_definition(type_name, arg_name, idlflags, default):


class ComMethodGenerator(object):
def __init__(self, m, isdual):
# type: (typedesc.ComMethod, bool) -> None
def __init__(self, m: typedesc.ComMethod, isdual: bool) -> None:
self._m = m
self._isdual = isdual
self._stream = io.StringIO()
Expand All @@ -232,9 +238,8 @@ def generate(self):
self._make_withargs()
return self._stream.getvalue()

def _get_common_elms(self):
# type: () -> Tuple[List[_UnionT[str, dispid, helpstring]], str, str]
idlflags = [] # type: List[_UnionT[str, dispid, helpstring]]
def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]:
idlflags: List[_IdlFlagType] = []
if self._isdual:
idlflags.append(dispid(self._m.memid))
idlflags.extend(self._m.idlflags)
Expand All @@ -245,8 +250,7 @@ def _get_common_elms(self):
type_name = self._to_type_name(self._m.returns)
return (idlflags, type_name, self._m.name)

def _make_noargs(self):
# type: () -> None
def _make_noargs(self) -> None:
elms = self._get_common_elms()
code = " COMMETHOD(%r, %s, '%s')," % elms
if len(code) > 80:
Expand All @@ -259,8 +263,7 @@ def _make_noargs(self):
) % elms
print(code, file=self._stream)

def _make_withargs(self):
# type: () -> None
def _make_withargs(self) -> None:
code = (
" COMMETHOD(\n" " %r,\n" " %s,\n" " '%s',"
) % self._get_common_elms()
Expand All @@ -269,8 +272,7 @@ def _make_withargs(self):
print(",\n".join(arglist), file=self._stream)
print(" ),", file=self._stream)

def _iter_args(self):
# type: () -> Iterator[Tuple[str, str, List[str], _UnionT[lcid, Any, None]]]
def _iter_args(self) -> Iterator[Tuple[str, str, List[str], _DefValType]]:
for typ, arg_name, _f, _defval in self._m.arguments:
###########################################################
# IDL files that contain 'open arrays' or 'conformant
Expand Down Expand Up @@ -327,8 +329,7 @@ def _iter_args(self):


class DispMethodGenerator(object):
def __init__(self, m):
# type: (typedesc.DispMethod) -> None
def __init__(self, m: typedesc.DispMethod) -> None:
self._m = m
self._stream = io.StringIO()
self._to_type_name = TypeNamer()
Expand All @@ -341,18 +342,16 @@ def generate(self):
self._make_withargs()
return self._stream.getvalue()

def _get_common_elms(self):
# type: () -> Tuple[List[_UnionT[str, dispid, helpstring]], str, str]
idlflags = [] # type: List[_UnionT[str, dispid, helpstring]]
def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]:
idlflags: List[_IdlFlagType] = []
idlflags.append(dispid(self._m.dispid))
idlflags.extend(self._m.idlflags)
if __debug__ and self._m.doc:
idlflags.insert(1, helpstring(self._m.doc))
type_name = self._to_type_name(self._m.returns)
return (idlflags, type_name, self._m.name)

def _make_noargs(self):
# type: () -> None
def _make_noargs(self) -> None:
elms = self._get_common_elms()
code = " DISPMETHOD(%r, %s, '%s')," % elms
if len(code) > 80:
Expand All @@ -365,8 +364,7 @@ def _make_noargs(self):
) % elms
print(code, file=self._stream)

def _make_withargs(self):
# type: () -> None
def _make_withargs(self) -> None:
code = (
" DISPMETHOD(\n" " %r,\n" " %s,\n" " '%s',"
) % self._get_common_elms()
Expand All @@ -375,16 +373,14 @@ def _make_withargs(self):
print(",\n".join(arglist), file=self._stream)
print(" ),", file=self._stream)

def _iter_args(self):
# type: () -> Iterator[Tuple[str, str, List[str], _UnionT[lcid, Any, None]]]
def _iter_args(self) -> Iterator[Tuple[str, str, List[str], _DefValType]]:
for typ, arg_name, idlflags, default in self._m.arguments:
type_name = self._to_type_name(typ)
yield (type_name, arg_name, idlflags, default)


class DispPropertyGenerator(object):
def __init__(self, m):
# type: (typedesc.DispProperty) -> None
def __init__(self, m: typedesc.DispProperty) -> None:
self._m = m
self._to_type_name = TypeNamer()

Expand All @@ -402,9 +398,8 @@ def generate(self):
) % elms
return code + "\n"

def _get_common_elms(self):
# type: () -> Tuple[List[_UnionT[str, dispid, helpstring]], str, str]
idlflags = [] # type: List[_UnionT[str, dispid, helpstring]]
def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]:
idlflags: List[_IdlFlagType] = []
idlflags.append(dispid(self._m.dispid))
idlflags.extend(self._m.idlflags)
if __debug__ and self._m.doc:
Expand Down Expand Up @@ -569,21 +564,18 @@ def need_VARIANT_imports(self, value):
if "datetime.datetime(" in text:
self.imports.add("datetime")

def _to_docstring(self, orig, depth=1):
# type: (str, int) -> str
def _to_docstring(self, orig: str, depth: int = 1) -> str:
# increasing `depth` by one increases indentation by one
indent = " " * depth
# some chars are replaced to avoid causing a `SyntaxError`
repled = orig.replace("\\", r"\\").replace('"', r"'")
return '%s"""%s"""' % (indent, repled)

def ArrayType(self, tp):
# type: (typedesc.ArrayType) -> None
def ArrayType(self, tp: typedesc.ArrayType) -> None:
self.generate(get_real_type(tp.typ))
self.generate(tp.typ)

def EnumValue(self, tp):
# type: (typedesc.EnumValue) -> None
def EnumValue(self, tp: typedesc.EnumValue) -> None:
self.last_item_class = False
value = int(tp.value)
if keyword.iskeyword(tp.name):
Expand All @@ -594,8 +586,7 @@ def EnumValue(self, tp):
print("%s = %d" % (tp_name, value), file=self.stream)
self.names.add(tp_name)

def Enumeration(self, tp):
# type: (typedesc.Enumeration) -> None
def Enumeration(self, tp: typedesc.Enumeration) -> None:
self.last_item_class = False
if tp.name:
print("# values for enumeration '%s'" % tp.name, file=self.stream)
Expand All @@ -611,8 +602,7 @@ def Enumeration(self, tp):
print("%s = c_int # enum" % tp.name, file=self.stream)
self.names.add(tp.name)

def Typedef(self, tp):
# type: (typedesc.Typedef) -> None
def Typedef(self, tp: typedesc.Typedef) -> None:
if isinstance(tp.typ, (typedesc.Structure, typedesc.Union)):
self.generate(tp.typ.get_head())
self.more.add(tp.typ)
Expand All @@ -627,12 +617,10 @@ def Typedef(self, tp):
self.last_item_class = False
self.names.add(tp.name)

def FundamentalType(self, item):
# type: (typedesc.FundamentalType) -> None
def FundamentalType(self, item: typedesc.FundamentalType) -> None:
pass # we should check if this is known somewhere

def StructureHead(self, head):
# type: (typedesc.StructureHead) -> None
def StructureHead(self, head: typedesc.StructureHead) -> None:
for struct in head.struct.bases:
self.generate(struct.get_head())
self.more.add(struct)
Expand Down Expand Up @@ -722,20 +710,17 @@ def StructureHead(self, head):
print(file=self.stream)
self.names.add(head.struct.name)

def Structure(self, struct):
# type: (typedesc.Structure) -> None
def Structure(self, struct: typedesc.Structure) -> None:
self.generate(struct.get_head())
self.generate(struct.get_body())

def Union(self, union):
# type: (typedesc.Union) -> None
def Union(self, union: typedesc.Union) -> None:
self.generate(union.get_head())
self.generate(union.get_body())

def StructureBody(self, body):
# type: (typedesc.StructureBody) -> None
fields = [] # type: List[typedesc.Field]
methods = [] # type: List[typedesc.Method]
def StructureBody(self, body: typedesc.StructureBody) -> None:
fields: List[typedesc.Field] = []
methods: List[typedesc.Method] = []
for m in body.struct.members:
if type(m) is typedesc.Field:
fields.append(m)
Expand Down Expand Up @@ -863,8 +848,7 @@ def StructureBody(self, body):
################################################################
# top-level typedesc generators
#
def TypeLib(self, lib):
# type: (typedesc.TypeLib) -> None
def TypeLib(self, lib: typedesc.TypeLib) -> None:
# Hm, in user code we have to write:
# class MyServer(COMObject, ...):
# _com_interfaces_ = [MyTypeLib.IInterface]
Expand Down Expand Up @@ -893,29 +877,25 @@ def TypeLib(self, lib):
print(file=self.stream)
print(file=self.stream)

def External(self, ext):
# type: (typedesc.External) -> None
def External(self, ext: typedesc.External) -> None:
modname = name_wrapper_module(ext.tlib)
if modname not in self.imports:
self.externals.append(ext.tlib)
self.imports.add(modname)

def Constant(self, tp):
# type: (typedesc.Constant) -> None
def Constant(self, tp: typedesc.Constant) -> None:
self.last_item_class = False
print(
"%s = %r # Constant %s" % (tp.name, tp.value, self._to_type_name(tp.typ)),
file=self.stream,
)
self.names.add(tp.name)

def SAFEARRAYType(self, sa):
# type: (typedesc.SAFEARRAYType) -> None
def SAFEARRAYType(self, sa: typedesc.SAFEARRAYType) -> None:
self.generate(sa.typ)
self.imports.add("comtypes.automation", "_midlSAFEARRAY")

def PointerType(self, tp):
# type: (typedesc.PointerType) -> None
def PointerType(self, tp: typedesc.PointerType) -> None:
if type(tp.typ) is typedesc.ComInterface:
# this defines the class
self.generate(tp.typ.get_head())
Expand All @@ -939,8 +919,7 @@ def PointerType(self, tp):
elif real_type.name == "wchar_t":
self.declarations.add("WSTRING", "c_wchar_p")

def CoClass(self, coclass):
# type: (typedesc.CoClass) -> None
def CoClass(self, coclass: typedesc.CoClass) -> None:
self.imports.add("comtypes", "GUID")
self.imports.add("comtypes", "CoClass")
if not self.last_item_class:
Expand Down Expand Up @@ -1001,14 +980,12 @@ def CoClass(self, coclass):

self.names.add(coclass.name)

def ComInterface(self, itf):
# type: (typedesc.ComInterface) -> None
def ComInterface(self, itf: typedesc.ComInterface) -> None:
self.generate(itf.get_head())
self.generate(itf.get_body())
self.names.add(itf.name)

def _is_enuminterface(self, itf):
# type: (typedesc.ComInterface) -> bool
def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool:
# Check if this is an IEnumXXX interface
if not itf.name.startswith("IEnum"):
return False
Expand All @@ -1018,8 +995,7 @@ def _is_enuminterface(self, itf):
return False
return True

def ComInterfaceHead(self, head):
# type: (typedesc.ComInterfaceHead) -> None
def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None:
if head.itf.name in self.known_symbols:
return
base = head.itf.base
Expand Down Expand Up @@ -1071,8 +1047,7 @@ def ComInterfaceHead(self, head):
print(file=self.stream)
print(file=self.stream)

def ComInterfaceBody(self, body):
# type: (typedesc.ComInterfaceBody) -> None
def ComInterfaceBody(self, body: typedesc.ComInterfaceBody) -> None:
# The base class must be fully generated, including the
# _methods_ list.
self.generate(body.itf.base)
Expand Down Expand Up @@ -1166,14 +1141,12 @@ def ComInterfaceBody(self, body):
raise RuntimeError("BUG")
print("#", file=self.stream)

def DispInterface(self, itf):
# type: (typedesc.DispInterface) -> None
def DispInterface(self, itf: typedesc.DispInterface) -> None:
self.generate(itf.get_head())
self.generate(itf.get_body())
self.names.add(itf.name)

def DispInterfaceHead(self, head):
# type: (typedesc.DispInterfaceHead) -> None
def DispInterfaceHead(self, head: typedesc.DispInterfaceHead) -> None:
self.generate(head.itf.base)
basename = self._to_type_name(head.itf.base)

Expand All @@ -1195,8 +1168,7 @@ def DispInterfaceHead(self, head):
print(file=self.stream)
print(file=self.stream)

def DispInterfaceBody(self, body):
# type: (typedesc.DispInterfaceBody) -> None
def DispInterfaceBody(self, body: typedesc.DispInterfaceBody) -> None:
# make sure we can generate the body
for m in body.itf.members:
if isinstance(m, typedesc.DispMethod):
Expand Down Expand Up @@ -1226,8 +1198,7 @@ def DispInterfaceBody(self, body):
################################################################
# non-toplevel method generators
#
def make_ComMethod(self, m, isdual):
# type: (typedesc.ComMethod, bool) -> None
def make_ComMethod(self, m: typedesc.ComMethod, isdual: bool) -> None:
self.imports.add("comtypes", "COMMETHOD")
if isdual:
self.imports.add("comtypes", "dispid")
Expand All @@ -1246,8 +1217,7 @@ def make_ComMethod(self, m, isdual):
if default is not None:
self.need_VARIANT_imports(default)

def make_DispMethod(self, m):
# type: (typedesc.DispMethod) -> None
def make_DispMethod(self, m: typedesc.DispMethod) -> None:
self.imports.add("comtypes", "DISPMETHOD")
self.imports.add("comtypes", "dispid")
if __debug__ and m.doc:
Expand All @@ -1259,8 +1229,7 @@ def make_DispMethod(self, m):
if default is not None:
self.need_VARIANT_imports(default)

def make_DispProperty(self, prop):
# type: (typedesc.DispProperty) -> None
def make_DispProperty(self, prop: typedesc.DispProperty) -> None:
self.imports.add("comtypes", "DISPPROPERTY")
self.imports.add("comtypes", "dispid")
if __debug__ and prop.doc:
Expand All @@ -1271,8 +1240,7 @@ def make_DispProperty(self, prop):


class TypeNamer(object):
def __call__(self, t):
# type: (Any) -> str
def __call__(self, t: Any) -> str:
# Return a string, containing an expression which can be used
# to refer to the type. Assumes the 'from ctypes import *'
# namespace is available.
Expand Down Expand Up @@ -1315,8 +1283,9 @@ def __call__(self, t):
return "%s.%s" % (modname, t.symbol_name)
return t.name

def _inspect_PointerType(self, t, count=0):
# type: (typedesc.PointerType, int) -> Tuple[Any, int]
def _inspect_PointerType(
self, t: typedesc.PointerType, count: int = 0
) -> Tuple[Any, int]:
if ASSUME_STRINGS:
x = get_real_type(t.typ)
if isinstance(x, typedesc.FundamentalType):
Expand Down

0 comments on commit ee7ae9b

Please sign in to comment.