From 341ca6d69b979968b82fd8071cf57a7422a40fa6 Mon Sep 17 00:00:00 2001 From: Andrew Chin Date: Thu, 23 Nov 2023 15:33:27 -0500 Subject: [PATCH] Introduce DType::is_scalar function --- numbat/src/typechecker.rs | 6 +++--- numbat/src/typed_ast.rs | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/numbat/src/typechecker.rs b/numbat/src/typechecker.rs index c44ff184..822e5b49 100644 --- a/numbat/src/typechecker.rs +++ b/numbat/src/typechecker.rs @@ -457,7 +457,7 @@ impl TypeChecker { match *op { ast::UnaryOperator::Factorial => { - if dtype != DType::unity() { + if !dtype.is_scalar() { return Err(TypeCheckError::NonScalarFactorialArgument( expr.full_span(), dtype, @@ -539,7 +539,7 @@ impl TypeChecker { } typed_ast::BinaryOperator::Power => { let exponent_type = dtype(&rhs_checked)?; - if exponent_type != DType::unity() { + if !exponent_type.is_scalar() { return Err(TypeCheckError::NonScalarExponent( rhs.full_span(), exponent_type, @@ -547,7 +547,7 @@ impl TypeChecker { } let base_type = dtype(&lhs_checked)?; - if base_type == DType::unity() { + if base_type.is_scalar() { // Skip evaluating the exponent if the lhs is a scalar. This allows // for arbitrary (decimal) exponents, if the base is a scalar. diff --git a/numbat/src/typed_ast.rs b/numbat/src/typed_ast.rs index 297b46dc..0f856852 100644 --- a/numbat/src/typed_ast.rs +++ b/numbat/src/typed_ast.rs @@ -15,8 +15,11 @@ use crate::{ pub type DType = BaseRepresentation; impl DType { + pub fn is_scalar(&self) -> bool { + self == &DType::unity() + } pub fn to_readable_type(&self, registry: &DimensionRegistry) -> m::Markup { - if self == &DType::unity() { + if self.is_scalar() { return m::type_identifier("Scalar"); }