diff --git a/numbat/src/bytecode_interpreter.rs b/numbat/src/bytecode_interpreter.rs index bfe8ee9e..26bf00a5 100644 --- a/numbat/src/bytecode_interpreter.rs +++ b/numbat/src/bytecode_interpreter.rs @@ -17,7 +17,7 @@ use crate::typed_ast::{ }; use crate::unit::{CanonicalName, Unit}; use crate::unit_registry::{UnitMetadata, UnitRegistry}; -use crate::value::FunctionReference; +use crate::value::{FunctionReference, Value}; use crate::vm::{Constant, ExecutionContext, Op, Vm}; use crate::{decorator, ffi, Type}; @@ -159,7 +159,7 @@ impl BytecodeInterpreter { Expression::FunctionCall(_span, _full_span, name, args, _type) => { // Put all arguments on top of the stack for arg in args { - self.compile_expression_with_simplify(arg)?; + self.compile_expression(arg)?; } if let Some(idx) = self.vm.get_ffi_callable_idx(name) { @@ -181,7 +181,7 @@ impl BytecodeInterpreter { .sorted_by_key(|(n, _)| struct_info.fields.get_index_of(n).unwrap()); for (_, expr) in sorted_exprs.rev() { - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; } let struct_info_idx = self.vm.get_structinfo_idx(&struct_info.name).unwrap() as u16; @@ -190,7 +190,7 @@ impl BytecodeInterpreter { .add_op2(Op::BuildStructInstance, struct_info_idx, exprs.len() as u16); } Expression::AccessField(_span, _full_span, expr, attr, struct_type, _result_type) => { - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; let Type::Struct(ref struct_info) = struct_type.to_concrete_type() else { unreachable!( @@ -205,7 +205,7 @@ impl BytecodeInterpreter { Expression::CallableCall(_span, callable, args, _type) => { // Put all arguments on top of the stack for arg in args { - self.compile_expression_with_simplify(arg)?; + self.compile_expression(arg)?; } // Put the callable on top of the stack @@ -229,7 +229,7 @@ impl BytecodeInterpreter { span: _, format_specifiers, } => { - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; let index = self.vm.add_constant(Constant::FormatSpecifiers( format_specifiers.clone(), )); @@ -263,7 +263,7 @@ impl BytecodeInterpreter { } Expression::List(_, elements, _) => { for element in elements { - self.compile_expression_with_simplify(element)?; + self.compile_expression(element)?; } self.vm.add_op1(Op::BuildList, elements.len() as u16); @@ -276,32 +276,6 @@ impl BytecodeInterpreter { Ok(()) } - fn compile_expression_with_simplify(&mut self, expr: &Expression) -> Result<()> { - self.compile_expression(expr)?; - - match expr { - Expression::Scalar(..) - | Expression::Identifier(..) - | Expression::UnitIdentifier(..) - | Expression::FunctionCall(..) - | Expression::CallableCall(..) - | Expression::UnaryOperator(..) - | Expression::BinaryOperator(_, BinaryOperator::ConvertTo, _, _, _) - | Expression::Boolean(..) - | Expression::String(..) - | Expression::Condition(..) - | Expression::InstantiateStruct(..) - | Expression::AccessField(..) - | Expression::List(..) => {} - Expression::BinaryOperator(..) | Expression::BinaryOperatorForDate(..) => { - self.vm.add_op(Op::FullSimplify); - } - Expression::TypedHole(_, _) => unreachable!("Typed holes cause type inference errors"), - } - - Ok(()) - } - fn compile_define_variable(&mut self, define_variable: &DefineVariable) -> Result<()> { let DefineVariable(identifier, decorators, expr, _annotation, _type, _readable_type) = define_variable; @@ -320,7 +294,7 @@ impl BytecodeInterpreter { }; for alias_name in aliases { - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; self.locals[current_depth].push(Local { identifier: alias_name.clone(), @@ -338,7 +312,7 @@ impl BytecodeInterpreter { ) -> Result<()> { match stmt { Statement::Expression(expr) => { - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; self.vm.add_op(Op::Return); } Statement::DefineVariable(define_variable) => { @@ -351,7 +325,7 @@ impl BytecodeInterpreter { parameters, Some(expr), local_variables, - _return_type, + _function_type, _return_type_annotation, _readable_return_type, ) => { @@ -371,7 +345,8 @@ impl BytecodeInterpreter { self.compile_define_variable(local_variables)?; } - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; + self.vm.add_op(Op::Return); self.locals.pop(); @@ -484,7 +459,7 @@ impl BytecodeInterpreter { }, ); // TODO: there is some asymmetry here because we do not introduce identifiers for base units - self.compile_expression_with_simplify(expr)?; + self.compile_expression(expr)?; self.vm .add_op2(Op::SetUnitConstant, unit_information_idx, constant_idx); @@ -507,7 +482,7 @@ impl BytecodeInterpreter { Statement::ProcedureCall(kind, args) => { // Put all arguments on top of the stack for arg in args { - self.compile_expression_with_simplify(arg)?; + self.compile_expression(arg)?; } let name = &ffi::procedures().get(kind).unwrap().name; @@ -542,6 +517,13 @@ impl BytecodeInterpreter { let result = self.vm.run(&mut ctx); + let result = match result { + Ok(InterpreterResult::Value(Value::Quantity(q))) => { + Ok(InterpreterResult::Value(Value::Quantity(q.full_simplify()))) + } + r => r, + }; + self.vm.debug(); result diff --git a/numbat/src/quantity.rs b/numbat/src/quantity.rs index 2cedfc53..eb3afcb3 100644 --- a/numbat/src/quantity.rs +++ b/numbat/src/quantity.rs @@ -24,20 +24,31 @@ pub type Result = std::result::Result; pub struct Quantity { value: Number, unit: Unit, + can_simplify: bool, } impl Quantity { pub fn new(value: Number, unit: Unit) -> Self { - Quantity { value, unit } + Quantity { + value, + unit, + can_simplify: true, + } } pub fn new_f64(value: f64, unit: Unit) -> Self { Quantity { value: Number::from_f64(value), unit, + can_simplify: true, } } + pub fn no_simplify(mut self) -> Self { + self.can_simplify = false; + self + } + pub fn from_scalar(value: f64) -> Quantity { Quantity::new_f64(value, Unit::scalar()) } @@ -130,6 +141,10 @@ impl Quantity { } pub fn full_simplify(&self) -> Self { + if !self.can_simplify { + return self.clone(); + } + // Heuristic 1 if let Ok(scalar_result) = self.convert_to(&Unit::scalar()) { return scalar_result; @@ -256,10 +271,10 @@ impl std::ops::Add for &Quantity { } else if rhs.is_zero() { Ok(self.clone()) } else { - Ok(Quantity { - value: self.value + rhs.convert_to(&self.unit)?.value, - unit: self.unit.clone(), - }) + Ok(Quantity::new( + self.value + rhs.convert_to(&self.unit)?.value, + self.unit.clone(), + )) } } } @@ -273,10 +288,10 @@ impl std::ops::Sub for &Quantity { } else if rhs.is_zero() { Ok(self.clone()) } else { - Ok(Quantity { - value: self.value - rhs.convert_to(&self.unit)?.value, - unit: self.unit.clone(), - }) + Ok(Quantity::new( + self.value - rhs.convert_to(&self.unit)?.value, + self.unit.clone(), + )) } } } @@ -285,10 +300,7 @@ impl std::ops::Mul for Quantity { type Output = Quantity; fn mul(self, rhs: Self) -> Self::Output { - Quantity { - value: self.value * rhs.value, - unit: self.unit * rhs.unit, - } + Quantity::new(self.value * rhs.value, self.unit * rhs.unit) } } @@ -296,10 +308,7 @@ impl std::ops::Div for Quantity { type Output = Quantity; fn div(self, rhs: Self) -> Self::Output { - Quantity { - value: self.value / rhs.value, - unit: self.unit / rhs.unit, - } + Quantity::new(self.value / rhs.value, self.unit / rhs.unit) } } @@ -307,10 +316,7 @@ impl std::ops::Neg for Quantity { type Output = Quantity; fn neg(self) -> Self::Output { - Quantity { - value: -self.value, - unit: self.unit, - } + Quantity::new(-self.value, self.unit) } } diff --git a/numbat/src/vm.rs b/numbat/src/vm.rs index 0933f7da..4878d514 100644 --- a/numbat/src/vm.rs +++ b/numbat/src/vm.rs @@ -108,9 +108,6 @@ pub enum Op { /// Combine N strings on the stack into a single part, used by string interpolation JoinString, - /// Perform a simplification operation to the current value on the stack - FullSimplify, - /// Build a struct from the field values on the stack BuildStructInstance, /// Access a single field of a struct @@ -159,7 +156,6 @@ impl Op { | Op::LogicalAnd | Op::LogicalOr | Op::LogicalNeg - | Op::FullSimplify | Op::Return | Op::GetLastResult => 0, } @@ -201,7 +197,6 @@ impl Op { Op::CallCallable => "CallCallable", Op::PrintString => "PrintString", Op::JoinString => "JoinString", - Op::FullSimplify => "FullSimplify", Op::Return => "Return", Op::BuildStructInstance => "BuildStructInstance", Op::AccessStructField => "AccessStructField", @@ -703,7 +698,9 @@ impl Vm { Ok(lhs.checked_div(rhs).ok_or(RuntimeError::DivisionByZero)?) } Op::Power => lhs.power(rhs), - Op::ConvertTo => lhs.convert_to(rhs.unit()), + // If the user specifically converted the type of a unit, we should NOT simplify this value + // before any operations are applied to it + Op::ConvertTo => lhs.convert_to(rhs.unit()).map(Quantity::no_simplify), _ => unreachable!(), }; self.push_quantity(result.map_err(RuntimeError::QuantityError)?); @@ -911,7 +908,7 @@ impl Vm { let dt = self.pop_datetime(); let tz = jiff::tz::TimeZone::get(&tz_name) - .map_err(|_| RuntimeError::UnknownTimezone(tz_name.into()))?; + .map_err(|_| RuntimeError::UnknownTimezone(tz_name))?; let dt = dt.with_time_zone(tz); @@ -928,7 +925,7 @@ impl Vm { let num_parts = self.read_u16() as usize; let mut joined = String::new(); let to_str = |value| match value { - Value::Quantity(q) => q.to_string(), + Value::Quantity(q) => q.full_simplify().to_string(), Value::Boolean(b) => b.to_string(), Value::String(s) => s, Value::DateTime(dt) => crate::datetime::to_string(&dt), @@ -950,6 +947,8 @@ impl Vm { let part = match self.pop() { Value::FormatSpecifiers(Some(specifiers)) => match self.pop() { Value::Quantity(q) => { + let q = q.full_simplify(); + let mut vars = HashMap::new(); vars.insert("value".to_string(), q.unsafe_value().to_f64()); @@ -981,13 +980,6 @@ impl Vm { } self.push(Value::String(joined)) } - Op::FullSimplify => match self.pop() { - Value::Quantity(q) => { - let simplified = q.full_simplify(); - self.push_quantity(simplified); - } - v => self.push(v), - }, Op::Return => { if self.frames.len() == 1 { let return_value = self.pop(); diff --git a/numbat/tests/interpreter.rs b/numbat/tests/interpreter.rs index a07665b1..497d369c 100644 --- a/numbat/tests/interpreter.rs +++ b/numbat/tests/interpreter.rs @@ -213,6 +213,13 @@ fn test_conversions() { expect_output("5m^2 -> cm*m", "500 cm·m"); expect_output("1 kB / 10 ms -> MB/s", "0.1 MB/s"); expect_output("55! / (6! (55 - 6)!) -> million", "28.9897 million"); + + // regression test for https://github.com/sharkdp/numbat/issues/534 + let mut ctx = get_test_context(); + let _ = ctx + .interpret("let x = 1 deg", CodeSource::Internal) + .unwrap(); + expect_output_with_context(&mut ctx, "12 deg -> x", "12°"); } #[test] @@ -943,5 +950,14 @@ mod tests { 212121001.1 cm "###); } + + #[test] + fn issue505_angles() { + insta::assert_snapshot!(fail("assert_eq(-77° + 0′ + 32″, -77.0089°, 1e-4°)"), @r###" + Assertion failed because the following two quantities differ by 0.0178°, which is more than 0.0001°: + -76.9911° + -77.0089° + "###); + } } } diff --git a/numbat/tests/prelude_and_examples.rs b/numbat/tests/prelude_and_examples.rs index 873405b8..da12e91e 100644 --- a/numbat/tests/prelude_and_examples.rs +++ b/numbat/tests/prelude_and_examples.rs @@ -21,7 +21,7 @@ fn assert_runs(code: &str) { fn assert_runs_without_prelude(code: &str) { let result = get_test_context_without_prelude().interpret(code, CodeSource::Internal); - assert!(result.is_ok()); + assert!(result.is_ok(), "Failed with: {}", result.unwrap_err()); assert!(matches!( result.unwrap().1, InterpreterResult::Value(_) | InterpreterResult::Continue