Skip to content

Commit

Permalink
Merge pull request #364 from Nadrieril/debruijn2
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadrieril authored Sep 16, 2024
2 parents 8579246 + 585c42e commit 0825fda
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 116 deletions.
93 changes: 47 additions & 46 deletions charon/src/bin/charon-driver/translate/translate_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use rustc_error_messages::MultiSpan;
use rustc_hir::def_id::DefId;
use rustc_hir::Node as HirNode;
use rustc_middle::ty::TyCtxt;
use std::borrow::Cow;
use std::cmp::{Ord, PartialOrd};
use std::collections::HashMap;
use std::collections::{BTreeMap, VecDeque};
Expand Down Expand Up @@ -238,6 +239,7 @@ pub(crate) struct BodyTransCtx<'tcx, 'ctx, 'ctx1> {
pub t_ctx: &'ctx mut TranslateCtx<'tcx, 'ctx1>,
/// A hax state with an owner id
pub hax_state: hax::State<hax::Base<'tcx>, (), (), rustc_hir::def_id::DefId>,

/// The regions.
/// We use DeBruijn indices, so we have a stack of regions.
/// See the comments for [Region::BVar].
Expand Down Expand Up @@ -268,25 +270,17 @@ pub(crate) struct BodyTransCtx<'tcx, 'ctx, 'ctx1> {
/// ==============
/// We use DeBruijn indices. See the comments for [Region::Var].
pub bound_region_vars: VecDeque<Box<[RegionId]>>,
/// The type variables
pub type_vars: Vector<TypeVarId, TypeVar>,
/// The map from rust type variable indices to translated type variable
/// indices.
/// The generic parameters for the item. `regions` must be empty, as regions are handled
/// separately.
pub generic_params: GenericParams,
/// The map from rust type variable indices to translated type variable indices.
pub type_vars_map: HashMap<u32, TypeVarId>,
/// The "regular" variables
pub vars: Vector<VarId, Var>,
/// The map from rust variable indices to translated variables indices.
pub vars_map: HashMap<usize, VarId>,
/// The const generic variables
pub const_generic_vars: Vector<ConstGenericVarId, ConstGenericVar>,
/// The map from rust const generic variables to translate const generic
/// variable indices.
pub const_generic_vars_map: HashMap<u32, ConstGenericVarId>,
/// Trait refs we couldn't solve at the moment of translating them and will solve in a second
/// pass before extracting the generic params.
pub unsolved_traits: Vector<UnsolvedTraitId, hax::TraitRef>,
/// Accumulated clauses to be put into the item's `GenericParams`.
pub param_trait_clauses: Vector<TraitClauseId, TraitClause>,
/// (For traits only) accumulated implied trait clauses.
pub parent_trait_clauses: Vector<TraitClauseId, TraitClause>,
/// (For traits only) accumulated trait clauses on associated types.
Expand All @@ -296,14 +290,13 @@ pub(crate) struct BodyTransCtx<'tcx, 'ctx, 'ctx1> {
/// corresponds to.
/// FIXME: hax should take care of this matching up.
/// We use a betreemap to get a consistent output order and `OrdRustId` to get an orderable
/// `DefId`.
pub trait_clauses: BTreeMap<OrdRustId, Vec<NonLocalTraitClause>>,
///
pub types_outlive: Vec<TypeOutlives>,
///
pub regions_outlive: Vec<RegionOutlives>,
///
pub trait_type_constraints: Vec<TraitTypeConstraint>,
/// `DefId` but they're all `OrdRustId::TraitDecl`.
pub trait_clauses_map: BTreeMap<OrdRustId, Vec<NonLocalTraitClause>>,

/// The "regular" variables
pub vars: Vector<VarId, ast::Var>,
/// The map from rust variable indices to translated variables indices.
pub vars_map: HashMap<usize, VarId>,
/// The translated blocks. We can't use `ast::Vector<BlockId, ast::BlockData>`
/// here because we might generate several fresh indices before actually
/// adding the resulting blocks to the map.
Expand Down Expand Up @@ -926,20 +919,15 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
region_vars: [Vector::new()].into(),
free_region_vars: Default::default(),
bound_region_vars: Default::default(),
type_vars: Default::default(),
generic_params: Default::default(),
type_vars_map: Default::default(),
vars: Default::default(),
vars_map: Default::default(),
const_generic_vars: Default::default(),
const_generic_vars_map: Default::default(),
unsolved_traits: Default::default(),
param_trait_clauses: Default::default(),
parent_trait_clauses: Default::default(),
item_trait_clauses: Default::default(),
trait_clauses: Default::default(),
regions_outlive: Default::default(),
types_outlive: Default::default(),
trait_type_constraints: Default::default(),
trait_clauses_map: Default::default(),
vars: Default::default(),
vars_map: Default::default(),
blocks: Default::default(),
blocks_map: Default::default(),
blocks_stack: Default::default(),
Expand Down Expand Up @@ -1079,7 +1067,10 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
}

