Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid cloning the typechecker in elaborate_statement #543

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions numbat/src/name_resolution.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::collections::HashMap;

use thiserror::Error;

use crate::span::Span;
use crate::{span::Span, typechecker::map_stack::MapStack};

pub const LAST_RESULT_IDENTIFIERS: &[&str] = &["ans", "_"];

Expand All @@ -23,10 +21,18 @@ pub enum NameResolutionError {

#[derive(Debug, Clone, Default)]
pub struct Namespace {
seen: HashMap<String, (String, Span)>,
seen: MapStack<String, (String, Span)>,
}

impl Namespace {
pub(crate) fn save(&mut self) {
self.seen.save()
}

pub(crate) fn restore(&mut self) {
self.seen.restore()
}

pub fn add_identifier_allow_override(
&mut self,
name: String,
Expand Down Expand Up @@ -73,7 +79,7 @@ impl Namespace {
});
}

self.seen.insert(name, (item_type, span));
let _ = self.seen.insert(name, (item_type, span));

Ok(())
}
Expand Down
21 changes: 16 additions & 5 deletions numbat/src/typechecker/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use crate::type_variable::TypeVariable;
use crate::typed_ast::pretty_print_function_signature;
use crate::Type;

use super::map_stack::MapStack;
use super::substitutions::{ApplySubstitution, Substitution, SubstitutionError};
use super::type_scheme::TypeScheme;

use std::collections::HashMap;

