diff --git a/compiler/src/passes/validate/error.rs b/compiler/src/passes/validate/error.rs index 04947be..804a317 100644 --- a/compiler/src/passes/validate/error.rs +++ b/compiler/src/passes/validate/error.rs @@ -66,16 +66,41 @@ pub enum TypeError { }, #[error("The program doesn't have a main function.")] NoMain, - #[error("Types did not match.")] MismatchedFnReturn { expect: String, got: String, - #[label = "Expected this function to be of type: `{expect}`"] + //TODO would like this span to be return type if present + #[label = "Expected this function to return: `{expect}`"] span_expected: (usize, usize), - - #[label = "but got this type: `{got}`"] + #[label = "But got this type: `{got}`"] span_got: (usize, usize), }, + #[error("Types did not match.")] + OperandExpect { + expect: String, + got: String, + op: String, + + //TODO would like this span to be operator + #[label = "Arguments of {op} are of type: `{expect}`"] + span_op: (usize, usize), + #[label = "But got this type: `{got}`"] + span_arg: (usize, usize), + }, + #[error("Types did not match.")] + OperandEqual { + lhs: String, + rhs: String, + op: String, + + //TODO would like this span to be operator + #[label = "Arguments of {op} should be of equal types."] + span_op: (usize, usize), + #[label = "Type: `{lhs}`"] + span_lhs: (usize, usize), + #[label = "Type: `{rhs}`"] + span_rhs: (usize, usize), + }, } diff --git a/compiler/src/passes/validate/generate_constraints.rs b/compiler/src/passes/validate/generate_constraints.rs index 8cee96f..e1eb797 100644 --- a/compiler/src/passes/validate/generate_constraints.rs +++ b/compiler/src/passes/validate/generate_constraints.rs @@ -1,11 +1,10 @@ -use crate::passes::parse::{Lit, Meta, Span}; +use crate::passes::parse::types::Type; +use crate::passes::parse::{BinaryOp, Lit, Meta, Span, UnaryOp}; use crate::passes::validate::error::TypeError; -use crate::passes::validate::error::TypeError::MismatchedFnReturn; use crate::passes::validate::uncover_globals::{uncover_globals, Env, EnvEntry}; use crate::passes::validate::uniquify::PrgUniquified; use crate::passes::validate::{ - type_to_index, CMeta, DefConstrained, DefUniquified, ExprConstrained, ExprUniquified, - PrgConstrained, + CMeta, DefConstrained, DefUniquified, ExprConstrained, ExprUniquified, PrgConstrained, }; use crate::utils::gen_sym::UniqueSym; use crate::utils::union_find::{UnionFind, UnionIndex}; @@ -87,32 +86,37 @@ fn constrain_def<'p>( typ, bdy, } => { + // Put function parameters in scope. scope.extend(params.iter().map(|p| { ( p.sym.inner, EnvEntry::Type { mutable: p.mutable, - typ: type_to_index(p.typ.clone(), uf), + typ: uf.type_to_index(p.typ.clone()), }, ) })); - let return_index = type_to_index(typ.clone(), uf); + // Add return type to env and keep it for error handling. + let return_index = uf.type_to_index(typ.clone()); let mut env = Env { uf, scope, return_type: return_index, }; + // Constrain body of function. let bdy = constrain_expr(bdy, &mut env)?; - uf.try_union_by(return_index, bdy.meta.index, combine_partial_types) - .map_err(|_| MismatchedFnReturn { - expect: format!("{typ}"), - got: format!("{}", "bananas"), - span_expected: (0, 0), - span_got: (0, 0), - })?; + // Return error if function body a type differs from its return type. + uf.expect_equal(return_index, bdy.meta.index, |r, b| { + TypeError::MismatchedFnReturn { + expect: r, + got: b, + span_expected: sym.meta, + span_got: bdy.meta.span, + } + })?; DefConstrained::Fn { sym, @@ -131,6 +135,8 @@ fn constrain_expr<'p>( expr: Meta>, env: &mut Env<'_, 'p>, ) -> Result>, TypeError> { + let span = expr.meta; + Ok(match expr.inner { ExprUniquified::Lit { val } => { let typ = match val { @@ -140,10 +146,7 @@ fn constrain_expr<'p>( }; let index = env.uf.add(typ); Meta { - meta: CMeta { - span: expr.meta, - index, - }, + meta: CMeta { span, index }, inner: ExprConstrained::Lit { val }, } } @@ -151,45 +154,98 @@ fn constrain_expr<'p>( let EnvEntry::Type { typ, .. } = env.scope[&sym.inner] else { panic!(); }; + Meta { + meta: CMeta { span, index: typ }, + inner: ExprConstrained::Var { sym }, + } + } + ExprUniquified::UnaryOp { op, expr } => { + let typ = match op { + UnaryOp::Neg => Type::I64, + UnaryOp::Not => Type::Bool, + }; + let expr = constrain_expr(*expr, env)?; + + env.uf.expect_type(expr.meta.index, typ, |got, expect| { + TypeError::OperandExpect { + expect, + got, + op: op.to_string(), + span_op: span, + span_arg: expr.meta.span, + } + })?; + Meta { meta: CMeta { - span: expr.meta, - index: typ, + span, + index: expr.meta.index, + }, + inner: ExprConstrained::UnaryOp { + op, + expr: Box::new(expr), + }, + } + } + ExprUniquified::BinaryOp { op, exprs: [lhs, rhs] } => { + // input: None = Any but equal, Some = expect this + // output: None = Same as input, Some = this + let (input, output) = match op { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => (Some(PartialType::Int), None), + BinaryOp::LAnd | BinaryOp::LOr | BinaryOp::Xor => (Some(PartialType::Bool), None), + BinaryOp::GT | BinaryOp::GE | BinaryOp::LE | BinaryOp::LT => (Some(PartialType::Int), Some(PartialType::Bool)), + BinaryOp::EQ | BinaryOp::NE => (None, Some(PartialType::Bool)), + }; + + let e1 = constrain_expr(*lhs, env)?; + let e2 = constrain_expr(*rhs, env)?; + + // Check inputs satisfy constraints + if let Some(input) = input { + let mut check = |expr: &Meta>| { + env.uf.expect_partial_type(expr.meta.index, input.clone(), |got, expect| TypeError::OperandExpect { + expect, + got, + op: op.to_string(), + span_op: span, + span_arg: expr.meta.span, + }) + }; + + check(&e1)?; + check(&e2)?; + } + + // Check inputs equal + let input_index = env.uf.expect_equal(e1.meta.index, e2.meta.index, |lhs, rhs| TypeError::OperandEqual { + lhs, + rhs, + op: op.to_string(), + span_op: span, + span_lhs: e1.meta.span, + span_rhs: e2.meta.span, + })?; + + // Generate output index + let output_index = match output { + None => input_index, + Some(e) => { + env.uf.add(e) + } + }; + + Meta { + meta: CMeta { + span, + index: output_index, + }, + inner: ExprConstrained::BinaryOp { + op, + exprs: [e1, e2].map(Box::new), }, - inner: ExprConstrained::Var { sym }, } } - ExprUniquified::UnaryOp { op, expr } => todo!(), - ExprUniquified::BinaryOp { op, exprs } => todo!(), - // ExprUniquified::Prim { op, args } if args.len() == 2 => { - // let (pt, lhs, rhs) = match op { - // Op::Plus => (PartialType::Int, PartialType::Int, PartialType::Int), - // Op::Minus => PartialType::Int, - // Op::Mul => PartialType::Int, - // Op::Div => PartialType::Int, - // Op::Mod => PartialType::Int, - // Op::LAnd => todo!(), - // Op::LOr => todo!(), - // Op::Xor => todo!(), - // Op::GT => PartialType::Int, - // Op::GE => PartialType::Int, - // Op::EQ => todo!(), - // Op::LE => PartialType::Int, - // Op::LT => PartialType::Int, - // Op::NE => todo!(), - // Op::Read | Op::Print | Op::Not => unreachable!(), - // }; - // - // let index = env.uf.add(pt); - // - // Meta { - // meta: CMeta{ span: expr.meta, index }, - // inner: ExprConstrained::Prim { op, args: args.into_iter().map(|arg| match arg { - // - // })}, - // } - // }, - ExprUniquified::Let { sym, bnd, bdy, .. } => todo!(), + ExprUniquified::Let { .. } => todo!(), ExprUniquified::If { .. } => todo!(), ExprUniquified::Apply { .. } => todo!(), ExprUniquified::Loop { .. } => todo!(), @@ -205,11 +261,10 @@ fn constrain_expr<'p>( }) } -// uf: &mut UnionFind> fn combine_partial_types<'p>( a: PartialType<'p>, b: PartialType<'p>, - uf: &mut UnionFind> + uf: &mut UnionFind>, ) -> Result, ()> { let typ = match (a, b) { (PartialType::I64, PartialType::I64 | PartialType::Int) => PartialType::I64, @@ -235,22 +290,78 @@ fn combine_partial_types<'p>( }, ) => { if params_a.len() != params_b.len() { - return Err(()) + return Err(()); } - let params = params_a.into_iter().zip(params_b).map(|(param_a, param_b)| { - uf.try_union_by(param_a, param_b, combine_partial_types) - }).collect::>()?; + let params = params_a + .into_iter() + .zip(params_b) + .map(|(param_a, param_b)| uf.try_union_by(param_a, param_b, combine_partial_types)) + .collect::>()?; let typ = uf.try_union_by(typ_a, typ_b, combine_partial_types)?; - PartialType::Fn { - params, - typ, - } + PartialType::Fn { params, typ } } _ => return Err(()), }; Ok(typ) } + +impl<'p> UnionFind> { + pub fn expect_equal( + &mut self, + a: UnionIndex, + b: UnionIndex, + map_err: impl FnOnce(String, String) -> TypeError, + ) -> Result { + self.try_union_by(a, b, combine_partial_types).map_err(|_| { + let typ_a = self.get(a).clone(); + let str_a = typ_a.to_string(self); + let typ_b = self.get(b).clone(); + let str_b = typ_b.to_string(self); + map_err(str_a, str_b) + }) + } + + pub fn expect_type( + &mut self, + a: UnionIndex, + t: Type>>, + map_err: impl FnOnce(String, String) -> TypeError, + ) -> Result { + let t_index = self.type_to_index(t); + self.expect_equal(a, t_index, map_err) + } + + pub fn expect_partial_type( + &mut self, + a: UnionIndex, + t: PartialType<'p>, + map_err: impl FnOnce(String, String) -> TypeError, + ) -> Result { + let t_index = self.add(t); + self.expect_equal(a, t_index, map_err) + } + + pub fn type_to_index(&mut self, t: Type>>) -> UnionIndex { + let pt = match t { + Type::I64 => PartialType::I64, + Type::U64 => PartialType::U64, + Type::Bool => PartialType::Bool, + Type::Unit => PartialType::Unit, + Type::Never => PartialType::Never, + Type::Fn { params, typ } => PartialType::Fn { + params: params + .into_iter() + .map(|param| self.type_to_index(param)) + .collect(), + typ: self.type_to_index(*typ), + }, + Type::Var { sym } => PartialType::Var { sym: sym.inner }, + }; + + self.add(pt) + } +} diff --git a/compiler/src/passes/validate/mod.rs b/compiler/src/passes/validate/mod.rs index 2a14350..c3049bc 100644 --- a/compiler/src/passes/validate/mod.rs +++ b/compiler/src/passes/validate/mod.rs @@ -108,26 +108,3 @@ impl FromStr for TLit { }) } } - -pub fn type_to_index<'p>( - t: Type>>, - uf: &mut UnionFind>, -) -> UnionIndex { - let pt = match t { - Type::I64 => PartialType::I64, - Type::U64 => PartialType::U64, - Type::Bool => PartialType::Bool, - Type::Unit => PartialType::Unit, - Type::Never => PartialType::Never, - Type::Fn { params, typ } => PartialType::Fn { - params: params - .into_iter() - .map(|param| type_to_index(param, uf)) - .collect(), - typ: type_to_index(*typ, uf), - }, - Type::Var { sym } => PartialType::Var { sym: sym.inner }, - }; - - uf.add(pt) -} diff --git a/compiler/src/passes/validate/resolve_types.rs b/compiler/src/passes/validate/resolve_types.rs index 1387d7f..7788bcf 100644 --- a/compiler/src/passes/validate/resolve_types.rs +++ b/compiler/src/passes/validate/resolve_types.rs @@ -7,7 +7,7 @@ use crate::passes::validate::{ PrgValidated, TLit, }; use crate::utils::gen_sym::UniqueSym; -use crate::utils::union_find::UnionFind; +use crate::utils::union_find::{UnionFind, UnionIndex}; use functor_derive::Functor; impl<'p> PrgConstrained<'p> { @@ -79,40 +79,76 @@ fn resolve_type(typ: Type>) -> Type { } } +/// Panic if not possible +fn partial_type_to_type<'p>( + value: UnionIndex, + uf: &mut UnionFind>, +) -> Option>> { + Some(match uf.get(value).clone() { + PartialType::I64 => Type::I64, + PartialType::U64 => Type::U64, + PartialType::Int => return None, + PartialType::Bool => Type::Bool, + PartialType::Unit => Type::Unit, + PartialType::Never => Type::Never, + PartialType::Var { sym } => Type::Var { sym }, + PartialType::Fn { params, typ } => Type::Fn { + params: params + .into_iter() + .map(|param| partial_type_to_type(uf.find(param), uf)) + .collect::>()?, + typ: Box::new(partial_type_to_type(typ, uf)?), + }, + }) +} + fn resolve_expr<'p>( expr: Meta>, uf: &mut UnionFind>, ) -> Result>, ExprValidated<'p>>, TypeError> { - let (typ, expr) = match expr.inner { + // Type of the expression, if `None` then type is still ambiguous. + let typ = partial_type_to_type(expr.meta.index, uf); + + let expr: Expr<_, _, _, Type> = match expr.inner { ExprConstrained::Lit { val } => { - let (typ, val) = match val { + let val = match val { Lit::Int { val } => { - match &uf.get(expr.meta.index) { - PartialType::I64 => { - let val = val.parse().map_err(|_| TypeError::IntegerOutOfBounds { - span: expr.meta.span, - typ: "I64", - })?; - (Type::I64, TLit::I64 { val }) - }, - PartialType::U64 => todo!(), - PartialType::Int => { - return Err(dbg!(TypeError::IntegerAmbiguous { - span: expr.meta.span - })) + match &typ { + None => return Err(dbg!(TypeError::IntegerAmbiguous { + span: expr.meta.span + })), + Some(typ) => match typ { + Type::I64 => TLit::I64 { + val: val.parse().map_err(|_| TypeError::IntegerOutOfBounds { + span: expr.meta.span, + typ: "I64", + })?, + }, + Type::U64 => todo!(), + _ => unreachable!(), }, - _ => unreachable!(), } - }, - Lit::Bool { val } => (Type::Bool ,TLit::Bool { val }), - Lit::Unit => (Type::Unit, TLit::Unit), + } + Lit::Bool { val } => TLit::Bool { val }, + Lit::Unit => TLit::Unit, }; - - (typ, Expr::Lit { val }) - }, + Expr::Lit { val } + } ExprConstrained::Var { .. } => todo!(), - ExprConstrained::UnaryOp { .. } => todo!(), - ExprConstrained::BinaryOp { .. } => todo!(), + ExprConstrained::UnaryOp { + op, + expr: expr_inner, + } => Expr::UnaryOp { + op, + expr: Box::new(resolve_expr(*expr_inner, uf)?), + }, + ExprConstrained::BinaryOp { op, exprs: [e1, e2] } => Expr::BinaryOp { + op, + exprs: [ + resolve_expr(*e1, uf)?, + resolve_expr(*e2, uf)? + ].map(Box::new) + }, ExprConstrained::Let { .. } => todo!(), ExprConstrained::If { .. } => todo!(), ExprConstrained::Apply { .. } => todo!(), @@ -129,7 +165,7 @@ fn resolve_expr<'p>( }; Ok(Meta { - meta: typ, + meta: typ.unwrap(), inner: expr, }) } diff --git a/compiler/src/passes/validate/tests.rs b/compiler/src/passes/validate/tests.rs index 677505a..c431a81 100644 --- a/compiler/src/passes/validate/tests.rs +++ b/compiler/src/passes/validate/tests.rs @@ -1,8 +1,8 @@ use crate::passes::parse::parse::parse_program; use crate::passes::validate::error::TypeError; +use crate::passes::validate::PrgValidated; use crate::utils::split_test::split_test; use test_each_file::test_each_file; -use crate::passes::validate::PrgValidated; fn validate([test]: [&str; 1]) { let (_, _, _, expected_error) = split_test(test); @@ -10,9 +10,11 @@ fn validate([test]: [&str; 1]) { let result = parse_program(test).unwrap().validate(); match (result, expected_error) { - (Ok(_), None)=> {} - (Ok(_), Some(expected_error))=> panic!("Expected validation to fail with: {expected_error}."), - (Err(error), None)=> panic!("Expected validation to succeed but got: {error}"), + (Ok(_), None) => {} + (Ok(_), Some(expected_error)) => { + panic!("Expected validation to fail with: {expected_error}.") + } + (Err(error), None) => panic!("Expected validation to succeed but got: {error}"), (Err(_), Some(_)) => {} } } diff --git a/compiler/src/passes/validate/uncover_globals.rs b/compiler/src/passes/validate/uncover_globals.rs index 589488a..a2a726c 100644 --- a/compiler/src/passes/validate/uncover_globals.rs +++ b/compiler/src/passes/validate/uncover_globals.rs @@ -2,7 +2,7 @@ use crate::passes::parse::types::Type; use crate::passes::parse::{Def, Meta, Span, TypeDef}; use crate::passes::validate::generate_constraints::PartialType; use crate::passes::validate::uniquify::PrgUniquified; -use crate::passes::validate::{type_to_index, DefUniquified}; +use crate::passes::validate::DefUniquified; use crate::utils::gen_sym::UniqueSym; use crate::utils::union_find::{UnionFind, UnionIndex}; use std::collections::HashMap; @@ -49,13 +49,10 @@ fn uncover_def<'p>(def: &DefUniquified<'p>, uf: &mut UnionFind>) params: args, typ, .. } => EnvEntry::Type { mutable: false, - typ: type_to_index( - Type::Fn { - typ: Box::new(typ.clone()), - params: args.iter().map(|param| param.typ.clone()).collect(), - }, - uf, - ), + typ: uf.type_to_index(Type::Fn { + typ: Box::new(typ.clone()), + params: args.iter().map(|param| param.typ.clone()).collect(), + }), }, Def::TypeDef { def, .. } => EnvEntry::Def { def: (*def).clone(), diff --git a/compiler/src/utils/union_find.rs b/compiler/src/utils/union_find.rs index bfbd0cf..65bc359 100644 --- a/compiler/src/utils/union_find.rs +++ b/compiler/src/utils/union_find.rs @@ -83,9 +83,13 @@ impl UnionFind { a: UnionIndex, b: UnionIndex, f: impl FnOnce(T, T, &mut Self) -> Result, - ) -> Result where T: Clone { + ) -> Result + where + T: Clone, + { + let v = f(self.data[a.0].clone(), self.data[b.0].clone(), self)?; let root = self.union(a, b); - self.data[root.0] = f(self.data[a.0].clone(), self.data[b.0].clone(), self)?; + self.data[root.0] = v; Ok(root) }