@@ -841,6 +841,7 @@ def write_to_module(
841841 find_replace : ty .Optional [ty .List [ty .Tuple [str , str ]]] = None ,
842842 inline_intra_pkg : bool = False ,
843843 additional_imports : ty .Optional [ty .List [ImportStatement ]] = None ,
844+ interface_module : bool = False ,
844845 ):
845846 """Writes the given imports, constants, classes, and functions to the file at the given path,
846847 merging with existing code if it exists"""
@@ -875,9 +876,13 @@ def write_to_module(
875876 existing_imports = parse_imports (existing_import_strs , relative_to = module_name )
876877 converter_imports = []
877878
879+ src_module_name = self .untranslate_submodule (module_name )
880+ if interface_module :
881+ src_module_name = "." .join (src_module_name .split ("." )[:- 1 ])
882+
878883 for klass in used .classes :
879884 if (
880- klass .__module__ == module_name
885+ klass .__module__ == src_module_name
881886 and f"\n class { klass .__name__ } (" not in code_str
882887 ):
883888 try :
@@ -912,7 +917,7 @@ def write_to_module(
912917
913918 for func in sorted (used .functions , key = attrgetter ("__name__" )):
914919 if (
915- func .__module__ == module_name
920+ func .__module__ == src_module_name
916921 and f"\n def { func .__name__ } (" not in code_str
917922 ):
918923 if func .__name__ in self .functions :
0 commit comments