diff --git a/numbat-cli/src/main.rs b/numbat-cli/src/main.rs index ae32d3af..3bab1cfc 100644 --- a/numbat-cli/src/main.rs +++ b/numbat-cli/src/main.rs @@ -9,7 +9,7 @@ use highlighter::NumbatHighlighter; use numbat::diagnostic::ErrorDiagnostic; use numbat::pretty_print::PrettyPrint; use numbat::resolver::{CodeSource, FileSystemImporter, ResolverError}; -use numbat::{markup, NameResolutionError, RuntimeError}; +use numbat::{markup, InterpreterSettings, NameResolutionError, RuntimeError}; use numbat::{Context, ExitStatus, InterpreterResult, NumbatError}; use anyhow::{bail, Context as AnyhowContext, Result}; @@ -149,6 +149,12 @@ impl Cli { if load_prelude { let ctx = self.context.clone(); + let mut no_print_settings = InterpreterSettings { + print_fn: Box::new( + move |_: &str| { // ignore any print statements when loading this module asynchronously + }, + ), + }; thread::spawn(move || { numbat::Context::fetch_exchange_rates(); @@ -158,7 +164,11 @@ impl Cli { // a short delay (the limiting factor is the HTTP request). ctx.lock() .unwrap() - .interpret("use units::currencies", CodeSource::Internal) + .interpret_with_settings( + &mut no_print_settings, + "use units::currencies", + CodeSource::Internal, + ) .ok(); }); } @@ -331,7 +341,20 @@ impl Cli { execution_mode: ExecutionMode, pretty_print_mode: PrettyPrintMode, ) -> ControlFlow { - let result = { self.context.lock().unwrap().interpret(input, code_source) }; + let to_be_printed: Arc>> = Arc::new(Mutex::new(vec![])); + let to_be_printed_c = to_be_printed.clone(); + let mut settings = InterpreterSettings { + print_fn: Box::new(move |s: &str| { + to_be_printed_c.lock().unwrap().push(s.to_string()); + }), + }; + + let result = { + self.context + .lock() + .unwrap() + .interpret_with_settings(&mut settings, input, code_source) + }; let pretty_print = match pretty_print_mode { PrettyPrintMode::Always => true, @@ -352,6 +375,22 @@ impl Cli { println!(); } + let to_be_printed = to_be_printed.lock().unwrap(); + for s in to_be_printed.iter() { + print!( + "{}{}", + if execution_mode == ExecutionMode::Interactive { + " " + } else { + "" + }, + s + ); + } + if !to_be_printed.is_empty() && execution_mode == ExecutionMode::Interactive { + println!(); + } + match interpreter_result { InterpreterResult::Quantity(quantity) => { let q_markup = markup::whitespace(" ") diff --git a/numbat/src/bytecode_interpreter.rs b/numbat/src/bytecode_interpreter.rs index e1e909e1..c688564d 100644 --- a/numbat/src/bytecode_interpreter.rs +++ b/numbat/src/bytecode_interpreter.rs @@ -1,12 +1,14 @@ use std::collections::HashMap; use crate::ast::ProcedureKind; -use crate::interpreter::{Interpreter, InterpreterResult, Result, RuntimeError}; +use crate::interpreter::{ + Interpreter, InterpreterResult, InterpreterSettings, Result, RuntimeError, +}; use crate::prefix::Prefix; use crate::typed_ast::{BinaryOperator, Expression, Statement, UnaryOperator}; use crate::unit::Unit; use crate::unit_registry::UnitRegistry; -use crate::vm::{Constant, Op, Vm}; +use crate::vm::{Constant, ExecutionContext, Op, Vm}; use crate::{decorator, ffi}; pub struct BytecodeInterpreter { @@ -209,12 +211,16 @@ impl BytecodeInterpreter { Ok(()) } - fn run(&mut self) -> Result { - self.vm.disassemble(); + fn run(&mut self, settings: &mut InterpreterSettings) -> Result { + let mut ctx = ExecutionContext { + print_fn: &mut settings.print_fn, + }; + + self.vm.disassemble(&mut ctx); - let result = self.vm.run(); + let result = self.vm.run(&mut ctx); - self.vm.debug(); + self.vm.debug(&mut ctx); result } @@ -233,7 +239,11 @@ impl Interpreter for BytecodeInterpreter { } } - fn interpret_statements(&mut self, statements: &[Statement]) -> Result { + fn interpret_statements( + &mut self, + settings: &mut InterpreterSettings, + statements: &[Statement], + ) -> Result { if statements.is_empty() { return Err(RuntimeError::NoStatements); }; @@ -242,7 +252,7 @@ impl Interpreter for BytecodeInterpreter { self.compile_statement(statement)?; } - self.run() + self.run(settings) } fn get_unit_registry(&self) -> &UnitRegistry { diff --git a/numbat/src/ffi.rs b/numbat/src/ffi.rs index 1feb3319..b0920a75 100644 --- a/numbat/src/ffi.rs +++ b/numbat/src/ffi.rs @@ -4,6 +4,7 @@ use std::sync::OnceLock; use crate::currency::ExchangeRatesCache; use crate::interpreter::RuntimeError; +use crate::vm::ExecutionContext; use crate::{ast::ProcedureKind, quantity::Quantity}; type ControlFlow = std::ops::ControlFlow; @@ -14,7 +15,7 @@ type BoxedFunction = Box Quantity + Send + Sync>; pub(crate) enum Callable { Function(BoxedFunction), - Procedure(fn(&[Quantity]) -> ControlFlow), + Procedure(fn(&mut ExecutionContext, &[Quantity]) -> ControlFlow), } pub(crate) struct ForeignFunction { @@ -284,15 +285,15 @@ pub(crate) fn functions() -> &'static HashMap { }) } -fn print(args: &[Quantity]) -> ControlFlow { +fn print(ctx: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow { assert!(args.len() == 1); - println!("{}", args[0]); + (ctx.print_fn)(&format!("{}\n", args[0])); ControlFlow::Continue(()) } -fn assert_eq(args: &[Quantity]) -> ControlFlow { +fn assert_eq(_: &mut ExecutionContext, args: &[Quantity]) -> ControlFlow { assert!(args.len() == 2 || args.len() == 3); if args.len() == 2 { diff --git a/numbat/src/interpreter.rs b/numbat/src/interpreter.rs index add2559b..d67c518c 100644 --- a/numbat/src/interpreter.rs +++ b/numbat/src/interpreter.rs @@ -54,10 +54,30 @@ impl InterpreterResult { pub type Result = std::result::Result; +pub type PrintFunction = dyn FnMut(&str) -> () + Send; + +pub struct InterpreterSettings { + pub print_fn: Box, +} + +impl Default for InterpreterSettings { + fn default() -> Self { + Self { + print_fn: Box::new(move |s: &str| { + print!("{}", s); + }), + } + } +} + pub trait Interpreter { fn new() -> Self; - fn interpret_statements(&mut self, statements: &[Statement]) -> Result; + fn interpret_statements( + &mut self, + settings: &mut InterpreterSettings, + statements: &[Statement], + ) -> Result; fn get_unit_registry(&self) -> &UnitRegistry; } @@ -105,7 +125,8 @@ mod tests { let statements_typechecked = crate::typechecker::TypeChecker::default() .check_statements(statements_transformed) .expect("No type check errors for inputs in this test suite"); - BytecodeInterpreter::new().interpret_statements(&statements_typechecked) + BytecodeInterpreter::new() + .interpret_statements(&mut InterpreterSettings::default(), &statements_typechecked) } fn assert_evaluates_to(input: &str, expected: Quantity) { diff --git a/numbat/src/lib.rs b/numbat/src/lib.rs index 7b1b66fa..6cb06e30 100644 --- a/numbat/src/lib.rs +++ b/numbat/src/lib.rs @@ -48,6 +48,7 @@ use ast::Statement; pub use diagnostic::Diagnostic; pub use interpreter::ExitStatus; pub use interpreter::InterpreterResult; +pub use interpreter::InterpreterSettings; pub use interpreter::RuntimeError; pub use name_resolution::NameResolutionError; pub use parser::ParseError; @@ -138,6 +139,15 @@ impl Context { &mut self, code: &str, code_source: CodeSource, + ) -> Result<(Vec, InterpreterResult)> { + self.interpret_with_settings(&mut InterpreterSettings::default(), code, code_source) + } + + pub fn interpret_with_settings( + &mut self, + settings: &mut InterpreterSettings, + code: &str, + code_source: CodeSource, ) -> Result<(Vec, InterpreterResult)> { let statements = self .resolver @@ -174,7 +184,9 @@ impl Context { let typed_statements = result?; - let result = self.interpreter.interpret_statements(&typed_statements); + let result = self + .interpreter + .interpret_statements(settings, &typed_statements); if result.is_err() { // Similar to above: we need to reset the state of the typechecker and the prefix transformer diff --git a/numbat/src/vm.rs b/numbat/src/vm.rs index 40f07453..c3bc61d3 100644 --- a/numbat/src/vm.rs +++ b/numbat/src/vm.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Display}; use crate::{ ffi::{self, ArityRange, Callable, ForeignFunction}, - interpreter::{InterpreterResult, Result, RuntimeError}, + interpreter::{InterpreterResult, PrintFunction, Result, RuntimeError}, math, name_resolution::LAST_RESULT_IDENTIFIERS, prefix::Prefix, @@ -169,6 +169,10 @@ impl CallFrame { } } +pub struct ExecutionContext<'a> { + pub print_fn: &'a mut PrintFunction, +} + pub struct Vm { /// The actual code of the program, structured by function name. The code /// for the global scope is at index 0 under the function name `
`. @@ -327,22 +331,25 @@ impl Vm { Some(position as u16) } - pub fn disassemble(&self) { + pub fn disassemble(&self, ctx: &mut ExecutionContext) { if !self.debug { return; } - println!(); - println!(".CONSTANTS"); + self.println(ctx, ""); + self.println(ctx, ".CONSTANTS"); for (idx, constant) in self.constants.iter().enumerate() { - println!(" {:04} {}", idx, constant); + self.println(ctx, format!(" {:04} {}", idx, constant)); } - println!(".IDENTIFIERS"); + self.println(ctx, ".IDENTIFIERS"); for (idx, identifier) in self.global_identifiers.iter().enumerate() { - println!(" {:04} {}", idx, identifier.0); + self.println(ctx, format!(" {:04} {}", idx, identifier.0)); } for (idx, (function_name, bytecode)) in self.bytecode.iter().enumerate() { - println!(".CODE {idx} ({name})", idx = idx, name = function_name); + self.println( + ctx, + format!(".CODE {idx} ({name})", idx = idx, name = function_name), + ); let mut offset = 0; while offset < bytecode.len() { let this_offset = offset; @@ -364,25 +371,34 @@ impl Vm { .collect::>() .join(" "); - print!( - " {:04} {:<13} {}", - this_offset, - op.to_string(), - operands_str + self.print( + ctx, + format!( + " {:04} {:<13} {}", + this_offset, + op.to_string(), + operands_str, + ), ); if op == Op::LoadConstant { - print!(" (value: {})", self.constants[operands[0] as usize]); + self.print( + ctx, + format!(" (value: {})", self.constants[operands[0] as usize]), + ); } else if op == Op::Call { - print!( - " ({}, num_args={})", - self.bytecode[operands[0] as usize].0, operands[1] as usize + self.print( + ctx, + format!( + " ({}, num_args={})", + self.bytecode[operands[0] as usize].0, operands[1] as usize + ), ); } - println!(); + self.println(ctx, ""); } } - println!(); + self.println(ctx, ""); } // The following functions are helpers for the actual execution of the code @@ -415,8 +431,8 @@ impl Vm { self.stack.pop().expect("stack should not be empty") } - pub fn run(&mut self) -> Result { - let result = self.run_without_cleanup(); + pub fn run(&mut self, ctx: &mut ExecutionContext) -> Result { + let result = self.run_without_cleanup(ctx); if result.is_err() { // Perform cleanup: clear the stack and move IP to the end. // This is useful for the REPL. @@ -438,10 +454,10 @@ impl Vm { self.current_frame().ip >= self.bytecode[self.current_frame().function_idx].1.len() } - fn run_without_cleanup(&mut self) -> Result { + fn run_without_cleanup(&mut self, ctx: &mut ExecutionContext) -> Result { let mut result_last_statement = None; while !self.is_at_the_end() { - self.debug(); + self.debug(ctx); let op = unsafe { std::mem::transmute::(self.read_byte()) }; @@ -576,7 +592,7 @@ impl Vm { self.push(result); } Callable::Procedure(procedure) => { - let result = (procedure)(&args[..]); + let result = (procedure)(ctx, &args[..]); match result { std::ops::ControlFlow::Continue(()) => {} @@ -590,7 +606,7 @@ impl Vm { Op::PrintString => { let s_idx = self.read_u16() as usize; let s = &self.strings[s_idx]; - println!("{}", s); + self.println(ctx, s); } Op::FullSimplify => { let simplified = self.pop().full_simplify(); @@ -633,23 +649,29 @@ impl Vm { } } - pub fn debug(&self) { + pub fn debug(&self, ctx: &mut ExecutionContext) { if !self.debug { return; } let frame = self.current_frame(); - print!( - "FRAME = {}, IP = {}, ", - self.bytecode[frame.function_idx].0, frame.ip + self.print( + ctx, + format!( + "FRAME = {}, IP = {}, ", + self.bytecode[frame.function_idx].0, frame.ip + ), ); - println!( - "Stack: [{}]", - self.stack - .iter() - .map(|x| x.to_string()) - .collect::>() - .join("] [") + self.println( + ctx, + format!( + "Stack: [{}]", + self.stack + .iter() + .map(|x| x.to_string()) + .collect::>() + .join("] [") + ), ); } @@ -658,6 +680,14 @@ impl Vm { assert!(self.strings.len() <= u16::MAX as usize); (self.strings.len() - 1) as u16 // TODO: this can overflow, see above } + + fn print>(&self, ctx: &mut ExecutionContext, s: S) { + (ctx.print_fn)(s.as_ref()); + } + + fn println>(&self, ctx: &mut ExecutionContext, s: S) { + self.print(ctx, format!("{}\n", s.as_ref())); + } } #[test] @@ -671,8 +701,13 @@ fn vm_basic() { vm.add_op(Op::Add); vm.add_op(Op::Return); + let mut print_fn = |_: &str| {}; + let mut ctx = ExecutionContext { + print_fn: &mut print_fn, + }; + assert_eq!( - vm.run().unwrap(), + vm.run(&mut ctx).unwrap(), InterpreterResult::Quantity(Quantity::from_scalar(42.0 + 1.0)) ); }