type Identifier = String;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -94,7 +93,7 @@ impl IdentifierKind {

#[derive(Clone, Debug, Default)]
pub struct Environment {
identifiers: HashMap<Identifier, IdentifierKind>,
identifiers: MapStack<Identifier, IdentifierKind>,
}

impl Environment {
Expand All @@ -110,6 +109,14 @@ impl Environment {
.insert(i, IdentifierKind::Normal(scheme, span, is_unit));
}

pub(crate) fn save(&mut self) {
self.identifiers.save();
}

pub(crate) fn restore(&mut self) {
self.identifiers.restore();
}

pub(crate) fn add_function(
&mut self,
v: String,
Expand All @@ -126,7 +133,7 @@ impl Environment {
}

pub(crate) fn get_identifier_type(&self, v: &str) -> Option<TypeScheme> {
self.identifiers.get(v).map(|k| k.get_type())
self.find(v).map(|k| k.get_type())
}

pub(crate) fn iter_identifiers(&self) -> impl Iterator<Item = &Identifier> {
Expand All @@ -145,11 +152,15 @@ impl Environment {
.map(|(id, kind)| (id, kind.get_type()))
}

fn find(&self, name: &str) -> Option<&IdentifierKind> {
self.identifiers.get(name)
}

pub(crate) fn get_function_info(
&self,
name: &str,
) -> Option<(&FunctionSignature, &FunctionMetadata)> {
match self.identifiers.get(name) {
match self.find(name) {
Some(IdentifierKind::Function(signature, metadata)) => Some((signature, metadata)),
_ => None,
}
Expand Down
82 changes: 82 additions & 0 deletions numbat/src/typechecker/map_stack.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use std::hash::Hash;
use std::{borrow::Borrow, collections::HashMap};

/// A stack of hash maps. All insertions affect the hash map at the top of the
/// stack (which is the last element of the `stack` vector), preserving any
/// entries in maps below. The `save` function can be used to push a new map on
/// the top of the stack, in effect saving the current state of the map, which
/// one can then restore with `restore`.
///
/// The stack vector should never be empty
#[derive(Debug, Clone)]
pub(crate) struct MapStack<K, V> {
stack: Vec<HashMap<K, V>>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably only use two layers at the moment, but it's great to have this more general stack, since we're going to need it once we add nested functions, I guess.

}

impl<K, V> Default for MapStack<K, V> {
fn default() -> Self {
MapStack {
stack: vec![Default::default()],
}
}
}

impl<K: Hash + Eq, V> MapStack<K, V> {
fn iter_dict(&self) -> impl Iterator<Item = &HashMap<K, V>> {
self.stack.iter().rev()
}

fn iter_dict_mut(&mut self) -> impl Iterator<Item = &mut HashMap<K, V>> {
self.stack.iter_mut().rev()
}

pub(crate) fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
self.iter_dict().flatten()
}

pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = (&K, &mut V)> {
self.iter_dict_mut().flatten()
}

pub(crate) fn keys(&self) -> impl Iterator<Item = &K> {
self.iter_dict().map(|dict| dict.keys()).flatten()
}

pub(crate) fn get<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.iter_dict().filter_map(|dict| dict.get(key)).nth(0)
}

pub(crate) fn insert(&mut self, key: K, value: V) {
let _ = self.stack.last_mut().unwrap().insert(key, value);
}

pub(crate) fn contains_key<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.iter_dict().any(|dict| dict.contains_key(key))
}

/// Remove the top hash map from the stack, making the next one the
/// current top of the stack, restoring the state of the map before the
/// last call to save.
pub(crate) fn restore(&mut self) {
let _ = self.stack.pop();
// The stack should never be empty
assert!(
!self.stack.is_empty(),
"Tried to restore the last saved state but nothing was saved"
);
}

/// Save the current state of the map by making the top of the stack a
/// new empty map.
pub(crate) fn save(&mut self) {
self.stack.push(HashMap::default());
}
}
60 changes: 33 additions & 27 deletions numbat/src/typechecker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod constraints;
mod environment;
mod error;
mod incompatible_dimensions;
pub mod map_stack;
mod name_generator;
pub mod qualified_type;
mod substitutions;
Expand Down Expand Up @@ -55,8 +56,8 @@ pub struct TypeChecker {

type_namespace: Namespace,
value_namespace: Namespace,

env: Environment,

name_generator: NameGenerator,
constraints: ConstraintSet,
}
Expand Down Expand Up @@ -1310,32 +1311,35 @@ impl TypeChecker {
)?;
}

let mut typechecker_fn = self.clone();
// Save the environment and namespaces to avoid polluting
// their parents with the locals of this function
self.env.save();
self.type_namespace.save();
self.value_namespace.save();

let is_ffi_function = body.is_none();

for (span, type_parameter, bound) in type_parameters {
if typechecker_fn.type_namespace.has_identifier(type_parameter) {
if self.type_namespace.has_identifier(type_parameter) {
return Err(TypeCheckError::TypeParameterNameClash(
*span,
type_parameter.clone(),
));
}

typechecker_fn
.type_namespace
self.type_namespace
.add_identifier(type_parameter.clone(), *span, "type parameter".to_owned())
.ok(); // TODO: is this call even correct?

typechecker_fn.registry.introduced_type_parameters.push((
self.registry.introduced_type_parameters.push((
*span,
type_parameter.clone(),
bound.clone(),
));

match bound {
Some(TypeParameterBound::Dim) => {
typechecker_fn
.add_dtype_constraint(&Type::TPar(type_parameter.clone()))
self.add_dtype_constraint(&Type::TPar(type_parameter.clone()))
.ok();
}
None => {}
Expand All @@ -1346,12 +1350,12 @@ impl TypeChecker {
for (parameter_span, parameter, type_annotation) in parameters {
let annotated_type = type_annotation
.as_ref()
.map(|a| typechecker_fn.type_from_annotation(a))
.map(|a| self.type_from_annotation(a))
.transpose()?;

let parameter_type = match &annotated_type {
Some(annotated_type) => annotated_type.clone(),
None => typechecker_fn.fresh_type_variable(),
None => self.fresh_type_variable(),
};

if is_ffi_function && annotated_type.is_none() {
Expand All @@ -1361,7 +1365,7 @@ impl TypeChecker {
));
}

typechecker_fn.env.add_scheme(
self.env.add_scheme(
parameter.clone(),
TypeScheme::make_quantified(parameter_type.clone()),
*parameter_span,
Expand All @@ -1377,12 +1381,12 @@ impl TypeChecker {

let annotated_return_type = return_type_annotation
.as_ref()
.map(|annotation| typechecker_fn.type_from_annotation(annotation))
.map(|annotation| self.type_from_annotation(annotation))
.transpose()?;

let return_type = match &annotated_return_type {
Some(annotated_return_type) => annotated_return_type.clone(),
None => typechecker_fn.fresh_type_variable(),
None => self.fresh_type_variable(),
};

// Add the function to the environment, so it can be called recursively
Expand All @@ -1399,7 +1403,7 @@ impl TypeChecker {
let fn_type =
TypeScheme::Concrete(Type::Fn(parameter_types, Box::new(return_type.clone())));

typechecker_fn.env.add_function(
self.env.add_function(
function_name.clone(),
FunctionSignature {
name: function_name.clone(),
Expand All @@ -1418,19 +1422,18 @@ impl TypeChecker {

let mut typed_local_variables = vec![];
for local_variable in local_variables {
typed_local_variables
.push(typechecker_fn.elaborate_define_variable(local_variable)?);
typed_local_variables.push(self.elaborate_define_variable(local_variable)?);
}

let body_checked = body
.as_ref()
.map(|expr| typechecker_fn.elaborate_expression(expr))
.map(|expr| self.elaborate_expression(expr))
.transpose()?;

let return_type_inferred = if let Some(ref expr) = body_checked {
let return_type_inferred = expr.get_type();

if typechecker_fn
if self
.add_equal_constraint(&return_type_inferred, &return_type)
.is_trivially_violated()
{
Expand All @@ -1449,7 +1452,7 @@ impl TypeChecker {
.unwrap()
.full_span(),
expected_name: "specified return type",
expected_dimensions: typechecker_fn
expected_dimensions: self
.registry
.get_derived_entry_names_for(
&dtype_specified.to_base_representation(),
Expand All @@ -1461,7 +1464,7 @@ impl TypeChecker {
.unwrap(),
actual_name: " actual return type",
actual_name_for_fix: "expression in the function body",
actual_dimensions: typechecker_fn
actual_dimensions: self
.registry
.get_derived_entry_names_for(
&dtype_deduced.to_base_representation(),
Expand Down Expand Up @@ -1500,16 +1503,19 @@ impl TypeChecker {
})?
};

typechecker_fn
.add_equal_constraint(&return_type_inferred, &return_type)
self.add_equal_constraint(&return_type_inferred, &return_type)
.ok();

self.constraints = typechecker_fn.constraints;
self.name_generator = typechecker_fn.name_generator;
self.registry = typechecker_fn.registry;
// Copy identifier for the new function into local env:
let (signature, metadata) =
typechecker_fn.env.get_function_info(function_name).unwrap();
let (signature, metadata) = self.env.get_function_info(function_name).unwrap();
let signature = signature.clone();
let metadata = metadata.clone();

// Restore the environment and namespaces before exiting and
// add the function name to the environment
self.value_namespace.restore();
self.type_namespace.restore();
self.env.restore();
self.env
.add_function(function_name.clone(), signature.clone(), metadata.clone());

Expand Down
Loading