Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make hir ProjectionKind more precise #74140

Merged
merged 1 commit into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/librustc_typeck/expr_use_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use rustc_hir as hir;
use rustc_hir::def::Res;
use rustc_hir::def_id::LocalDefId;
use rustc_hir::PatKind;
use rustc_index::vec::Idx;
use rustc_infer::infer::InferCtxt;
use rustc_middle::ty::{self, adjustment, TyCtxt};
use rustc_target::abi::VariantIdx;

use crate::mem_categorization as mc;
use rustc_span::Span;
Expand Down Expand Up @@ -396,6 +398,7 @@ impl<'a, 'tcx> ExprUseVisitor<'a, 'tcx> {
&*with_expr,
with_place.clone(),
with_field.ty(self.tcx(), substs),
mc::ProjectionKind::Field(f_index as u32, VariantIdx::new(0)),
);
self.delegate_consume(&field_place);
}
Expand Down
170 changes: 156 additions & 14 deletions src/librustc_typeck/mem_categorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ use rustc_middle::ty::{self, Ty, TyCtxt};

use rustc_data_structures::fx::FxIndexMap;
use rustc_hir as hir;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::def::{CtorOf, DefKind, Res};
use rustc_hir::def_id::LocalDefId;
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
use rustc_hir::PatKind;
use rustc_index::vec::Idx;
use rustc_infer::infer::InferCtxt;
use rustc_span::Span;
use rustc_target::abi::VariantIdx;
use rustc_trait_selection::infer::InferCtxtExt;

