diff --git a/numbat/src/bytecode_interpreter.rs b/numbat/src/bytecode_interpreter.rs index 2e7c7816..a3a9bce5 100644 --- a/numbat/src/bytecode_interpreter.rs +++ b/numbat/src/bytecode_interpreter.rs @@ -4,6 +4,7 @@ use crate::ast::ProcedureKind; use crate::interpreter::{ Interpreter, InterpreterResult, InterpreterSettings, Result, RuntimeError, }; +use crate::name_resolution::LAST_RESULT_IDENTIFIERS; use crate::prefix::Prefix; use crate::pretty_print::PrettyPrint; use crate::typed_ast::{BinaryOperator, Expression, Statement, StringPart, UnaryOperator}; @@ -12,10 +13,16 @@ use crate::unit_registry::UnitRegistry; use crate::vm::{Constant, ExecutionContext, Op, Vm}; use crate::{decorator, ffi}; +#[derive(Debug)] +pub struct Local { + identifier: String, + depth: usize, +} + pub struct BytecodeInterpreter { vm: Vm, - /// List of local variables currently in scope - local_variables: Vec, + /// List of local variables currently in scope, one vector for each scope (for now: 0: 'global' scope, 1: function scope) + locals: Vec>, // Maps names of units to indices of the respective constants in the VM unit_name_to_constant_index: HashMap, } @@ -28,11 +35,24 @@ impl BytecodeInterpreter { self.vm.add_op1(Op::LoadConstant, index); } Expression::Identifier(_span, identifier, _type) => { - if let Some(position) = self.local_variables.iter().position(|n| n == identifier) { + // Searching in reverse order ensures that we find the innermost identifer of that name first (shadowing) + + let current_depth = self.locals.len() - 1; + + if let Some(position) = self.locals[current_depth] + .iter() + .rposition(|l| &l.identifier == identifier && l.depth == current_depth) + { self.vm.add_op1(Op::GetLocal, position as u16); // TODO: check overflow + } else if let Some(upvalue_position) = self.locals[0] + .iter() + .rposition(|l| &l.identifier == identifier) + { + self.vm.add_op1(Op::GetUpvalue, upvalue_position as u16); + } else if LAST_RESULT_IDENTIFIERS.contains(&identifier.as_str()) { + self.vm.add_op(Op::GetLastResult); } else { - let identifier_idx = self.vm.add_global_identifier(identifier, None); - self.vm.add_op1(Op::GetVariable, identifier_idx); + unreachable!("Unknown identifier {identifier}") } } Expression::UnitIdentifier(_span, prefix, unit_name, _full_name, _type) => { @@ -165,8 +185,11 @@ impl BytecodeInterpreter { } Statement::DefineVariable(identifier, expr, _type_annotation, _type) => { self.compile_expression_with_simplify(expr)?; - let identifier_idx = self.vm.add_global_identifier(identifier, None); - self.vm.add_op1(Op::SetVariable, identifier_idx); + let current_depth = self.current_depth(); + self.locals[current_depth].push(Local { + identifier: identifier.clone(), + depth: 0, + }); } Statement::DefineFunction( name, @@ -177,14 +200,22 @@ impl BytecodeInterpreter { _return_type, ) => { self.vm.begin_function(name); - for parameter in parameters.iter() { - self.local_variables.push(parameter.1.clone()); + + self.locals.push(vec![]); + + let current_depth = self.current_depth(); + for parameter in parameters { + self.locals[current_depth].push(Local { + identifier: parameter.1.clone(), + depth: current_depth, + }); } + self.compile_expression_with_simplify(expr)?; self.vm.add_op(Op::Return); - for _ in parameters { - self.local_variables.pop(); - } + + self.locals.pop(); + self.vm.end_function(); } Statement::DefineFunction( @@ -294,13 +325,17 @@ impl BytecodeInterpreter { pub(crate) fn set_debug(&mut self, activate: bool) { self.vm.set_debug(activate); } + + fn current_depth(&self) -> usize { + self.locals.len() - 1 + } } impl Interpreter for BytecodeInterpreter { fn new() -> Self { Self { vm: Vm::new(), - local_variables: vec![], + locals: vec![vec![]], unit_name_to_constant_index: HashMap::new(), } } diff --git a/numbat/src/typechecker.rs b/numbat/src/typechecker.rs index 53e9e2d8..0a5dfcf5 100644 --- a/numbat/src/typechecker.rs +++ b/numbat/src/typechecker.rs @@ -816,14 +816,6 @@ impl TypeChecker { )); } - if let Some(entry) = self.identifiers.get(identifier) { - return Err(TypeCheckError::NameAlreadyUsedBy( - "a constant", - *identifier_span, - entry.1, - )); - } - let expr_checked = self.check_expression(expr)?; let type_deduced = expr_checked.get_type(); diff --git a/numbat/src/vm.rs b/numbat/src/vm.rs index 328e358c..32cba279 100644 --- a/numbat/src/vm.rs +++ b/numbat/src/vm.rs @@ -1,11 +1,10 @@ -use std::{collections::HashMap, fmt::Display}; +use std::fmt::Display; use crate::{ ffi::{self, ArityRange, Callable, ForeignFunction}, interpreter::{InterpreterResult, PrintFunction, Result, RuntimeError}, markup::Markup, math, - name_resolution::LAST_RESULT_IDENTIFIERS, prefix::Prefix, quantity::Quantity, unit::Unit, @@ -31,15 +30,16 @@ pub enum Op { /// `1 ` to the constant with the given index. SetUnitConstant, - /// Set the specified variable to the value on top of the stack - SetVariable, - /// Push the value of the specified variable onto the stack - GetVariable, - /// Push the value of the specified local variable onto the stack (even /// though it is already on the stack, somewhere lower down). GetLocal, + /// Similar to GetLocal, but get variable from surrounding scope + GetUpvalue, + + /// Get the last stored result (_ and ans) + GetLastResult, + /// Negate the top of the stack Negate, @@ -98,9 +98,8 @@ impl Op { Op::SetUnitConstant | Op::Call | Op::FFICallFunction | Op::FFICallProcedure => 2, Op::LoadConstant | Op::ApplyPrefix - | Op::SetVariable - | Op::GetVariable | Op::GetLocal + | Op::GetUpvalue | Op::PrintString | Op::JoinString | Op::JumpIfFalse @@ -120,7 +119,8 @@ impl Op { | Op::Equal | Op::NotEqual | Op::FullSimplify - | Op::Return => 0, + | Op::Return + | Op::GetLastResult => 0, } } @@ -129,9 +129,9 @@ impl Op { Op::LoadConstant => "LoadConstant", Op::ApplyPrefix => "ApplyPrefix", Op::SetUnitConstant => "SetUnitConstant", - Op::SetVariable => "SetVariable", - Op::GetVariable => "GetVariable", Op::GetLocal => "GetLocal", + Op::GetUpvalue => "GetUpvalue", + Op::GetLastResult => "GetLastResult", Op::Negate => "Negate", Op::Factorial => "Factorial", Op::Add => "Add", @@ -237,8 +237,8 @@ pub struct Vm { /// entry is the canonical name for units. global_identifiers: Vec<(String, Option)>, - /// A dictionary of global variables and their respective values. - globals: HashMap, + /// Result of the last expression + last_result: Option, /// List of registered native/foreign functions ffi_callables: Vec<&'static ForeignFunction>, @@ -264,7 +264,7 @@ impl Vm { prefixes: vec![], strings: vec![], global_identifiers: vec![], - globals: HashMap::new(), + last_result: None, ffi_callables: ffi::procedures().iter().map(|(_, ff)| ff).collect(), frames: vec![CallFrame::root()], stack: vec![], @@ -561,27 +561,18 @@ impl Vm { defining_unit.clone(), )); } - Op::SetVariable => { - let identifier_idx = self.read_u16(); - let value = self.pop(); - let identifier: String = - self.global_identifiers[identifier_idx as usize].0.clone(); - - self.globals.insert(identifier, value); - } - Op::GetVariable => { - let identifier_idx = self.read_u16(); - let identifier = &self.global_identifiers[identifier_idx as usize].0; - - let value = self.globals.get(identifier).expect("Variable exists"); - - self.push(value.clone()); - } Op::GetLocal => { let slot_idx = self.read_u16() as usize; let stack_idx = self.current_frame().fp + slot_idx; self.push(self.stack[stack_idx].clone()); } + Op::GetUpvalue => { + let stack_idx = self.read_u16() as usize; + self.push(self.stack[stack_idx].clone()); + } + Op::GetLastResult => { + self.push(self.last_result.as_ref().unwrap().clone()); + } op @ (Op::Add | Op::Subtract | Op::Multiply @@ -725,10 +716,7 @@ impl Vm { if self.frames.len() == 1 { let return_value = self.pop(); - // Save the returned value in `ans` and `_`: - for &identifier in LAST_RESULT_IDENTIFIERS { - self.globals.insert(identifier.into(), return_value.clone()); - } + self.last_result = Some(return_value.clone()); result_last_statement = Some(return_value); } else { @@ -749,8 +737,6 @@ impl Vm { } } - debug_assert!(self.stack.is_empty()); - if let Some(value) = result_last_statement { Ok(InterpreterResult::Value(value)) } else { diff --git a/numbat/tests/interpreter.rs b/numbat/tests/interpreter.rs index 8f664c7d..494c748b 100644 --- a/numbat/tests/interpreter.rs +++ b/numbat/tests/interpreter.rs @@ -334,9 +334,6 @@ fn test_type_check_errors() { "fn sin(x)=0", "This name is already used by a foreign function", ); - - // TODO: this restriction should be lifted in the future: - expect_failure("let pi = 1", "This name is already used by a constant"); } #[test] @@ -409,3 +406,23 @@ fn test_overwrite_inner_function() { "0", ); } + +#[test] +fn test_override_constants() { + expect_output("let x = 1\nlet x = 2\nx", "2"); + expect_output("let pi = 4\npi", "4"); +} + +#[test] +fn test_overwrite_captured_constant() { + expect_output( + " + let x = 1 + fn f() = sin(x) + + let x = 1 m + f() + ", + "0.841471", + ); +}