Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the number of calls to simplify #537

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions numbat/src/bytecode_interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl BytecodeInterpreter {
Expression::FunctionCall(_span, _full_span, name, args, _type) => {
// Put all arguments on top of the stack
for arg in args {
self.compile_expression_with_simplify(arg)?;
self.compile_expression(arg)?;
}

if let Some(idx) = self.vm.get_ffi_callable_idx(name) {
Expand All @@ -181,7 +181,7 @@ impl BytecodeInterpreter {
.sorted_by_key(|(n, _)| struct_info.fields.get_index_of(n).unwrap());

for (_, expr) in sorted_exprs.rev() {
self.compile_expression_with_simplify(expr)?;
self.compile_expression(expr)?;
}

let struct_info_idx = self.vm.get_structinfo_idx(&struct_info.name).unwrap() as u16;
Expand All @@ -190,7 +190,7 @@ impl BytecodeInterpreter {
.add_op2(Op::BuildStructInstance, struct_info_idx, exprs.len() as u16);
}
Expression::AccessField(_span, _full_span, expr, attr, struct_type, _result_type) => {
self.compile_expression_with_simplify(expr)?;
self.compile_expression(expr)?;

let Type::Struct(ref struct_info) = struct_type.to_concrete_type() else {
unreachable!(
Expand All @@ -205,7 +205,7 @@ impl BytecodeInterpreter {
Expression::CallableCall(_span, callable, args, _type) => {
// Put all arguments on top of the stack
for arg in args {
self.compile_expression_with_simplify(arg)?;
self.compile_expression(arg)?;
}

// Put the callable on top of the stack
Expand Down Expand Up @@ -263,7 +263,7 @@ impl BytecodeInterpreter {
}
Expression::List(_, elements, _) => {
for element in elements {
self.compile_expression_with_simplify(element)?;
self.compile_expression(element)?;
}

self.vm.add_op1(Op::BuildList, elements.len() as u16);
Expand All @@ -280,20 +280,21 @@ impl BytecodeInterpreter {
self.compile_expression(expr)?;

match expr {
Expression::BinaryOperator(_, BinaryOperator::ConvertTo, _, _, _)
| Expression::InstantiateStruct(..)
| Expression::Boolean(..)
| Expression::String(..)
| Expression::Condition(..)
| Expression::List(..) => (),
Expression::Scalar(..)
| Expression::Identifier(..)
| Expression::UnitIdentifier(..)
| Expression::FunctionCall(..)
| Expression::CallableCall(..)
| Expression::UnaryOperator(..)
| Expression::BinaryOperator(_, BinaryOperator::ConvertTo, _, _, _)
| Expression::Boolean(..)
| Expression::String(..)
| Expression::Condition(..)
| Expression::InstantiateStruct(..)
| Expression::AccessField(..)
| Expression::List(..) => {}
Expression::BinaryOperator(..) | Expression::BinaryOperatorForDate(..) => {
| Expression::BinaryOperator(..)
| Expression::BinaryOperatorForDate(..) => {
self.vm.add_op(Op::FullSimplify);
}
Expression::TypedHole(_, _) => unreachable!("Typed holes cause type inference errors"),
Expand Down Expand Up @@ -351,7 +352,7 @@ impl BytecodeInterpreter {
parameters,
Some(expr),
local_variables,
_return_type,
_function_type,
_return_type_annotation,
_readable_return_type,
) => {
Expand All @@ -371,7 +372,8 @@ impl BytecodeInterpreter {
self.compile_define_variable(local_variables)?;
}

self.compile_expression_with_simplify(expr)?;
self.compile_expression(expr)?;

self.vm.add_op(Op::Return);

self.locals.pop();
Expand Down Expand Up @@ -484,7 +486,7 @@ impl BytecodeInterpreter {
},
); // TODO: there is some asymmetry here because we do not introduce identifiers for base units

self.compile_expression_with_simplify(expr)?;
self.compile_expression(expr)?;
self.vm
.add_op2(Op::SetUnitConstant, unit_information_idx, constant_idx);

Expand Down
48 changes: 27 additions & 21 deletions numbat/src/quantity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,31 @@ pub type Result<T> = std::result::Result<T, QuantityError>;
pub struct Quantity {
value: Number,
unit: Unit,
can_simplify: bool,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name confused me a bit when first reading this code. Maybe call it simplification_allowed? And rename no_simplify to prevent_simplification?

}

impl Quantity {
pub fn new(value: Number, unit: Unit) -> Self {
Quantity { value, unit }
Quantity {
value,
unit,
can_simplify: true,
}
}

pub fn new_f64(value: f64, unit: Unit) -> Self {
Quantity {
value: Number::from_f64(value),
unit,
can_simplify: true,
}
}

pub fn no_simplify(mut self) -> Self {
self.can_simplify = false;
self
}

pub fn from_scalar(value: f64) -> Quantity {
Quantity::new_f64(value, Unit::scalar())
}
Expand Down Expand Up @@ -130,6 +141,10 @@ impl Quantity {
}

pub fn full_simplify(&self) -> Self {
if !self.can_simplify {
return self.clone();
}

// Heuristic 1
if let Ok(scalar_result) = self.convert_to(&Unit::scalar()) {
return scalar_result;
Expand Down Expand Up @@ -256,10 +271,10 @@ impl std::ops::Add for &Quantity {
} else if rhs.is_zero() {
Ok(self.clone())
} else {
Ok(Quantity {
value: self.value + rhs.convert_to(&self.unit)?.value,
unit: self.unit.clone(),
})
Ok(Quantity::new(
self.value + rhs.convert_to(&self.unit)?.value,
self.unit.clone(),
))
}
}
}
Expand All @@ -273,10 +288,10 @@ impl std::ops::Sub for &Quantity {
} else if rhs.is_zero() {
Ok(self.clone())
} else {
Ok(Quantity {
value: self.value - rhs.convert_to(&self.unit)?.value,
unit: self.unit.clone(),
})
Ok(Quantity::new(
self.value - rhs.convert_to(&self.unit)?.value,
self.unit.clone(),
))
}
}
}
Expand All @@ -285,32 +300,23 @@ impl std::ops::Mul for Quantity {
type Output = Quantity;

fn mul(self, rhs: Self) -> Self::Output {
Quantity {
value: self.value * rhs.value,
unit: self.unit * rhs.unit,
}
Quantity::new(self.value * rhs.value, self.unit * rhs.unit)
}
}

impl std::ops::Div for Quantity {
type Output = Quantity;

fn div(self, rhs: Self) -> Self::Output {
Quantity {
value: self.value / rhs.value,
unit: self.unit / rhs.unit,
}
Quantity::new(self.value / rhs.value, self.unit / rhs.unit)
}
}

impl std::ops::Neg for Quantity {
type Output = Quantity;

fn neg(self) -> Self::Output {
Quantity {
value: -self.value,
unit: self.unit,
}
Quantity::new(-self.value, self.unit)
}
}

Expand Down
6 changes: 4 additions & 2 deletions numbat/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ impl Vm {
Ok(lhs.checked_div(rhs).ok_or(RuntimeError::DivisionByZero)?)
}
Op::Power => lhs.power(rhs),
Op::ConvertTo => lhs.convert_to(rhs.unit()),
// If the user specifically converted the type of a unit, we should NOT simplify this value
// before any operations are applied to it
Op::ConvertTo => lhs.convert_to(rhs.unit()).map(Quantity::no_simplify),
_ => unreachable!(),
};
self.push_quantity(result.map_err(RuntimeError::QuantityError)?);
Expand Down Expand Up @@ -911,7 +913,7 @@ impl Vm {
let dt = self.pop_datetime();

let tz = jiff::tz::TimeZone::get(&tz_name)
.map_err(|_| RuntimeError::UnknownTimezone(tz_name.into()))?;
.map_err(|_| RuntimeError::UnknownTimezone(tz_name))?;

let dt = dt.with_time_zone(tz);

Expand Down
2 changes: 1 addition & 1 deletion numbat/tests/prelude_and_examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn assert_runs(code: &str) {

fn assert_runs_without_prelude(code: &str) {
let result = get_test_context_without_prelude().interpret(code, CodeSource::Internal);
assert!(result.is_ok());
assert!(result.is_ok(), "Failed with: {}", result.unwrap_err());
assert!(matches!(
result.unwrap().1,
InterpreterResult::Value(_) | InterpreterResult::Continue
Expand Down
Loading