#[derive(Clone, Debug)]
Expand All @@ -77,8 +80,20 @@ pub enum PlaceBase {
pub enum ProjectionKind {
/// A dereference of a pointer, reference or `Box<T>` of the given type
Deref,
/// An index or a field
Other,

/// `B.F` where `B` is the base expression and `F` is
/// the field. The field is identified by which variant
/// it appears in along with a field index. The variant
/// is used for enums.
Field(u32, VariantIdx),

/// Some index like `B[x]`, where `B` is the base
/// expression. We don't preserve the index `x` because
/// we won't need it.
Index,

/// A subslice covering a range of values like `B[x..y]`.
Subslice,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -406,7 +421,20 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> {
hir::ExprKind::Field(ref base, _) => {
let base = self.cat_expr(&base)?;
debug!("cat_expr(cat_field): id={} expr={:?} base={:?}", expr.hir_id, expr, base);
Ok(self.cat_projection(expr, base, expr_ty))

let field_idx = self
.tables
.field_indices()
.get(expr.hir_id)
.cloned()
.expect("Field index not found");

Ok(self.cat_projection(
expr,
base,
expr_ty,
ProjectionKind::Field(field_idx as u32, VariantIdx::new(0)),
))
}

hir::ExprKind::Index(ref base, _) => {
Expand All @@ -419,7 +447,7 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> {
self.cat_overloaded_place(expr, base)
} else {
let base = self.cat_expr(&base)?;
Ok(self.cat_projection(expr, base, expr_ty))
Ok(self.cat_projection(expr, base, expr_ty, ProjectionKind::Index))
}
}

Expand Down Expand Up @@ -533,9 +561,10 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> {
node: &N,
base_place: PlaceWithHirId<'tcx>,
ty: Ty<'tcx>,
kind: ProjectionKind,
) -> PlaceWithHirId<'tcx> {
let mut projections = base_place.place.projections;
projections.push(Projection { kind: ProjectionKind::Other, ty: ty });
projections.push(Projection { kind: kind, ty: ty });
let ret = PlaceWithHirId::new(
node.hir_id(),
base_place.place.base_ty,
Expand Down Expand Up @@ -609,6 +638,75 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> {
self.cat_pattern_(place, pat, &mut op)
}

/// Returns the variant index for an ADT used within a Struct or TupleStruct pattern
/// Here `pat_hir_id` is the HirId of the pattern itself.
fn variant_index_for_adt(
arora-aman marked this conversation as resolved.
Show resolved Hide resolved
&self,
qpath: &hir::QPath<'_>,
pat_hir_id: hir::HirId,
span: Span,
) -> McResult<VariantIdx> {
let res = self.tables.qpath_res(qpath, pat_hir_id);
let ty = self.tables.node_type(pat_hir_id);
let adt_def = match ty.kind {
ty::Adt(adt_def, _) => adt_def,
_ => {
self.tcx()
.sess
.delay_span_bug(span, "struct or tuple struct pattern not applied to an ADT");
return Err(());
}
};

match res {
Res::Def(DefKind::Variant, variant_id) => Ok(adt_def.variant_index_with_id(variant_id)),
Res::Def(DefKind::Ctor(CtorOf::Variant, ..), variant_ctor_id) => {
Ok(adt_def.variant_index_with_ctor_id(variant_ctor_id))
}
Res::Def(DefKind::Ctor(CtorOf::Struct, ..), _)
| Res::Def(DefKind::Struct | DefKind::Union | DefKind::TyAlias | DefKind::AssocTy, _)
| Res::SelfCtor(..)
| Res::SelfTy(..) => {
// Structs and Unions have only have one variant.
Ok(VariantIdx::new(0))
}
_ => bug!("expected ADT path, found={:?}", res),
}
}

/// Returns the total number of fields in an ADT variant used within a pattern.
/// Here `pat_hir_id` is the HirId of the pattern itself.
fn total_fields_in_adt_variant(
&self,
pat_hir_id: hir::HirId,
variant_index: VariantIdx,
span: Span,
) -> McResult<usize> {
let ty = self.tables.node_type(pat_hir_id);
match ty.kind {
ty::Adt(adt_def, _) => Ok(adt_def.variants[variant_index].fields.len()),
_ => {
self.tcx()
.sess
.delay_span_bug(span, "struct or tuple struct pattern not applied to an ADT");
return Err(());
}
}
}

/// Returns the total number of fields in a tuple used within a Tuple pattern.
/// Here `pat_hir_id` is the HirId of the pattern itself.
fn total_fields_in_tuple(&self, pat_hir_id: hir::HirId, span: Span) -> McResult<usize> {
let ty = self.tables.node_type(pat_hir_id);
match ty.kind {
ty::Tuple(substs) => Ok(substs.len()),
_ => {
self.tcx().sess.delay_span_bug(span, "tuple pattern not applied to a tuple");
return Err(());
}
}
}

// FIXME(#19596) This is a workaround, but there should be a better way to do this
fn cat_pattern_<F>(
&self,
Expand Down Expand Up @@ -679,20 +777,54 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> {
op(&place_with_id, pat);

match pat.kind {
PatKind::TupleStruct(_, ref subpats, _) | PatKind::Tuple(ref subpats, _) => {
// S(p1, ..., pN) or (p1, ..., pN)
for subpat in subpats.iter() {
PatKind::Tuple(ref subpats, dots_pos) => {
// (p1, ..., pN)
let total_fields = self.total_fields_in_tuple(pat.hir_id, pat.span)?;

for (i, subpat) in subpats.iter().enumerate_and_adjust(total_fields, dots_pos) {
let subpat_ty = self.pat_ty_adjusted(&subpat)?;
let sub_place = self.cat_projection(pat, place_with_id.clone(), subpat_ty);
let projection_kind = ProjectionKind::Field(i as u32, VariantIdx::new(0));
let sub_place =
self.cat_projection(pat, place_with_id.clone(), subpat_ty, projection_kind);
self.cat_pattern_(sub_place, &subpat, op)?;
}
}

PatKind::Struct(_, field_pats, _) => {
PatKind::TupleStruct(ref qpath, ref subpats, dots_pos) => {
// S(p1, ..., pN)
let variant_index = self.variant_index_for_adt(qpath, pat.hir_id, pat.span)?;
let total_fields =
self.total_fields_in_adt_variant(pat.hir_id, variant_index, pat.span)?;

for (i, subpat) in subpats.iter().enumerate_and_adjust(total_fields, dots_pos) {
let subpat_ty = self.pat_ty_adjusted(&subpat)?;
let projection_kind = ProjectionKind::Field(i as u32, variant_index);
let sub_place =
self.cat_projection(pat, place_with_id.clone(), subpat_ty, projection_kind);
self.cat_pattern_(sub_place, &subpat, op)?;
}
}

PatKind::Struct(ref qpath, field_pats, _) => {
// S { f1: p1, ..., fN: pN }

let variant_index = self.variant_index_for_adt(qpath, pat.hir_id, pat.span)?;

for fp in field_pats {
let field_ty = self.pat_ty_adjusted(&fp.pat)?;
let field_place = self.cat_projection(pat, place_with_id.clone(), field_ty);
let field_index = self
.tables
.field_indices()
.get(fp.hir_id)
.cloned()
.expect("no index for a field");

let field_place = self.cat_projection(
pat,
place_with_id.clone(),
field_ty,
ProjectionKind::Field(field_index as u32, variant_index),
);
self.cat_pattern_(field_place, &fp.pat, op)?;
}
}
Expand Down Expand Up @@ -723,13 +855,23 @@ impl<'a, 'tcx> MemCategorizationContext<'a, 'tcx> {
return Err(());
}
};
let elt_place = self.cat_projection(pat, place_with_id.clone(), element_ty);
let elt_place = self.cat_projection(
pat,
place_with_id.clone(),
element_ty,
ProjectionKind::Index,
);
for before_pat in before {
self.cat_pattern_(elt_place.clone(), &before_pat, op)?;
}
if let Some(ref slice_pat) = *slice {
let slice_pat_ty = self.pat_ty_adjusted(&slice_pat)?;
let slice_place = self.cat_projection(pat, place_with_id, slice_pat_ty);
let slice_place = self.cat_projection(
pat,
place_with_id,
slice_pat_ty,
ProjectionKind::Subslice,
);
self.cat_pattern_(slice_place, &slice_pat, op)?;
}
for after_pat in after {
Expand Down