@@ -118,7 +118,7 @@ fn infer_definition_types_cycle_recovery<'db>(
118118) -> TypeInference < ' db > {
119119 tracing:: trace!( "infer_definition_types_cycle_recovery" ) ;
120120 let mut inference = TypeInference :: empty ( input. scope ( db) ) ;
121- let category = input. category ( db) ;
121+ let category = input. kind ( db) . category ( ) ;
122122 if category. is_declaration ( ) {
123123 inference
124124 . declarations
@@ -198,6 +198,36 @@ pub(crate) fn infer_expression_types<'db>(
198198 TypeInferenceBuilder :: new ( db, InferenceRegion :: Expression ( expression) , index) . finish ( )
199199}
200200
201+ /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query.
202+ ///
203+ /// This is a small helper around [`infer_expression_types()`] to reduce the boilerplate.
204+ /// Use [`infer_expression_type()`] if it isn't guaranteed that `expression` is in the same file to
205+ /// avoid cross-file query dependencies.
206+ pub ( super ) fn infer_same_file_expression_type < ' db > (
207+ db : & ' db dyn Db ,
208+ expression : Expression < ' db > ,
209+ ) -> Type < ' db > {
210+ let inference = infer_expression_types ( db, expression) ;
211+ let scope = expression. scope ( db) ;
212+ inference. expression_type ( expression. node_ref ( db) . scoped_expression_id ( db, scope) )
213+ }
214+
215+ /// Infers the type of an expression where the expression might come from another file.
216+ ///
217+ /// Use this over [`infer_expression_types`] if the expression might come from another file than the
218+ /// enclosing query to avoid cross-file query dependencies.
219+ ///
220+ /// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same
221+ /// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`.
222+ #[ salsa:: tracked]
223+ pub ( crate ) fn infer_expression_type < ' db > (
224+ db : & ' db dyn Db ,
225+ expression : Expression < ' db > ,
226+ ) -> Type < ' db > {
227+ // It's okay to call the "same file" version here because we're inside a salsa query.
228+ infer_same_file_expression_type ( db, expression)
229+ }
230+
201231/// Infer the types for an [`Unpack`] operation.
202232///
203233/// This infers the expression type and performs structural match against the target expression
@@ -870,7 +900,7 @@ impl<'db> TypeInferenceBuilder<'db> {
870900 }
871901
872902 fn add_binding ( & mut self , node : AnyNodeRef , binding : Definition < ' db > , ty : Type < ' db > ) {
873- debug_assert ! ( binding. is_binding ( self . db( ) ) ) ;
903+ debug_assert ! ( binding. kind ( self . db( ) ) . category ( ) . is_binding ( ) ) ;
874904 let use_def = self . index . use_def_map ( binding. file_scope ( self . db ( ) ) ) ;
875905 let declarations = use_def. declarations_at_binding ( binding) ;
876906 let mut bound_ty = ty;
@@ -905,7 +935,7 @@ impl<'db> TypeInferenceBuilder<'db> {
905935 declaration : Definition < ' db > ,
906936 ty : TypeAndQualifiers < ' db > ,
907937 ) {
908- debug_assert ! ( declaration. is_declaration ( self . db( ) ) ) ;
938+ debug_assert ! ( declaration. kind ( self . db( ) ) . category ( ) . is_declaration ( ) ) ;
909939 let use_def = self . index . use_def_map ( declaration. file_scope ( self . db ( ) ) ) ;
910940 let prior_bindings = use_def. bindings_at_declaration ( declaration) ;
911941 // unbound_ty is Never because for this check we don't care about unbound
@@ -935,8 +965,8 @@ impl<'db> TypeInferenceBuilder<'db> {
935965 definition : Definition < ' db > ,
936966 declared_and_inferred_ty : & DeclaredAndInferredType < ' db > ,
937967 ) {
938- debug_assert ! ( definition. is_binding ( self . db( ) ) ) ;
939- debug_assert ! ( definition. is_declaration ( self . db( ) ) ) ;
968+ debug_assert ! ( definition. kind ( self . db( ) ) . category ( ) . is_binding ( ) ) ;
969+ debug_assert ! ( definition. kind ( self . db( ) ) . category ( ) . is_declaration ( ) ) ;
940970
941971 let ( declared_ty, inferred_ty) = match * declared_and_inferred_ty {
942972 DeclaredAndInferredType :: AreTheSame ( ty) => ( ty. into ( ) , ty) ,
@@ -6626,4 +6656,93 @@ mod tests {
66266656
66276657 Ok ( ( ) )
66286658 }
6659+
6660+ /// This test verifies that changing a class's declaration in a non-meaningful way (e.g. by adding a comment)
6661+ /// doesn't trigger type inference for expressions that depend on the class's members.
6662+ #[ test]
6663+ fn dependency_own_instance_member ( ) -> anyhow:: Result < ( ) > {
6664+ fn x_rhs_expression ( db : & TestDb ) -> Expression < ' _ > {
6665+ let file_main = system_path_to_file ( db, "/src/main.py" ) . unwrap ( ) ;
6666+ let ast = parsed_module ( db, file_main) ;
6667+ // Get the second statement in `main.py` (x = …) and extract the expression
6668+ // node on the right-hand side:
6669+ let x_rhs_node = & ast. syntax ( ) . body [ 1 ] . as_assign_stmt ( ) . unwrap ( ) . value ;
6670+
6671+ let index = semantic_index ( db, file_main) ;
6672+ index. expression ( x_rhs_node. as_ref ( ) )
6673+ }
6674+
6675+ let mut db = setup_db ( ) ;
6676+
6677+ db. write_dedented (
6678+ "/src/mod.py" ,
6679+ r#"
6680+ class C:
6681+ if random.choice([True, False]):
6682+ attr: int = 42
6683+ else:
6684+ attr: None = None
6685+ "# ,
6686+ ) ?;
6687+ db. write_dedented (
6688+ "/src/main.py" ,
6689+ r#"
6690+ from mod import C
6691+ x = C().attr
6692+ "# ,
6693+ ) ?;
6694+
6695+ let file_main = system_path_to_file ( & db, "/src/main.py" ) . unwrap ( ) ;
6696+ let attr_ty = global_symbol ( & db, file_main, "x" ) . expect_type ( ) ;
6697+ assert_eq ! ( attr_ty. display( & db) . to_string( ) , "Unknown | int | None" ) ;
6698+
6699+ // Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
6700+ db. write_dedented (
6701+ "/src/mod.py" ,
6702+ r#"
6703+ class C:
6704+ if random.choice([True, False]):
6705+ attr: str = "42"
6706+ else:
6707+ attr: None = None
6708+ "# ,
6709+ ) ?;
6710+
6711+ let events = {
6712+ db. clear_salsa_events ( ) ;
6713+ let attr_ty = global_symbol ( & db, file_main, "x" ) . expect_type ( ) ;
6714+ assert_eq ! ( attr_ty. display( & db) . to_string( ) , "Unknown | str | None" ) ;
6715+ db. take_salsa_events ( )
6716+ } ;
6717+ assert_function_query_was_run ( & db, infer_expression_types, x_rhs_expression ( & db) , & events) ;
6718+
6719+ // Add a comment; this should not trigger the type of `x` to be re-inferred
6720+ db. write_dedented (
6721+ "/src/mod.py" ,
6722+ r#"
6723+ class C:
6724+ # comment
6725+ if random.choice([True, False]):
6726+ attr: str = "42"
6727+ else:
6728+ attr: None = None
6729+ "# ,
6730+ ) ?;
6731+
6732+ let events = {
6733+ db. clear_salsa_events ( ) ;
6734+ let attr_ty = global_symbol ( & db, file_main, "x" ) . expect_type ( ) ;
6735+ assert_eq ! ( attr_ty. display( & db) . to_string( ) , "Unknown | str | None" ) ;
6736+ db. take_salsa_events ( )
6737+ } ;
6738+
6739+ assert_function_query_was_not_run (
6740+ & db,
6741+ infer_expression_types,
6742+ x_rhs_expression ( & db) ,
6743+ & events,
6744+ ) ;
6745+
6746+ Ok ( ( ) )
6747+ }
66296748}
0 commit comments