diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 903e113a..2e0f13f6 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -187,9 +187,16 @@ def name_friendly_module(tlib): ################################################################ +_DefValType = _UnionT["lcid", Any, None] +_IdlFlagType = _UnionT[str, dispid, helpstring] -def _to_arg_definition(type_name, arg_name, idlflags, default): - # type: (str, str, List[str], _UnionT[lcid, Any, None]) -> str + +def _to_arg_definition( + type_name: str, + arg_name: str, + idlflags: List[str], + default: _DefValType, +) -> str: if default is not None: elms = (idlflags, type_name, arg_name, default) code = " (%r, %s, '%s', %r)" % elms @@ -217,8 +224,7 @@ def _to_arg_definition(type_name, arg_name, idlflags, default): class ComMethodGenerator(object): - def __init__(self, m, isdual): - # type: (typedesc.ComMethod, bool) -> None + def __init__(self, m: typedesc.ComMethod, isdual: bool) -> None: self._m = m self._isdual = isdual self._stream = io.StringIO() @@ -232,9 +238,8 @@ def generate(self): self._make_withargs() return self._stream.getvalue() - def _get_common_elms(self): - # type: () -> Tuple[List[_UnionT[str, dispid, helpstring]], str, str] - idlflags = [] # type: List[_UnionT[str, dispid, helpstring]] + def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]: + idlflags: List[_IdlFlagType] = [] if self._isdual: idlflags.append(dispid(self._m.memid)) idlflags.extend(self._m.idlflags) @@ -245,8 +250,7 @@ def _get_common_elms(self): type_name = self._to_type_name(self._m.returns) return (idlflags, type_name, self._m.name) - def _make_noargs(self): - # type: () -> None + def _make_noargs(self) -> None: elms = self._get_common_elms() code = " COMMETHOD(%r, %s, '%s')," % elms if len(code) > 80: @@ -259,8 +263,7 @@ def _make_noargs(self): ) % elms print(code, file=self._stream) - def _make_withargs(self): - # type: () -> None + def _make_withargs(self) -> None: code = ( " COMMETHOD(\n" " %r,\n" " %s,\n" " '%s'," ) % self._get_common_elms() @@ -269,8 +272,7 @@ def _make_withargs(self): print(",\n".join(arglist), file=self._stream) print(" ),", file=self._stream) - def _iter_args(self): - # type: () -> Iterator[Tuple[str, str, List[str], _UnionT[lcid, Any, None]]] + def _iter_args(self) -> Iterator[Tuple[str, str, List[str], _DefValType]]: for typ, arg_name, _f, _defval in self._m.arguments: ########################################################### # IDL files that contain 'open arrays' or 'conformant @@ -327,8 +329,7 @@ def _iter_args(self): class DispMethodGenerator(object): - def __init__(self, m): - # type: (typedesc.DispMethod) -> None + def __init__(self, m: typedesc.DispMethod) -> None: self._m = m self._stream = io.StringIO() self._to_type_name = TypeNamer() @@ -341,9 +342,8 @@ def generate(self): self._make_withargs() return self._stream.getvalue() - def _get_common_elms(self): - # type: () -> Tuple[List[_UnionT[str, dispid, helpstring]], str, str] - idlflags = [] # type: List[_UnionT[str, dispid, helpstring]] + def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]: + idlflags: List[_IdlFlagType] = [] idlflags.append(dispid(self._m.dispid)) idlflags.extend(self._m.idlflags) if __debug__ and self._m.doc: @@ -351,8 +351,7 @@ def _get_common_elms(self): type_name = self._to_type_name(self._m.returns) return (idlflags, type_name, self._m.name) - def _make_noargs(self): - # type: () -> None + def _make_noargs(self) -> None: elms = self._get_common_elms() code = " DISPMETHOD(%r, %s, '%s')," % elms if len(code) > 80: @@ -365,8 +364,7 @@ def _make_noargs(self): ) % elms print(code, file=self._stream) - def _make_withargs(self): - # type: () -> None + def _make_withargs(self) -> None: code = ( " DISPMETHOD(\n" " %r,\n" " %s,\n" " '%s'," ) % self._get_common_elms() @@ -375,16 +373,14 @@ def _make_withargs(self): print(",\n".join(arglist), file=self._stream) print(" ),", file=self._stream) - def _iter_args(self): - # type: () -> Iterator[Tuple[str, str, List[str], _UnionT[lcid, Any, None]]] + def _iter_args(self) -> Iterator[Tuple[str, str, List[str], _DefValType]]: for typ, arg_name, idlflags, default in self._m.arguments: type_name = self._to_type_name(typ) yield (type_name, arg_name, idlflags, default) class DispPropertyGenerator(object): - def __init__(self, m): - # type: (typedesc.DispProperty) -> None + def __init__(self, m: typedesc.DispProperty) -> None: self._m = m self._to_type_name = TypeNamer() @@ -402,9 +398,8 @@ def generate(self): ) % elms return code + "\n" - def _get_common_elms(self): - # type: () -> Tuple[List[_UnionT[str, dispid, helpstring]], str, str] - idlflags = [] # type: List[_UnionT[str, dispid, helpstring]] + def _get_common_elms(self) -> Tuple[List[_IdlFlagType], str, str]: + idlflags: List[_IdlFlagType] = [] idlflags.append(dispid(self._m.dispid)) idlflags.extend(self._m.idlflags) if __debug__ and self._m.doc: @@ -569,21 +564,18 @@ def need_VARIANT_imports(self, value): if "datetime.datetime(" in text: self.imports.add("datetime") - def _to_docstring(self, orig, depth=1): - # type: (str, int) -> str + def _to_docstring(self, orig: str, depth: int = 1) -> str: # increasing `depth` by one increases indentation by one indent = " " * depth # some chars are replaced to avoid causing a `SyntaxError` repled = orig.replace("\\", r"\\").replace('"', r"'") return '%s"""%s"""' % (indent, repled) - def ArrayType(self, tp): - # type: (typedesc.ArrayType) -> None + def ArrayType(self, tp: typedesc.ArrayType) -> None: self.generate(get_real_type(tp.typ)) self.generate(tp.typ) - def EnumValue(self, tp): - # type: (typedesc.EnumValue) -> None + def EnumValue(self, tp: typedesc.EnumValue) -> None: self.last_item_class = False value = int(tp.value) if keyword.iskeyword(tp.name): @@ -594,8 +586,7 @@ def EnumValue(self, tp): print("%s = %d" % (tp_name, value), file=self.stream) self.names.add(tp_name) - def Enumeration(self, tp): - # type: (typedesc.Enumeration) -> None + def Enumeration(self, tp: typedesc.Enumeration) -> None: self.last_item_class = False if tp.name: print("# values for enumeration '%s'" % tp.name, file=self.stream) @@ -611,8 +602,7 @@ def Enumeration(self, tp): print("%s = c_int # enum" % tp.name, file=self.stream) self.names.add(tp.name) - def Typedef(self, tp): - # type: (typedesc.Typedef) -> None + def Typedef(self, tp: typedesc.Typedef) -> None: if isinstance(tp.typ, (typedesc.Structure, typedesc.Union)): self.generate(tp.typ.get_head()) self.more.add(tp.typ) @@ -627,12 +617,10 @@ def Typedef(self, tp): self.last_item_class = False self.names.add(tp.name) - def FundamentalType(self, item): - # type: (typedesc.FundamentalType) -> None + def FundamentalType(self, item: typedesc.FundamentalType) -> None: pass # we should check if this is known somewhere - def StructureHead(self, head): - # type: (typedesc.StructureHead) -> None + def StructureHead(self, head: typedesc.StructureHead) -> None: for struct in head.struct.bases: self.generate(struct.get_head()) self.more.add(struct) @@ -722,20 +710,17 @@ def StructureHead(self, head): print(file=self.stream) self.names.add(head.struct.name) - def Structure(self, struct): - # type: (typedesc.Structure) -> None + def Structure(self, struct: typedesc.Structure) -> None: self.generate(struct.get_head()) self.generate(struct.get_body()) - def Union(self, union): - # type: (typedesc.Union) -> None + def Union(self, union: typedesc.Union) -> None: self.generate(union.get_head()) self.generate(union.get_body()) - def StructureBody(self, body): - # type: (typedesc.StructureBody) -> None - fields = [] # type: List[typedesc.Field] - methods = [] # type: List[typedesc.Method] + def StructureBody(self, body: typedesc.StructureBody) -> None: + fields: List[typedesc.Field] = [] + methods: List[typedesc.Method] = [] for m in body.struct.members: if type(m) is typedesc.Field: fields.append(m) @@ -863,8 +848,7 @@ def StructureBody(self, body): ################################################################ # top-level typedesc generators # - def TypeLib(self, lib): - # type: (typedesc.TypeLib) -> None + def TypeLib(self, lib: typedesc.TypeLib) -> None: # Hm, in user code we have to write: # class MyServer(COMObject, ...): # _com_interfaces_ = [MyTypeLib.IInterface] @@ -893,15 +877,13 @@ def TypeLib(self, lib): print(file=self.stream) print(file=self.stream) - def External(self, ext): - # type: (typedesc.External) -> None + def External(self, ext: typedesc.External) -> None: modname = name_wrapper_module(ext.tlib) if modname not in self.imports: self.externals.append(ext.tlib) self.imports.add(modname) - def Constant(self, tp): - # type: (typedesc.Constant) -> None + def Constant(self, tp: typedesc.Constant) -> None: self.last_item_class = False print( "%s = %r # Constant %s" % (tp.name, tp.value, self._to_type_name(tp.typ)), @@ -909,13 +891,11 @@ def Constant(self, tp): ) self.names.add(tp.name) - def SAFEARRAYType(self, sa): - # type: (typedesc.SAFEARRAYType) -> None + def SAFEARRAYType(self, sa: typedesc.SAFEARRAYType) -> None: self.generate(sa.typ) self.imports.add("comtypes.automation", "_midlSAFEARRAY") - def PointerType(self, tp): - # type: (typedesc.PointerType) -> None + def PointerType(self, tp: typedesc.PointerType) -> None: if type(tp.typ) is typedesc.ComInterface: # this defines the class self.generate(tp.typ.get_head()) @@ -939,8 +919,7 @@ def PointerType(self, tp): elif real_type.name == "wchar_t": self.declarations.add("WSTRING", "c_wchar_p") - def CoClass(self, coclass): - # type: (typedesc.CoClass) -> None + def CoClass(self, coclass: typedesc.CoClass) -> None: self.imports.add("comtypes", "GUID") self.imports.add("comtypes", "CoClass") if not self.last_item_class: @@ -1001,14 +980,12 @@ def CoClass(self, coclass): self.names.add(coclass.name) - def ComInterface(self, itf): - # type: (typedesc.ComInterface) -> None + def ComInterface(self, itf: typedesc.ComInterface) -> None: self.generate(itf.get_head()) self.generate(itf.get_body()) self.names.add(itf.name) - def _is_enuminterface(self, itf): - # type: (typedesc.ComInterface) -> bool + def _is_enuminterface(self, itf: typedesc.ComInterface) -> bool: # Check if this is an IEnumXXX interface if not itf.name.startswith("IEnum"): return False @@ -1018,8 +995,7 @@ def _is_enuminterface(self, itf): return False return True - def ComInterfaceHead(self, head): - # type: (typedesc.ComInterfaceHead) -> None + def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None: if head.itf.name in self.known_symbols: return base = head.itf.base @@ -1071,8 +1047,7 @@ def ComInterfaceHead(self, head): print(file=self.stream) print(file=self.stream) - def ComInterfaceBody(self, body): - # type: (typedesc.ComInterfaceBody) -> None + def ComInterfaceBody(self, body: typedesc.ComInterfaceBody) -> None: # The base class must be fully generated, including the # _methods_ list. self.generate(body.itf.base) @@ -1166,14 +1141,12 @@ def ComInterfaceBody(self, body): raise RuntimeError("BUG") print("#", file=self.stream) - def DispInterface(self, itf): - # type: (typedesc.DispInterface) -> None + def DispInterface(self, itf: typedesc.DispInterface) -> None: self.generate(itf.get_head()) self.generate(itf.get_body()) self.names.add(itf.name) - def DispInterfaceHead(self, head): - # type: (typedesc.DispInterfaceHead) -> None + def DispInterfaceHead(self, head: typedesc.DispInterfaceHead) -> None: self.generate(head.itf.base) basename = self._to_type_name(head.itf.base) @@ -1195,8 +1168,7 @@ def DispInterfaceHead(self, head): print(file=self.stream) print(file=self.stream) - def DispInterfaceBody(self, body): - # type: (typedesc.DispInterfaceBody) -> None + def DispInterfaceBody(self, body: typedesc.DispInterfaceBody) -> None: # make sure we can generate the body for m in body.itf.members: if isinstance(m, typedesc.DispMethod): @@ -1226,8 +1198,7 @@ def DispInterfaceBody(self, body): ################################################################ # non-toplevel method generators # - def make_ComMethod(self, m, isdual): - # type: (typedesc.ComMethod, bool) -> None + def make_ComMethod(self, m: typedesc.ComMethod, isdual: bool) -> None: self.imports.add("comtypes", "COMMETHOD") if isdual: self.imports.add("comtypes", "dispid") @@ -1246,8 +1217,7 @@ def make_ComMethod(self, m, isdual): if default is not None: self.need_VARIANT_imports(default) - def make_DispMethod(self, m): - # type: (typedesc.DispMethod) -> None + def make_DispMethod(self, m: typedesc.DispMethod) -> None: self.imports.add("comtypes", "DISPMETHOD") self.imports.add("comtypes", "dispid") if __debug__ and m.doc: @@ -1259,8 +1229,7 @@ def make_DispMethod(self, m): if default is not None: self.need_VARIANT_imports(default) - def make_DispProperty(self, prop): - # type: (typedesc.DispProperty) -> None + def make_DispProperty(self, prop: typedesc.DispProperty) -> None: self.imports.add("comtypes", "DISPPROPERTY") self.imports.add("comtypes", "dispid") if __debug__ and prop.doc: @@ -1271,8 +1240,7 @@ def make_DispProperty(self, prop): class TypeNamer(object): - def __call__(self, t): - # type: (Any) -> str + def __call__(self, t: Any) -> str: # Return a string, containing an expression which can be used # to refer to the type. Assumes the 'from ctypes import *' # namespace is available. @@ -1315,8 +1283,9 @@ def __call__(self, t): return "%s.%s" % (modname, t.symbol_name) return t.name - def _inspect_PointerType(self, t, count=0): - # type: (typedesc.PointerType, int) -> Tuple[Any, int] + def _inspect_PointerType( + self, t: typedesc.PointerType, count: int = 0 + ) -> Tuple[Any, int]: if ASSUME_STRINGS: x = get_real_type(t.typ) if isinstance(x, typedesc.FundamentalType):