Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow constants to be shadowed #197

Merged
merged 1 commit into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 48 additions & 13 deletions numbat/src/bytecode_interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ast::ProcedureKind;
use crate::interpreter::{
Interpreter, InterpreterResult, InterpreterSettings, Result, RuntimeError,
};
use crate::name_resolution::LAST_RESULT_IDENTIFIERS;
use crate::prefix::Prefix;
use crate::pretty_print::PrettyPrint;
use crate::typed_ast::{BinaryOperator, Expression, Statement, StringPart, UnaryOperator};
Expand All @@ -12,10 +13,16 @@ use crate::unit_registry::UnitRegistry;
use crate::vm::{Constant, ExecutionContext, Op, Vm};
use crate::{decorator, ffi};

#[derive(Debug)]
pub struct Local {
identifier: String,
depth: usize,
}

pub struct BytecodeInterpreter {
vm: Vm,
/// List of local variables currently in scope
local_variables: Vec<String>,
/// List of local variables currently in scope, one vector for each scope (for now: 0: 'global' scope, 1: function scope)
locals: Vec<Vec<Local>>,
// Maps names of units to indices of the respective constants in the VM
unit_name_to_constant_index: HashMap<String, u16>,
}
Expand All @@ -28,11 +35,24 @@ impl BytecodeInterpreter {
self.vm.add_op1(Op::LoadConstant, index);
}
Expression::Identifier(_span, identifier, _type) => {
if let Some(position) = self.local_variables.iter().position(|n| n == identifier) {
// Searching in reverse order ensures that we find the innermost identifer of that name first (shadowing)

let current_depth = self.locals.len() - 1;

if let Some(position) = self.locals[current_depth]
.iter()
.rposition(|l| &l.identifier == identifier && l.depth == current_depth)
{
self.vm.add_op1(Op::GetLocal, position as u16); // TODO: check overflow
} else if let Some(upvalue_position) = self.locals[0]
.iter()
.rposition(|l| &l.identifier == identifier)
{
self.vm.add_op1(Op::GetUpvalue, upvalue_position as u16);
} else if LAST_RESULT_IDENTIFIERS.contains(&identifier.as_str()) {
self.vm.add_op(Op::GetLastResult);
} else {
let identifier_idx = self.vm.add_global_identifier(identifier, None);
self.vm.add_op1(Op::GetVariable, identifier_idx);
unreachable!("Unknown identifier {identifier}")
}
}
Expression::UnitIdentifier(_span, prefix, unit_name, _full_name, _type) => {
Expand Down Expand Up @@ -165,8 +185,11 @@ impl BytecodeInterpreter {
}
Statement::DefineVariable(identifier, expr, _type_annotation, _type) => {
self.compile_expression_with_simplify(expr)?;
let identifier_idx = self.vm.add_global_identifier(identifier, None);
self.vm.add_op1(Op::SetVariable, identifier_idx);
let current_depth = self.current_depth();
self.locals[current_depth].push(Local {
identifier: identifier.clone(),
depth: 0,
});
}
Statement::DefineFunction(
name,
Expand All @@ -177,14 +200,22 @@ impl BytecodeInterpreter {
_return_type,
) => {
self.vm.begin_function(name);
for parameter in parameters.iter() {
self.local_variables.push(parameter.1.clone());

self.locals.push(vec![]);

let current_depth = self.current_depth();
for parameter in parameters {
self.locals[current_depth].push(Local {
identifier: parameter.1.clone(),
depth: current_depth,
});
}

self.compile_expression_with_simplify(expr)?;
self.vm.add_op(Op::Return);
for _ in parameters {
self.local_variables.pop();
}

self.locals.pop();

self.vm.end_function();
}
Statement::DefineFunction(
Expand Down Expand Up @@ -294,13 +325,17 @@ impl BytecodeInterpreter {
pub(crate) fn set_debug(&mut self, activate: bool) {
self.vm.set_debug(activate);
}

fn current_depth(&self) -> usize {
self.locals.len() - 1
}
}

impl Interpreter for BytecodeInterpreter {
fn new() -> Self {
Self {
vm: Vm::new(),
local_variables: vec![],
locals: vec![vec![]],
unit_name_to_constant_index: HashMap::new(),
}
}
Expand Down
8 changes: 0 additions & 8 deletions numbat/src/typechecker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -816,14 +816,6 @@ impl TypeChecker {
));
}

if let Some(entry) = self.identifiers.get(identifier) {
return Err(TypeCheckError::NameAlreadyUsedBy(
"a constant",
*identifier_span,
entry.1,
));
}

let expr_checked = self.check_expression(expr)?;
let type_deduced = expr_checked.get_type();

Expand Down
60 changes: 23 additions & 37 deletions numbat/src/vm.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::{collections::HashMap, fmt::Display};
use std::fmt::Display;

use crate::{
ffi::{self, ArityRange, Callable, ForeignFunction},
interpreter::{InterpreterResult, PrintFunction, Result, RuntimeError},
markup::Markup,
math,
name_resolution::LAST_RESULT_IDENTIFIERS,
prefix::Prefix,
quantity::Quantity,
unit::Unit,
Expand All @@ -31,15 +30,16 @@ pub enum Op {
/// `1 <new_unit>` to the constant with the given index.
SetUnitConstant,

/// Set the specified variable to the value on top of the stack
SetVariable,
/// Push the value of the specified variable onto the stack
GetVariable,

/// Push the value of the specified local variable onto the stack (even
/// though it is already on the stack, somewhere lower down).
GetLocal,

/// Similar to GetLocal, but get variable from surrounding scope
GetUpvalue,

/// Get the last stored result (_ and ans)
GetLastResult,

/// Negate the top of the stack
Negate,

Expand Down Expand Up @@ -98,9 +98,8 @@ impl Op {
Op::SetUnitConstant | Op::Call | Op::FFICallFunction | Op::FFICallProcedure => 2,
Op::LoadConstant
| Op::ApplyPrefix
| Op::SetVariable
| Op::GetVariable
| Op::GetLocal
| Op::GetUpvalue
| Op::PrintString
| Op::JoinString
| Op::JumpIfFalse
Expand All @@ -120,7 +119,8 @@ impl Op {
| Op::Equal
| Op::NotEqual
| Op::FullSimplify
| Op::Return => 0,
| Op::Return
| Op::GetLastResult => 0,
}
}

Expand All @@ -129,9 +129,9 @@ impl Op {
Op::LoadConstant => "LoadConstant",
Op::ApplyPrefix => "ApplyPrefix",
Op::SetUnitConstant => "SetUnitConstant",
Op::SetVariable => "SetVariable",
Op::GetVariable => "GetVariable",
Op::GetLocal => "GetLocal",
Op::GetUpvalue => "GetUpvalue",
Op::GetLastResult => "GetLastResult",
Op::Negate => "Negate",
Op::Factorial => "Factorial",
Op::Add => "Add",
Expand Down Expand Up @@ -237,8 +237,8 @@ pub struct Vm {
/// entry is the canonical name for units.
global_identifiers: Vec<(String, Option<String>)>,

/// A dictionary of global variables and their respective values.
globals: HashMap<String, Value>,
/// Result of the last expression
last_result: Option<Value>,

/// List of registered native/foreign functions
ffi_callables: Vec<&'static ForeignFunction>,
Expand All @@ -264,7 +264,7 @@ impl Vm {
prefixes: vec![],
strings: vec![],
global_identifiers: vec![],
globals: HashMap::new(),
last_result: None,
ffi_callables: ffi::procedures().iter().map(|(_, ff)| ff).collect(),
frames: vec![CallFrame::root()],
stack: vec![],
Expand Down Expand Up @@ -561,27 +561,18 @@ impl Vm {
defining_unit.clone(),
));
}
Op::SetVariable => {
let identifier_idx = self.read_u16();
let value = self.pop();
let identifier: String =
self.global_identifiers[identifier_idx as usize].0.clone();

self.globals.insert(identifier, value);
}
Op::GetVariable => {
let identifier_idx = self.read_u16();
let identifier = &self.global_identifiers[identifier_idx as usize].0;

let value = self.globals.get(identifier).expect("Variable exists");

self.push(value.clone());
}
Op::GetLocal => {
let slot_idx = self.read_u16() as usize;
let stack_idx = self.current_frame().fp + slot_idx;
self.push(self.stack[stack_idx].clone());
}
Op::GetUpvalue => {
let stack_idx = self.read_u16() as usize;
self.push(self.stack[stack_idx].clone());
}
Op::GetLastResult => {
self.push(self.last_result.as_ref().unwrap().clone());
}
op @ (Op::Add
| Op::Subtract
| Op::Multiply
Expand Down Expand Up @@ -725,10 +716,7 @@ impl Vm {
if self.frames.len() == 1 {
let return_value = self.pop();

// Save the returned value in `ans` and `_`:
for &identifier in LAST_RESULT_IDENTIFIERS {
self.globals.insert(identifier.into(), return_value.clone());
}
self.last_result = Some(return_value.clone());

result_last_statement = Some(return_value);
} else {
Expand All @@ -749,8 +737,6 @@ impl Vm {
}
}

debug_assert!(self.stack.is_empty());

if let Some(value) = result_last_statement {
Ok(InterpreterResult::Value(value))
} else {
Expand Down
23 changes: 20 additions & 3 deletions numbat/tests/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,6 @@ fn test_type_check_errors() {
"fn sin(x)=0",
"This name is already used by a foreign function",
);

// TODO: this restriction should be lifted in the future:
expect_failure("let pi = 1", "This name is already used by a constant");
}

#[test]
Expand Down Expand Up @@ -409,3 +406,23 @@ fn test_overwrite_inner_function() {
"0",
);
}

#[test]
fn test_override_constants() {
expect_output("let x = 1\nlet x = 2\nx", "2");
expect_output("let pi = 4\npi", "4");
}

#[test]
fn test_overwrite_captured_constant() {
expect_output(
"
let x = 1
fn f() = sin(x)

let x = 1 m
f()
",
"0.841471",
);
}