Skip to content

Commit

Permalink
Initial work on function types
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkdp authored and David Peter committed Feb 10, 2024
1 parent 650a064 commit c42fa17
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 4 deletions.
6 changes: 6 additions & 0 deletions examples/function_types.nbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
let eps = 1e-10

#fn diff<X, Y>(f: Fn[(X) -> Y], x: X) -> Y / X =
# (f(x + eps · unit_of(x)) - f(x)) / (eps · unit_of(x))

fn test(f: Fn[(Scalar) -> Scalar], x0: Scalar) = x0
21 changes: 21 additions & 0 deletions numbat/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
arithmetic::Exponent, decorator::Decorator, markup::Markup, number::Number, prefix::Prefix,
pretty_print::PrettyPrint, resolver::ModulePath,
};
use itertools::Itertools;
use num_traits::Signed;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -216,6 +217,7 @@ pub enum TypeAnnotation {
Bool(Span),
String(Span),
DateTime(Span),
Fn(Span, Vec<TypeAnnotation>, Box<TypeAnnotation>),
}

impl TypeAnnotation {
Expand All @@ -225,6 +227,7 @@ impl TypeAnnotation {
TypeAnnotation::Bool(span) => *span,
TypeAnnotation::String(span) => *span,
TypeAnnotation::DateTime(span) => *span,
TypeAnnotation::Fn(span, _, _) => *span,
}
}
}
Expand All @@ -236,6 +239,21 @@ impl PrettyPrint for TypeAnnotation {
TypeAnnotation::Bool(_) => m::type_identifier("Bool"),
TypeAnnotation::String(_) => m::type_identifier("String"),
TypeAnnotation::DateTime(_) => m::type_identifier("DateTime"),
TypeAnnotation::Fn(_, parameter_types, return_type) => {
m::type_identifier("Fn")
+ m::operator("[(")
+ Itertools::intersperse(
parameter_types.iter().map(|t| t.pretty_print()),
m::operator(",") + m::space(),
)
.sum()
+ m::operator(")")
+ m::space()
+ m::operator("->")
+ m::space()
+ return_type.pretty_print()
+ m::operator("]")
}
}
}
}
Expand Down Expand Up @@ -368,6 +386,9 @@ impl ReplaceSpans for TypeAnnotation {
TypeAnnotation::Bool(_) => TypeAnnotation::Bool(Span::dummy()),
TypeAnnotation::String(_) => TypeAnnotation::String(Span::dummy()),
TypeAnnotation::DateTime(_) => TypeAnnotation::DateTime(Span::dummy()),
TypeAnnotation::Fn(_, pt, rt) => {
TypeAnnotation::Fn(Span::dummy(), pt.clone(), rt.clone())
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion numbat/src/bytecode_interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl BytecodeInterpreter {
} else if LAST_RESULT_IDENTIFIERS.contains(&identifier.as_str()) {
self.vm.add_op(Op::GetLastResult);
} else {
unreachable!("Unknown identifier {identifier}")
unreachable!("Unknown identifier '{identifier}'")
}
}
Expression::UnitIdentifier(_span, prefix, unit_name, _full_name, _type) => {
Expand Down
34 changes: 34 additions & 0 deletions numbat/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,40 @@ impl<'a> Parser<'a> {
Ok(TypeAnnotation::String(token.span))
} else if let Some(token) = self.match_exact(TokenKind::DateTime) {
Ok(TypeAnnotation::DateTime(token.span))
} else if self.match_exact(TokenKind::CapitalFn).is_some() {
let span = self.last().unwrap().span;
if self.match_exact(TokenKind::LeftBracket).is_none() {
todo!()
}
if self.match_exact(TokenKind::LeftParen).is_none() {
todo!()
}

let mut params = vec![];
if self.peek().kind != TokenKind::RightParen {
params.push(self.type_annotation()?);
while self.match_exact(TokenKind::Comma).is_some() {
params.push(self.type_annotation()?);
}
}

if self.match_exact(TokenKind::RightParen).is_none() {
todo!()
}

if self.match_exact(TokenKind::Arrow).is_none() {
todo!()
}

let return_type = self.type_annotation()?;

if self.match_exact(TokenKind::RightBracket).is_none() {
todo!()
}

let span = span.extend(&self.last().unwrap().span);

Ok(TypeAnnotation::Fn(span, params, Box::new(return_type)))
} else {
Ok(TypeAnnotation::DimensionExpression(
self.dimension_expression()?,
Expand Down
9 changes: 8 additions & 1 deletion numbat/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub enum TokenKind {
// Brackets
LeftParen,
RightParen,
LeftBracket,
RightBracket,

// Operators and special signs
Plus,
Expand Down Expand Up @@ -80,7 +82,7 @@ pub enum TokenKind {

// Keywords
Let,
Fn,
Fn, // 'fn'
Dimension,
Unit,
Use,
Expand All @@ -97,6 +99,8 @@ pub enum TokenKind {
String,
DateTime,

CapitalFn, // 'Fn'

Long,
Short,
Both,
Expand Down Expand Up @@ -355,6 +359,7 @@ impl Tokenizer {
m.insert("else", TokenKind::Else);
m.insert("String", TokenKind::String);
m.insert("DateTime", TokenKind::DateTime);
m.insert("Fn", TokenKind::CapitalFn);
// Keep this list in sync with keywords::KEYWORDS!
m
});
Expand Down Expand Up @@ -385,6 +390,8 @@ impl Tokenizer {
let kind = match current_char {
'(' => TokenKind::LeftParen,
')' => TokenKind::RightParen,
'[' => TokenKind::LeftBracket,
']' => TokenKind::RightBracket,
'≤' => TokenKind::LessOrEqual,
'<' if self.match_char('=') => TokenKind::LessOrEqual,
'<' => TokenKind::LessThan,
Expand Down
30 changes: 28 additions & 2 deletions numbat/src/typechecker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,9 @@ pub struct TypeChecker {
}

impl TypeChecker {
fn identifier_type(&self, span: Span, name: &str) -> Result<&Type> {
self.identifiers
fn identifier_type(&self, span: Span, name: &str) -> Result<Type> {
let id = self
.identifiers
.get(name)
.ok_or_else(|| {
let suggestion = suggestion::did_you_mean(
Expand All @@ -440,6 +441,24 @@ impl TypeChecker {
TypeCheckError::UnknownIdentifier(span, name.into(), suggestion)
})
.map(|(type_, _)| type_)
.cloned();

if id.is_err() {
if let Some(signature) = self.function_signatures.get(name) {
Ok(Type::Fn(
signature
.parameter_types
.iter()
.map(|(_, t)| t.clone())
.collect(),
Box::new(signature.return_type.clone()),
))
} else {
id
}
} else {
id
}
}

pub(crate) fn check_expression(&self, ast: &ast::Expression) -> Result<typed_ast::Expression> {
Expand Down Expand Up @@ -1479,6 +1498,13 @@ impl TypeChecker {
TypeAnnotation::Bool(_) => Ok(Type::Boolean),
TypeAnnotation::String(_) => Ok(Type::String),
TypeAnnotation::DateTime(_) => Ok(Type::DateTime),
TypeAnnotation::Fn(_, param_types, return_type) => Ok(Type::Fn(
param_types
.iter()
.map(|p| self.type_from_annotation(p))
.collect::<Result<Vec<_>>>()?,
Box::new(self.type_from_annotation(return_type)?),
)),
}
}
}
Expand Down
27 changes: 27 additions & 0 deletions numbat/src/typed_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub enum Type {
Boolean,
String,
DateTime,
Fn(Vec<Type>, Box<Type>),
}

impl std::fmt::Display for Type {
Expand All @@ -67,6 +68,13 @@ impl std::fmt::Display for Type {
Type::Boolean => write!(f, "Bool"),
Type::String => write!(f, "String"),
Type::DateTime => write!(f, "DateTime"),
Type::Fn(param_types, return_type) => {
write!(
f,
"Fn[({ps}) -> {return_type}]",
ps = param_types.iter().map(|p| p.to_string()).join(", ")
)
}
}
}
}
Expand All @@ -78,6 +86,21 @@ impl PrettyPrint for Type {
Type::Boolean => m::keyword("Bool"),
Type::String => m::keyword("String"),
Type::DateTime => m::keyword("DateTime"),
Type::Fn(param_types, return_type) => {
m::type_identifier("Fn")
+ m::operator("[(")
+ Itertools::intersperse(
param_types.iter().map(|t| t.pretty_print()),
m::operator(",") + m::space(),
)
.sum()
+ m::operator(")")
+ m::space()
+ m::operator("->")
+ m::space()
+ return_type.pretty_print()
+ m::operator("]")
}
}
}
}
Expand All @@ -97,6 +120,10 @@ impl Type {
pub fn is_dtype(&self) -> bool {
matches!(self, Type::Dimension(..))
}

pub fn is_fn_type(&self) -> bool {
matches!(self, Type::Fn(..))
}
}

#[derive(Debug, Clone, PartialEq)]
Expand Down

0 comments on commit c42fa17

Please sign in to comment.