Skip to content

Commit a5070d1

Browse files
committed
Special case ParamSpec default when inferring deferred assignment definition
1 parent 1f692a8 commit a5070d1

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

crates/ty_python_semantic/resources/mdtest/paramspec.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ python-version = "3.13"
8888
```py
8989
from typing import ParamSpec
9090

91-
# TODO: This is not an error
92-
# error: [invalid-type-form]
93-
P = ParamSpec("P", default=[int, str])
91+
P1 = ParamSpec("P1", default=[int, str])
92+
P2 = ParamSpec("P2", default=...)
93+
P3 = ParamSpec("P3", default=P2)
9494
```
9595

9696
### PEP 695

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3236,14 +3236,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
32363236
else {
32373237
return;
32383238
};
3239-
32403239
let previous_deferred_state =
32413240
std::mem::replace(&mut self.deferred_state, DeferredExpressionState::Deferred);
3241+
let default_ty = self.infer_paramspec_default(default);
3242+
self.store_expression_type(default, default_ty);
3243+
self.deferred_state = previous_deferred_state;
3244+
}
32423245

3246+
fn infer_paramspec_default(&mut self, default: &ast::Expr) -> Type<'db> {
32433247
// This is the same logic as `TypeInferenceBuilder::infer_callable_parameter_types` except
32443248
// for the subscript branch which is required for `Concatenate` but that cannot be
32453249
// specified in this context.
3246-
let default_ty = match &**default {
3250+
match default {
32473251
ast::Expr::EllipsisLiteral(_) => {
32483252
CallableType::single(self.db(), Signature::new(Parameters::gradual_form(), None))
32493253
}
@@ -3289,8 +3293,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
32893293
if is_paramspec {
32903294
name_ty
32913295
} else {
3292-
if let Some(builder) = self.context.report_lint(&INVALID_PARAMSPEC, &**default)
3293-
{
3296+
if let Some(builder) = self.context.report_lint(&INVALID_PARAMSPEC, default) {
32943297
builder.into_diagnostic(
32953298
"The default value to `ParamSpec` must be either a list of types, \
32963299
`ParamSpec`, or `...`",
@@ -3300,18 +3303,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
33003303
}
33013304
}
33023305
_ => {
3303-
if let Some(builder) = self.context.report_lint(&INVALID_PARAMSPEC, &**default) {
3306+
if let Some(builder) = self.context.report_lint(&INVALID_PARAMSPEC, default) {
33043307
builder.into_diagnostic(
33053308
"The default value to `ParamSpec` must be either a list of types, \
33063309
`ParamSpec`, or `...`",
33073310
);
33083311
}
33093312
Type::unknown()
33103313
}
3311-
};
3312-
3313-
self.store_expression_type(default, default_ty);
3314-
self.deferred_state = previous_deferred_state;
3314+
}
33153315
}
33163316

33173317
fn infer_typevartuple_definition(
@@ -4761,8 +4761,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
47614761
}
47624762

47634763
fn infer_assignment_deferred(&mut self, value: &ast::Expr) {
4764-
// Infer deferred bounds/constraints/defaults of a legacy TypeVar.
4765-
let ast::Expr::Call(ast::ExprCall { arguments, .. }) = value else {
4764+
// Infer deferred bounds/constraints/defaults of a legacy TypeVar / ParamSpec.
4765+
let ast::Expr::Call(ast::ExprCall {
4766+
func, arguments, ..
4767+
}) = value
4768+
else {
47664769
return;
47674770
};
47684771
for arg in arguments.args.iter().skip(1) {
@@ -4771,10 +4774,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
47714774
if let Some(bound) = arguments.find_keyword("bound") {
47724775
self.infer_type_expression(&bound.value);
47734776
}
4774-
// TODO: We need to differentiate between the `default` argument to `TypeVar` and
4775-
// `ParamSpec` because the types they accept are different.
47764777
if let Some(default) = arguments.find_keyword("default") {
4777-
self.infer_type_expression(&default.value);
4778+
let func_ty = self.get_or_infer_expression(func, TypeContext::default());
4779+
if func_ty.as_class_literal().is_some_and(|class_literal| {
4780+
class_literal.is_known(self.db(), KnownClass::ParamSpec)
4781+
}) {
4782+
self.infer_paramspec_default(&default.value);
4783+
} else {
4784+
self.infer_type_expression(&default.value);
4785+
}
47784786
}
47794787
}
47804788

0 commit comments

Comments
 (0)