diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 2ac44c023..0e0244e29 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -266,7 +266,7 @@ class AnnotationCounts: return_annotations: int = 0 classes_added: int = 0 - def applied_changes(self): + def any_changes(self): return ( self.global_annotations + self.attribute_annotations @@ -385,7 +385,13 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: self.annotations.class_definitions.update(visitor.class_definitions) tree_with_imports = AddImportsVisitor(self.context).transform_module(tree) - return tree_with_imports.visit(self) + tree_with_changes = tree_with_imports.visit(self) + + # don't modify the imports if we didn't actually add any type information + if self.annotation_counts.any_changes(): + return tree_with_changes + else: + return tree # smart constructors: all applied annotations happen via one of these diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 4a63ab29e..e337f41a1 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -1012,11 +1012,7 @@ class C: def __init__(self): self.attr_will_not_be_found = None """, - # TODO: use the annotation counts to avoid adding - # the import in this case. """ - from bar import X - class C: def __init__(self): self.attr_will_not_be_found = None @@ -1032,7 +1028,7 @@ def test_count_annotations( before: str, after: str, annotation_counts: AnnotationCounts, - applied_changes: False, + any_changes: False, ): stub = self.make_fixture_data(stub) before = self.make_fixture_data(before) @@ -1048,4 +1044,4 @@ def test_count_annotations( self.assertEqual(after, output_code) self.assertEqual(str(annotation_counts), str(visitor.annotation_counts)) - self.assertEqual(applied_changes, visitor.annotation_counts.applied_changes()) + self.assertEqual(any_changes, visitor.annotation_counts.any_changes())