pub(crate) fn push_type_var(&mut self, rid: u32, name: String) -> TypeVarId {
let var_id = self.type_vars.push_with(|index| TypeVar { index, name });
let var_id = self
.generic_params
.types
.push_with(|index| TypeVar { index, name });
self.type_vars_map.insert(rid, var_id);
var_id
}
Expand All @@ -1091,7 +1082,8 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {

pub(crate) fn push_const_generic_var(&mut self, rid: u32, ty: LiteralTy, name: String) {
let var_id = self
.const_generic_vars
.generic_params
.const_generics
.push_with(|index| ConstGenericVar { index, name, ty });
self.const_generic_vars_map.insert(rid, var_id);
}
Expand All @@ -1111,20 +1103,15 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
// Sanity checks
self.check_generics();
assert!(self.region_vars.len() == 1);
assert!(self
.param_trait_clauses
let mut generic_params = self.generic_params.clone();
assert!(generic_params.regions.is_empty());
generic_params.regions = self.region_vars[0].clone();
assert!(generic_params
.trait_clauses
.iter()
.enumerate()
.all(|(i, c)| c.clause_id.index() == i));
let mut generic_params = GenericParams {
regions: self.region_vars[0].clone(),
types: self.type_vars.clone(),
const_generics: self.const_generic_vars.clone(),
trait_clauses: self.param_trait_clauses.clone(),
regions_outlive: self.regions_outlive.clone(),
types_outlive: self.types_outlive.clone(),
trait_type_constraints: self.trait_type_constraints.clone(),
};

// Solve trait refs now that all clauses have been registered.
generic_params.drive_mut(&mut visitor_enter_fn_mut(|tref_kind: &mut TraitRefKind| {
if let TraitRefKind::Unsolved(unsolved_trait_id) = *tref_kind {
Expand All @@ -1136,7 +1123,7 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
let fmt_ctx = self.into_fmt();
let trait_ref = format!("{:?}", hax_trait_ref);
let clauses: Vec<String> = self
.trait_clauses
.trait_clauses_map
.values()
.flat_map(|x| x)
.map(|x| x.fmt_with_ctx(&fmt_ctx))
Expand Down Expand Up @@ -1193,11 +1180,25 @@ impl<'tcx, 'ctx, 'ctx1, 'a> IntoFormatter for &'a BodyTransCtx<'tcx, 'ctx, 'ctx1
type C = FmtCtx<'a>;

fn into_fmt(self) -> Self::C {
// Translate our generics into a stack of generics. Only the outermost binder has
// non-region parameters.
let mut generics: VecDeque<Cow<'_, GenericParams>> = self
.region_vars
.iter()
.cloned()
.map(|regions| {
Cow::Owned(GenericParams {
regions,
..Default::default()
})
})
.collect();
let outermost_generics = generics.back_mut().unwrap().to_mut();
outermost_generics.types = self.generic_params.types.clone();
outermost_generics.const_generics = self.generic_params.const_generics.clone();
FmtCtx {
translated: Some(&self.t_ctx.translated),
region_vars: self.region_vars.clone(),
type_vars: Some(&self.type_vars),
const_generic_vars: Some(&self.const_generic_vars),
generics,
locals: Some(&self.vars),
}
}
Expand Down
17 changes: 10 additions & 7 deletions charon/src/bin/charon-driver/translate/translate_predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,11 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
match self.translate_predicate(pred, span, origin.clone(), location)? {
None => (),
Some(pred) => match pred {
Predicate::TypeOutlives(p) => self.types_outlive.push(p),
Predicate::RegionOutlives(p) => self.regions_outlive.push(p),
Predicate::TraitType(p) => self.trait_type_constraints.push(p),
Predicate::TypeOutlives(p) => self.generic_params.types_outlive.push(p),
Predicate::RegionOutlives(p) => self.generic_params.regions_outlive.push(p),
Predicate::TraitType(p) => {
self.generic_params.trait_type_constraints.push(p)
}
Predicate::Trait(_) => unreachable!(),
},
}
Expand Down Expand Up @@ -193,7 +195,7 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {

let trait_decl_ref = self.translate_trait_predicate(hspan, trait_pred)?;
let vec = match location {
PredicateLocation::Base => &mut self.param_trait_clauses,
PredicateLocation::Base => &mut self.generic_params.trait_clauses,
PredicateLocation::Parent(..) => &mut self.parent_trait_clauses,
PredicateLocation::Item(.., item_name) => self
.item_trait_clauses
Expand Down Expand Up @@ -222,7 +224,7 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
),
};
let def_id = DefId::from(&trait_pred.trait_ref.def_id);
self.trait_clauses
self.trait_clauses_map
.entry(OrdRustId::TraitDecl(def_id))
.or_default()
.push(NonLocalTraitClause {
Expand All @@ -247,7 +249,7 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {
trait_ref_kind: TraitRefKind::SelfId,
};
let def_id = DefId::from(&trait_pred.trait_ref.def_id);
self.trait_clauses
self.trait_clauses_map
.entry(OrdRustId::TraitDecl(def_id))
.or_default()
.push(clause);
Expand Down Expand Up @@ -553,7 +555,8 @@ impl<'tcx, 'ctx, 'ctx1> BodyTransCtx<'tcx, 'ctx, 'ctx1> {

// Simply explore the trait clauses
let def_id = DefId::from(&hax_trait_ref.def_id);
if let Some(clauses_for_this_trait) = self.trait_clauses.get(&OrdRustId::TraitDecl(def_id))
if let Some(clauses_for_this_trait) =
self.trait_clauses_map.get(&OrdRustId::TraitDecl(def_id))
{
for trait_clause in clauses_for_this_trait {
if trait_clause.matches(hax_trait_ref) {
Expand Down
2 changes: 1 addition & 1 deletion charon/src/bin/charon-driver/translate/translate_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl<'tcx, 'ctx> TranslateCtx<'tcx, 'ctx> {
{
let ctx = bt_ctx.into_fmt();
let clauses = bt_ctx
.trait_clauses
.trait_clauses_map
.values()
.flat_map(|x| x)
.map(|c| c.fmt_with_ctx(&ctx))
Expand Down
88 changes: 26 additions & 62 deletions charon/src/pretty/formatter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::collections::VecDeque;

use crate::ast::*;
Expand Down Expand Up @@ -77,22 +78,10 @@ impl<'a, 'b> SetGenerics<'a> for FmtCtx<'b> {
type C = FmtCtx<'a>;

fn set_generics(&'a self, generics: &'a GenericParams) -> Self::C {
let FmtCtx {
translated,
region_vars: _,
type_vars: _,
const_generic_vars: _,
locals,
} = self;

let translated = translated.as_deref();
let locals = locals.as_deref();
FmtCtx {
translated,
region_vars: [generics.regions.clone()].into(),
type_vars: Some(&generics.types),
const_generic_vars: Some(&generics.const_generics),
locals,
translated: self.translated.as_deref(),
generics: [Cow::Borrowed(generics)].into(),
locals: self.locals.as_deref(),
}
}
}
Expand All @@ -109,22 +98,9 @@ impl<'a, 'b> SetLocals<'a> for FmtCtx<'b> {
type C = FmtCtx<'a>;

fn set_locals(&'a self, locals: &'a Vector<VarId, ast::Var>) -> Self::C {
let FmtCtx {
translated,
region_vars,
type_vars,
const_generic_vars,
locals: _,
} = self;

let translated = translated.as_deref();
let type_vars = type_vars.as_deref();
let const_generic_vars = const_generic_vars.as_deref();
FmtCtx {
translated,
region_vars: region_vars.clone(),
type_vars,
const_generic_vars,
translated: self.translated.as_deref(),
generics: self.generics.clone(),
locals: Some(locals),
}
}
Expand All @@ -134,33 +110,22 @@ impl<'a, 'b> SetLocals<'a> for FmtCtx<'b> {
pub trait PushBoundRegions<'a> {
type C: 'a + AstFormatter;

fn push_bound_regions(&'a self, regions: &Vector<RegionId, RegionVar>) -> Self::C;
fn push_bound_regions(&'a self, regions: &'a Vector<RegionId, RegionVar>) -> Self::C;
}

impl<'a, 'b> PushBoundRegions<'a> for FmtCtx<'b> {
type C = FmtCtx<'a>;

fn push_bound_regions(&'a self, regions: &Vector<RegionId, RegionVar>) -> Self::C {
let FmtCtx {
translated,
region_vars,
type_vars,
const_generic_vars,
locals,
} = self;

let translated = translated.as_deref();
let type_vars = type_vars.as_deref();
let const_generic_vars = const_generic_vars.as_deref();
let locals = locals.as_deref();
let mut region_vars = region_vars.clone();
region_vars.push_front(regions.clone());
fn push_bound_regions(&'a self, regions: &'a Vector<RegionId, RegionVar>) -> Self::C {
let mut generics = self.generics.clone();
generics.push_front(Cow::Owned(GenericParams {
regions: regions.clone(),
..Default::default()
}));
FmtCtx {
translated,
region_vars,
type_vars,
const_generic_vars,
locals,
translated: self.translated.as_deref(),
generics,
locals: self.locals.as_deref(),
}
}
}
Expand Down Expand Up @@ -200,10 +165,9 @@ pub trait AstFormatter = Formatter<TypeVarId>
#[derive(Default)]
pub struct FmtCtx<'a> {
pub translated: Option<&'a TranslatedCrate>,
/// The region variables are not an option, because we need to be able to push/pop
pub region_vars: VecDeque<Vector<RegionId, RegionVar>>,
pub type_vars: Option<&'a Vector<TypeVarId, TypeVar>>,
pub const_generic_vars: Option<&'a Vector<ConstGenericVarId, ConstGenericVar>>,
/// Generics form a stack, where each binder introduces a new level. For DeBruijn indices to
/// work, we keep the innermost parameters at the start of the vector.
pub generics: VecDeque<Cow<'a, GenericParams>>,
pub locals: Option<&'a Vector<VarId, ast::Var>>,
}

Expand Down Expand Up @@ -331,9 +295,9 @@ impl<'a> Formatter<AnyTransId> for FmtCtx<'a> {

impl<'a> Formatter<(DeBruijnId, RegionId)> for FmtCtx<'a> {
fn format_object(&self, (grid, id): (DeBruijnId, RegionId)) -> String {
match self.region_vars.get(grid.index) {
match self.generics.get(grid.index) {
None => Region::BVar(grid, id).to_string(),
Some(gr) => match gr.get(id) {
Some(generics) => match generics.regions.get(id) {
None => {
let region = Region::BVar(grid, id);
tracing::warn!(
Expand All @@ -353,7 +317,7 @@ impl<'a> Formatter<&RegionVar> for FmtCtx<'a> {
match &var.name {
Some(name) => name.to_string(),
None => {
let depth = self.region_vars.len() - 1;
let depth = self.generics.len() - 1;
if depth == 0 {
format!("'_{}", var.index)
} else {
Expand All @@ -366,9 +330,9 @@ impl<'a> Formatter<&RegionVar> for FmtCtx<'a> {

impl<'a> Formatter<TypeVarId> for FmtCtx<'a> {
fn format_object(&self, id: TypeVarId) -> String {
match &self.type_vars {
match &self.generics.back() {
None => id.to_pretty_string(),
Some(vars) => match vars.get(id) {
Some(generics) => match generics.types.get(id) {
None => id.to_pretty_string(),
Some(v) => v.to_string(),
},
Expand All @@ -378,9 +342,9 @@ impl<'a> Formatter<TypeVarId> for FmtCtx<'a> {

impl<'a> Formatter<ConstGenericVarId> for FmtCtx<'a> {
fn format_object(&self, id: ConstGenericVarId) -> String {
match &self.const_generic_vars {
match &self.generics.back() {
None => id.to_pretty_string(),
Some(vars) => match vars.get(id) {
Some(generics) => match generics.const_generics.get(id) {
None => id.to_pretty_string(),
Some(v) => v.to_string(),
},
Expand Down
Loading

0 comments on commit 0825fda

Please sign in to comment.