Skip to content

Commit

Permalink
Allow constants to be shadowed
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp committed Oct 8, 2023
1 parent 5451d97 commit 81620d3
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 61 deletions.
61 changes: 48 additions & 13 deletions numbat/src/bytecode_interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ast::ProcedureKind;
use crate::interpreter::{
Interpreter, InterpreterResult, InterpreterSettings, Result, RuntimeError,
};
use crate::name_resolution::LAST_RESULT_IDENTIFIERS;
use crate::prefix::Prefix;
use crate::pretty_print::PrettyPrint;
use crate::typed_ast::{BinaryOperator, Expression, Statement, StringPart, UnaryOperator};
Expand All @@ -12,10 +13,16 @@ use crate::unit_registry::UnitRegistry;
use crate::vm::{Constant, ExecutionContext, Op, Vm};
use crate::{decorator, ffi};

#[derive(Debug)]
pub struct Local {
identifier: String,
depth: usize,
}

pub struct BytecodeInterpreter {
vm: Vm,
/// List of local variables currently in scope
local_variables: Vec<String>,
/// List of local variables currently in scope, one vector for each scope (for now: 0: 'global' scope, 1: function scope)
locals: Vec<Vec<Local>>,
// Maps names of units to indices of the respective constants in the VM
unit_name_to_constant_index: HashMap<String, u16>,
}
Expand All @@ -28,11 +35,24 @@ impl BytecodeInterpreter {
self.vm.add_op1(Op::LoadConstant, index);
}
Expression::Identifier(_span, identifier, _type) => {
if let Some(position) = self.local_variables.iter().position(|n| n == identifier) {
// Searching in reverse order ensures that we find the innermost identifer of that name first (shadowing)

let current_depth = self.locals.len() - 1;

if let Some(position) = self.locals[current_depth]
.iter()
.rposition(|l| &l.identifier == identifier && l.depth == current_depth)
{
self.vm.add_op1(Op::GetLocal, position as u16); // TODO: check overflow
} else if let Some(upvalue_position) = self.locals[0]
.iter()
.rposition(|l| &l.identifier == identifier)
{
self.vm.add_op1(Op::GetUpvalue, upvalue_position as u16);
} else if LAST_RESULT_IDENTIFIERS.contains(&identifier.as_str()) {
self.vm.add_op(Op::GetLastResult);
} else {
let identifier_idx = self.vm.add_global_identifier(identifier, None);
self.vm.add_op1(Op::GetVariable, identifier_idx);
unreachable!("Unknown identifier {identifier}")
}
}
Expression::UnitIdentifier(_span, prefix, unit_name, _full_name, _type) => {
Expand Down Expand Up @@ -165,8 +185,11 @@ impl BytecodeInterpreter {
}
Statement::DefineVariable(identifier, expr, _type_annotation, _type) => {
self.compile_expression_with_simplify(expr)?;
let identifier_idx = self.vm.add_global_identifier(identifier, None);
self.vm.add_op1(Op::SetVariable, identifier_idx);
let current_depth = self.current_depth();
self.locals[current_depth].push(Local {
identifier: identifier.clone(),
depth: 0,
});
}
Statement::DefineFunction(
name,
Expand All @@ -177,14 +200,22 @@ impl BytecodeInterpreter {
_return_type,
) => {
self.vm.begin_function(name);
for parameter in parameters.iter() {
self.local_variables.push(parameter.1.clone());

self.locals.push(vec![]);

let current_depth = self.current_depth();
for parameter in parameters {
self.locals[current_depth].push(Local {
identifier: parameter.1.clone(),
depth: current_depth,
});
}

self.compile_expression_with_simplify(expr)?;
self.vm.add_op(Op::Return);
for _ in parameters {
self.local_variables.pop();
}

self.locals.pop();

self.vm.end_function();
}
Statement::DefineFunction(
Expand Down Expand Up @@ -294,13 +325,17 @@ impl BytecodeInterpreter {
pub(crate) fn set_debug(&mut self, activate: bool) {
self.vm.set_debug(activate);
}

fn current_depth(&self) -> usize {
self.locals.len() - 1
}
}

impl Interpreter for BytecodeInterpreter {
fn new() -> Self {
Self {
vm: Vm::new(),
local_variables: vec![],
locals: vec![vec![]],
unit_name_to_constant_index: HashMap::new(),
}
}
Expand Down
8 changes: 0 additions & 8 deletions numbat/src/typechecker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -816,14 +816,6 @@ impl TypeChecker {
));
}

if let Some(entry) = self.identifiers.get(identifier) {
return Err(TypeCheckError::NameAlreadyUsedBy(
"a constant",
*identifier_span,
entry.1,
));
}

let expr_checked = self.check_expression(expr)?;
let type_deduced = expr_checked.get_type();

Expand Down
60 changes: 23 additions & 37 deletions numbat/src/vm.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::{collections::HashMap, fmt::Display};
use std::fmt::Display;

use crate::{
ffi::{self, ArityRange, Callable, ForeignFunction},
interpreter::{InterpreterResult, PrintFunction, Result, RuntimeError},
markup::Markup,
math,
name_resolution::LAST_RESULT_IDENTIFIERS,
prefix::Prefix,
quantity::Quantity,
unit::Unit,
Expand All @@ -31,15 +30,16 @@ pub enum Op {
/// `1 <new_unit>` to the constant with the given index.
SetUnitConstant,

/// Set the specified variable to the value on top of the stack
SetVariable,
/// Push the value of the specified variable onto the stack
GetVariable,

/// Push the value of the specified local variable onto the stack (even
/// though it is already on the stack, somewhere lower down).
GetLocal,

/// Similar to GetLocal, but get variable from surrounding scope
GetUpvalue,

/// Get the last stored result (_ and ans)
GetLastResult,

/// Negate the top of the stack
Negate,

Expand Down Expand Up @@ -98,9 +98,8 @@ impl Op {
Op::SetUnitConstant | Op::Call | Op::FFICallFunction | Op::FFICallProcedure => 2,
Op::LoadConstant
| Op::ApplyPrefix
| Op::SetVariable
| Op::GetVariable
| Op::GetLocal
| Op::GetUpvalue
| Op::PrintString
| Op::JoinString
| Op::JumpIfFalse
Expand All @@ -120,7 +119,8 @@ impl Op {
| Op::Equal
| Op::NotEqual
| Op::FullSimplify
| Op::Return => 0,
| Op::Return
| Op::GetLastResult => 0,
}
}

Expand All @@ -129,9 +129,9 @@ impl Op {
Op::LoadConstant => "LoadConstant",
Op::ApplyPrefix => "ApplyPrefix",
Op::SetUnitConstant => "SetUnitConstant",
Op::SetVariable => "SetVariable",
Op::GetVariable => "GetVariable",
Op::GetLocal => "GetLocal",
Op::GetUpvalue => "GetUpvalue",
Op::GetLastResult => "GetLastResult",
Op::Negate => "Negate",
Op::Factorial => "Factorial",
Op::Add => "Add",
Expand Down Expand Up @@ -237,8 +237,8 @@ pub struct Vm {
/// entry is the canonical name for units.
global_identifiers: Vec<(String, Option<String>)>,

/// A dictionary of global variables and their respective values.
globals: HashMap<String, Value>,
/// Result of the last expression
last_result: Option<Value>,

/// List of registered native/foreign functions
ffi_callables: Vec<&'static ForeignFunction>,
Expand All @@ -264,7 +264,7 @@ impl Vm {
prefixes: vec![],
strings: vec![],
global_identifiers: vec![],
globals: HashMap::new(),
last_result: None,
ffi_callables: ffi::procedures().iter().map(|(_, ff)| ff).collect(),
frames: vec![CallFrame::root()],
stack: vec![],
Expand Down Expand Up @@ -561,27 +561,18 @@ impl Vm {
defining_unit.clone(),
));
}
Op::SetVariable => {
let identifier_idx = self.read_u16();
let value = self.pop();
let identifier: String =
self.global_identifiers[identifier_idx as usize].0.clone();

self.globals.insert(identifier, value);
}
Op::GetVariable => {
let identifier_idx = self.read_u16();
let identifier = &self.global_identifiers[identifier_idx as usize].0;

let value = self.globals.get(identifier).expect("Variable exists");

self.push(value.clone());
}
Op::GetLocal => {
let slot_idx = self.read_u16() as usize;
let stack_idx = self.current_frame().fp + slot_idx;
self.push(self.stack[stack_idx].clone());
}
Op::GetUpvalue => {
let stack_idx = self.read_u16() as usize;
self.push(self.stack[stack_idx].clone());
}
Op::GetLastResult => {
self.push(self.last_result.as_ref().unwrap().clone());
}
op @ (Op::Add
| Op::Subtract
| Op::Multiply
Expand Down Expand Up @@ -725,10 +716,7 @@ impl Vm {
if self.frames.len() == 1 {
let return_value = self.pop();

// Save the returned value in `ans` and `_`:
for &identifier in LAST_RESULT_IDENTIFIERS {
self.globals.insert(identifier.into(), return_value.clone());
}
self.last_result = Some(return_value.clone());

result_last_statement = Some(return_value);
} else {
Expand All @@ -749,8 +737,6 @@ impl Vm {
}
}

debug_assert!(self.stack.is_empty());

if let Some(value) = result_last_statement {
Ok(InterpreterResult::Value(value))
} else {
Expand Down
23 changes: 20 additions & 3 deletions numbat/tests/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,6 @@ fn test_type_check_errors() {
"fn sin(x)=0",
"This name is already used by a foreign function",
);

// TODO: this restriction should be lifted in the future:
expect_failure("let pi = 1", "This name is already used by a constant");
}

#[test]
Expand Down Expand Up @@ -409,3 +406,23 @@ fn test_overwrite_inner_function() {
"0",
);
}

#[test]
fn test_override_constants() {
expect_output("let x = 1\nlet x = 2\nx", "2");
expect_output("let pi = 4\npi", "4");
}

#[test]
fn test_overwrite_captured_constant() {
expect_output(
"
let x = 1
fn f() = sin(x)
let x = 1 m
f()
",
"0.841471",
);
}

0 comments on commit 81620d3

Please sign in to comment.