Skip to content

Rust: Rework type inference for impl Trait in return position #19954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ module Impl {
}
}

/** Holds if the call expression dispatches to a trait method. */
/** Holds if the call expression dispatches to a method. */
private predicate callIsMethodCall(CallExpr call, Path qualifier, string methodName) {
exists(Path path, Function f |
path = call.getFunction().(PathExpr).getPath() and
Expand Down
24 changes: 17 additions & 7 deletions rust/ql/lib/codeql/rust/internal/PathResolution.qll
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ abstract class ItemNode extends Locatable {
exists(ItemNode node |
this = node.(ImplItemNode).resolveSelfTy() and
result = node.getASuccessorRec(name) and
result instanceof AssocItemNode
result instanceof AssocItemNode and
not result instanceof TypeAlias
)
or
// trait items with default implementations made available in an implementation
Expand All @@ -181,6 +182,10 @@ abstract class ItemNode extends Locatable {
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
or
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
or
result = this.(TypeAliasItemNode).resolveAlias().getASuccessorRec(name) and
// type parameters defined in the RHS are not available in the LHS
not result instanceof TypeParam
}

/**
Expand Down Expand Up @@ -289,6 +294,8 @@ abstract class ItemNode extends Locatable {
Location getLocation() { result = super.getLocation() }
}

abstract class TypeItemNode extends ItemNode { }

/** A module or a source file. */
abstract private class ModuleLikeNode extends ItemNode {
/** Gets an item that may refer directly to items defined in this module. */
Expand Down Expand Up @@ -438,7 +445,7 @@ private class ConstItemNode extends AssocItemNode instanceof Const {
override TypeParam getTypeParam(int i) { none() }
}

private class EnumItemNode extends ItemNode instanceof Enum {
private class EnumItemNode extends TypeItemNode instanceof Enum {
override string getName() { result = Enum.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down Expand Up @@ -746,7 +753,7 @@ private class ModuleItemNode extends ModuleLikeNode instanceof Module {
}
}

private class StructItemNode extends ItemNode instanceof Struct {
private class StructItemNode extends TypeItemNode instanceof Struct {
override string getName() { result = Struct.super.getName().getText() }

override Namespace getNamespace() {
Expand Down Expand Up @@ -781,7 +788,7 @@ private class StructItemNode extends ItemNode instanceof Struct {
}
}

class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
class TraitItemNode extends ImplOrTraitItemNode, TypeItemNode instanceof Trait {
pragma[nomagic]
Path getABoundPath() {
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
Expand Down Expand Up @@ -838,7 +845,10 @@ class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
}
}

class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
class TypeAliasItemNode extends TypeItemNode, AssocItemNode instanceof TypeAlias {
pragma[nomagic]
ItemNode resolveAlias() { result = resolvePathFull(super.getTypeRepr().(PathTypeRepr).getPath()) }

override string getName() { result = TypeAlias.super.getName().getText() }

override predicate hasImplementation() { super.hasTypeRepr() }
Expand All @@ -854,7 +864,7 @@ class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
override string getCanonicalPath(Crate c) { none() }
}

private class UnionItemNode extends ItemNode instanceof Union {
private class UnionItemNode extends TypeItemNode instanceof Union {
override string getName() { result = Union.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down Expand Up @@ -912,7 +922,7 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
override string getCanonicalPath(Crate c) { none() }
}

class TypeParamItemNode extends ItemNode instanceof TypeParam {
class TypeParamItemNode extends TypeItemNode instanceof TypeParam {
private WherePred getAWherePred() {
exists(ItemNode declaringItem |
this = resolveTypeParamPathTypeRepr(result.getTypeRepr()) and
Expand Down
85 changes: 15 additions & 70 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ newtype TType =
TTrait(Trait t) or
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TImplTraitType(ImplTraitTypeRepr impl) or
TImplTraitArgumentType(Function function, ImplTraitTypeRepr impl) {
impl = function.getAParam().getTypeRepr()
} or
TSliceType() or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
Expand Down Expand Up @@ -139,9 +141,6 @@ class TraitType extends Type, TTrait {

override TypeParameter getTypeParameter(int i) {
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
or
result =
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
}

override TypeMention getTypeParameterDefault(int i) {
Expand Down Expand Up @@ -199,53 +198,6 @@ class RefType extends Type, TRefType {
override Location getLocation() { result instanceof EmptyLocation }
}

/**
* An [impl Trait][1] type.
*
* Each syntactic `impl Trait` type gives rise to its own type, even if
* two `impl Trait` types have the same bounds.
*
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html
*/
class ImplTraitType extends Type, TImplTraitType {
ImplTraitTypeRepr impl;

ImplTraitType() { this = TImplTraitType(impl) }

/** Gets the underlying AST node. */
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }

/** Gets the function that this `impl Trait` belongs to. */
abstract Function getFunction();

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }

override string toString() { result = impl.toString() }

override Location getLocation() { result = impl.getLocation() }
}

/**
* An [impl Trait in return position][1] type, for example:
*
* ```rust
* fn foo() -> impl Trait
* ```
*
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.return
*/
class ImplTraitReturnType extends ImplTraitType {
private Function function;

ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }

override Function getFunction() { result = function }
}

/**
* A slice type.
*
Expand Down Expand Up @@ -299,20 +251,6 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
override Location getLocation() { result = typeParam.getLocation() }
}

/**
* Gets the type alias that is the `i`th type parameter of `trait`. Type aliases
* are numbered consecutively but in arbitrary order, starting from the index
* following the last ordinary type parameter.
*/
predicate traitAliasIndex(Trait trait, int i, TypeAlias typeAlias) {
typeAlias =
rank[i + 1 - trait.getNumberOfGenericParams()](TypeAlias alias |
trait.(TraitItemNode).getADescendant() = alias
|
alias order by idOfTypeParameterAstNode(alias)
)
}

/**
* A type parameter corresponding to an associated type in a trait.
*
Expand Down Expand Up @@ -341,8 +279,6 @@ class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypePara
/** Gets the trait that contains this associated type declaration. */
TraitItemNode getTrait() { result.getAnAssocItem() = typeAlias }

int getIndex() { traitAliasIndex(_, result, typeAlias) }

override string toString() { result = typeAlias.getName().getText() }

override Location getLocation() { result = typeAlias.getLocation() }
Expand Down Expand Up @@ -405,18 +341,27 @@ class SelfTypeParameter extends TypeParameter, TSelfTypeParameter {
*
* [1]: https://doc.rust-lang.org/reference/types/impl-trait.html#r-type.impl-trait.param
*/
class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {
class ImplTraitArgumentType extends TypeParameter, TImplTraitArgumentType {
private Function function;
private ImplTraitTypeRepr impl;

ImplTraitTypeTypeParameter() { impl = function.getAParam().getTypeRepr() }
ImplTraitArgumentType() { this = TImplTraitArgumentType(function, impl) }

override Function getFunction() { result = function }
/** Gets the function that this `impl Trait` belongs to. */
Function getFunction() { result = function }

/** Gets the underlying AST node. */
ImplTraitTypeRepr getImplTraitTypeRepr() { result = impl }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }

override string toString() { result = impl.toString() }

override Location getLocation() { result = impl.getLocation() }
}

/**
Expand Down
Loading