@@ -118,7 +118,7 @@ fn infer_definition_types_cycle_recovery<'db>(
118118) -> TypeInference < ' db > {
119119 tracing:: trace!( "infer_definition_types_cycle_recovery" ) ;
120120 let mut inference = TypeInference :: empty ( input. scope ( db) ) ;
121- let category = input. category ( db) ;
121+ let category = input. kind ( db) . category ( ) ;
122122 if category. is_declaration ( ) {
123123 inference
124124 . declarations
@@ -198,6 +198,36 @@ pub(crate) fn infer_expression_types<'db>(
198198 TypeInferenceBuilder :: new ( db, InferenceRegion :: Expression ( expression) , index) . finish ( )
199199}
200200
201+ /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query.
202+ ///
203+ /// This is a small helper around [`infer_expression_types()`] to reduce the boilerplate.
204+ /// Use [`infer_expression_type()`] if it isn't guaranteed that `expression` is in the same file to
205+ /// avoid cross-file query dependencies.
206+ pub ( super ) fn infer_same_file_expression_type < ' db > (
207+ db : & ' db dyn Db ,
208+ expression : Expression < ' db > ,
209+ ) -> Type < ' db > {
210+ let inference = infer_expression_types ( db, expression) ;
211+ let scope = expression. scope ( db) ;
212+ inference. expression_type ( expression. node_ref ( db) . scoped_expression_id ( db, scope) )
213+ }
214+
215+ /// Infers the type of an expression where the expression might come from another file.
216+ ///
217+ /// Use this over [`infer_expression_types`] if the expression might come from another file than the
218+ /// enclosing query to avoid cross-file query dependencies.
219+ ///
220+ /// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same
221+ /// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`.
222+ #[ salsa:: tracked]
223+ pub ( crate ) fn infer_expression_type < ' db > (
224+ db : & ' db dyn Db ,
225+ expression : Expression < ' db > ,
226+ ) -> Type < ' db > {
227+ // It's okay to call the "same file" version here because we're inside a salsa query.
228+ infer_same_file_expression_type ( db, expression)
229+ }
230+
201231/// Infer the types for an [`Unpack`] operation.
202232///
203233/// This infers the expression type and performs structural match against the target expression
@@ -870,7 +900,7 @@ impl<'db> TypeInferenceBuilder<'db> {
870900 }
871901
872902 fn add_binding ( & mut self , node : AnyNodeRef , binding : Definition < ' db > , ty : Type < ' db > ) {
873- debug_assert ! ( binding. is_binding ( self . db( ) ) ) ;
903+ debug_assert ! ( binding. kind ( self . db( ) ) . category ( ) . is_binding ( ) ) ;
874904 let use_def = self . index . use_def_map ( binding. file_scope ( self . db ( ) ) ) ;
875905 let declarations = use_def. declarations_at_binding ( binding) ;
876906 let mut bound_ty = ty;
@@ -905,7 +935,7 @@ impl<'db> TypeInferenceBuilder<'db> {
905935 declaration : Definition < ' db > ,
906936 ty : TypeAndQualifiers < ' db > ,
907937 ) {
908- debug_assert ! ( declaration. is_declaration ( self . db( ) ) ) ;
938+ debug_assert ! ( declaration. kind ( self . db( ) ) . category ( ) . is_declaration ( ) ) ;
909939 let use_def = self . index . use_def_map ( declaration. file_scope ( self . db ( ) ) ) ;
910940 let prior_bindings = use_def. bindings_at_declaration ( declaration) ;
911941 // unbound_ty is Never because for this check we don't care about unbound
@@ -935,8 +965,8 @@ impl<'db> TypeInferenceBuilder<'db> {
935965 definition : Definition < ' db > ,
936966 declared_and_inferred_ty : & DeclaredAndInferredType < ' db > ,
937967 ) {
938- debug_assert ! ( definition. is_binding ( self . db( ) ) ) ;
939- debug_assert ! ( definition. is_declaration ( self . db( ) ) ) ;
968+ debug_assert ! ( definition. kind ( self . db( ) ) . category ( ) . is_binding ( ) ) ;
969+ debug_assert ! ( definition. kind ( self . db( ) ) . category ( ) . is_declaration ( ) ) ;
940970
941971 let ( declared_ty, inferred_ty) = match * declared_and_inferred_ty {
942972 DeclaredAndInferredType :: AreTheSame ( ty) => ( ty. into ( ) , ty) ,
@@ -6626,4 +6656,92 @@ mod tests {
66266656
66276657 Ok ( ( ) )
66286658 }
6659+
6660+ /// This test verifies that queries
6661+ #[ test]
6662+ fn dependency_own_instance_member ( ) -> anyhow:: Result < ( ) > {
6663+ fn x_rhs_expression ( db : & TestDb ) -> Expression < ' _ > {
6664+ let file_main = system_path_to_file ( db, "/src/main.py" ) . unwrap ( ) ;
6665+ let ast = parsed_module ( db, file_main) ;
6666+ // Get the second statement in `main.py` (x = …) and extract the expression
6667+ // node on the right-hand side:
6668+ let x_rhs_node = & ast. syntax ( ) . body [ 1 ] . as_assign_stmt ( ) . unwrap ( ) . value ;
6669+
6670+ let index = semantic_index ( db, file_main) ;
6671+ index. expression ( x_rhs_node. as_ref ( ) )
6672+ }
6673+
6674+ let mut db = setup_db ( ) ;
6675+
6676+ db. write_dedented (
6677+ "/src/mod.py" ,
6678+ r#"
6679+ class C:
6680+ if random.choice([True, False]):
6681+ attr: int = 42
6682+ else:
6683+ attr: None = None
6684+ "# ,
6685+ ) ?;
6686+ db. write_dedented (
6687+ "/src/main.py" ,
6688+ r#"
6689+ from mod import C
6690+ x = C().attr
6691+ "# ,
6692+ ) ?;
6693+
6694+ let file_main = system_path_to_file ( & db, "/src/main.py" ) . unwrap ( ) ;
6695+ let attr_ty = global_symbol ( & db, file_main, "x" ) . expect_type ( ) ;
6696+ assert_eq ! ( attr_ty. display( & db) . to_string( ) , "Unknown | int | None" ) ;
6697+
6698+ // Change the type of `attr` to `str | None`; this should trigger the type of `x` to be re-inferred
6699+ db. write_dedented (
6700+ "/src/mod.py" ,
6701+ r#"
6702+ class C:
6703+ if random.choice([True, False]):
6704+ attr: str = "42"
6705+ else:
6706+ attr: None = None
6707+ "# ,
6708+ ) ?;
6709+
6710+ let events = {
6711+ db. clear_salsa_events ( ) ;
6712+ let attr_ty = global_symbol ( & db, file_main, "x" ) . expect_type ( ) ;
6713+ assert_eq ! ( attr_ty. display( & db) . to_string( ) , "Unknown | str | None" ) ;
6714+ db. take_salsa_events ( )
6715+ } ;
6716+ assert_function_query_was_run ( & db, infer_expression_types, x_rhs_expression ( & db) , & events) ;
6717+
6718+ // Add a comment; this should not trigger the type of `x` to be re-inferred
6719+ db. write_dedented (
6720+ "/src/mod.py" ,
6721+ r#"
6722+ class C:
6723+ # comment
6724+ if random.choice([True, False]):
6725+ attr: str = "42"
6726+ else:
6727+ attr: None = None
6728+ "# ,
6729+ ) ?;
6730+
6731+ let events = {
6732+ db. clear_salsa_events ( ) ;
6733+ let attr_ty = global_symbol ( & db, file_main, "x" ) . expect_type ( ) ;
6734+ assert_eq ! ( attr_ty. display( & db) . to_string( ) , "Unknown | str | None" ) ;
6735+ db. take_salsa_events ( )
6736+ } ;
6737+
6738+ assert_function_query_was_not_run (
6739+ & db,
6740+ infer_expression_types,
6741+ x_rhs_expression ( & db) ,
6742+ & events,
6743+ ) ;
6744+
6745+ Ok ( ( ) )
6746+ }
66296747}
0 commit comments