From dd195a0f05cd0e4c889bd4b47facee739ba6b6c8 Mon Sep 17 00:00:00 2001 From: Rigidity Date: Tue, 25 Jun 2024 16:28:27 -0400 Subject: [PATCH] Field type guards --- crates/rue-compiler/src/compiler.rs | 30 +++-- .../src/compiler/expr/binary_expr.rs | 22 ++-- .../src/compiler/expr/field_access_expr.rs | 109 ++++++++++++------ .../src/compiler/expr/guard_expr.rs | 4 +- .../src/compiler/expr/path_expr.rs | 10 +- .../src/compiler/item/enum_item.rs | 1 + .../src/compiler/item/struct_item.rs | 5 +- .../rue-compiler/src/compiler/stmt/if_stmt.rs | 4 +- .../rue-compiler/src/database/type_system.rs | 26 +++-- crates/rue-compiler/src/value.rs | 33 +++++- crates/rue-compiler/src/value/guard.rs | 6 +- crates/rue-compiler/src/value/guard_path.rs | 23 ++++ crates/rue-compiler/src/value/ty.rs | 2 + 13 files changed, 192 insertions(+), 83 deletions(-) create mode 100644 crates/rue-compiler/src/value/guard_path.rs diff --git a/crates/rue-compiler/src/compiler.rs b/crates/rue-compiler/src/compiler.rs index 1f1cd1d..0126f23 100644 --- a/crates/rue-compiler/src/compiler.rs +++ b/crates/rue-compiler/src/compiler.rs @@ -11,7 +11,7 @@ use crate::{ database::{Database, HirId, ScopeId, SymbolId, TypeId}, hir::Hir, scope::Scope, - value::{PairType, Type, Value}, + value::{GuardPath, PairType, Type, Value}, ErrorKind, }; @@ -46,7 +46,7 @@ pub struct Compiler<'a> { type_definition_stack: Vec, // The type guard stack is used for overriding types in certain contexts. - type_guard_stack: Vec>, + type_guard_stack: Vec>, // The generic type stack is used for overriding generic types that are being checked against. generic_type_stack: Vec>, @@ -155,13 +155,19 @@ impl<'a> Compiler<'a> { format!("({first}, {rest})") } Type::Struct(struct_type) => { - let fields: Vec = struct_type - .fields - .iter() - .map(|(name, ty)| format!("{}: {}", name, self.type_name_visitor(*ty, stack))) - .collect(); - - format!("{{ {} }}", fields.join(", ")) + if struct_type.original_type_id == ty { + let fields: Vec = struct_type + .fields + .iter() + .map(|(name, ty)| { + format!("{}: {}", name, self.type_name_visitor(*ty, stack)) + }) + .collect(); + + format!("{{ {} }}", fields.join(", ")) + } else { + self.type_name_visitor(struct_type.original_type_id, stack) + } } Type::Enum { .. } => "".to_string(), Type::EnumVariant(enum_variant) => { @@ -181,7 +187,7 @@ impl<'a> Compiler<'a> { enum_type .variants .iter() - .find(|item| *item.1 == ty) + .find(|item| *item.1 == enum_variant.original_type_id) .expect("enum type is missing variant") .0 .clone() @@ -260,9 +266,9 @@ impl<'a> Compiler<'a> { Value::new(self.builtins.unknown_hir, self.builtins.unknown) } - fn symbol_type(&self, symbol_id: SymbolId) -> Option { + fn symbol_type(&self, guard_path: &GuardPath) -> Option { for guards in &self.type_guard_stack { - if let Some(guard) = guards.get(&symbol_id) { + if let Some(guard) = guards.get(guard_path) { return Some(*guard); } } diff --git a/crates/rue-compiler/src/compiler/expr/binary_expr.rs b/crates/rue-compiler/src/compiler/expr/binary_expr.rs index c91f911..7a96ec9 100644 --- a/crates/rue-compiler/src/compiler/expr/binary_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/binary_expr.rs @@ -167,11 +167,12 @@ impl Compiler<'_> { .compare_type(lhs.type_id, self.builtins.nil) .is_equal() { - if let Hir::Reference(symbol_id) = self.db.hir(rhs.hir_id) { - value.guards.insert( - *symbol_id, - Guard::new(self.builtins.nil, self.db.non_optional(rhs.type_id)), - ); + if let Some(guard_path) = rhs.guard_path { + let then_type = self.builtins.nil; + let else_type = self.db.non_optional(rhs.type_id); + value + .guards + .insert(guard_path, Guard::new(then_type, else_type)); } } @@ -180,11 +181,12 @@ impl Compiler<'_> { .compare_type(rhs.type_id, self.builtins.nil) .is_equal() { - if let Hir::Reference(symbol_id) = self.db.hir(lhs.hir_id) { - value.guards.insert( - *symbol_id, - Guard::new(self.builtins.nil, self.db.non_optional(lhs.type_id)), - ); + if let Some(guard_path) = lhs.guard_path.clone() { + let then_type = self.builtins.nil; + let else_type = self.db.non_optional(lhs.type_id); + value + .guards + .insert(guard_path, Guard::new(then_type, else_type)); } } diff --git a/crates/rue-compiler/src/compiler/expr/field_access_expr.rs b/crates/rue-compiler/src/compiler/expr/field_access_expr.rs index 4b28adc..deec0cd 100644 --- a/crates/rue-compiler/src/compiler/expr/field_access_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/field_access_expr.rs @@ -3,14 +3,14 @@ use rue_parser::FieldAccessExpr; use crate::{ compiler::Compiler, hir::Hir, - value::{Guard, PairType, Type, Value}, + value::{Guard, GuardPathItem, PairType, Type, Value}, ErrorKind, }; impl Compiler<'_> { /// Compiles a field access expression, or special properties for certain types. pub fn compile_field_access_expr(&mut self, field_access: &FieldAccessExpr) -> Value { - let Some(value) = field_access + let Some(old_value) = field_access .expr() .map(|expr| self.compile_expr(&expr, None)) else { @@ -21,59 +21,100 @@ impl Compiler<'_> { return self.unknown(); }; - match self.db.ty(value.type_id).clone() { + let mut new_value = match self.db.ty(old_value.type_id).clone() { Type::Struct(struct_type) => { if let Some(field) = struct_type.fields.get_full(field_name.text()) { let (index, _, field_type) = field; - return Value::new(self.compile_index(value.hir_id, index, false), *field_type); + Value::new( + self.compile_index(old_value.hir_id, index, false), + *field_type, + ) + .extend_guard_path(old_value, GuardPathItem::Field(field_name.to_string())) + } else { + self.db.error( + ErrorKind::UndefinedField { + field: field_name.to_string(), + ty: self.type_name(old_value.type_id), + }, + field_name.text_range(), + ); + return self.unknown(); + } + } + Type::EnumVariant(variant_type) => { + if let Some(field) = variant_type.fields.get_full(field_name.text()) { + let (index, _, field_type) = field; + Value::new( + self.compile_index(old_value.hir_id, index, false), + *field_type, + ) + .extend_guard_path(old_value, GuardPathItem::Field(field_name.to_string())) + } else { + self.db.error( + ErrorKind::UndefinedField { + field: field_name.to_string(), + ty: self.type_name(old_value.type_id), + }, + field_name.text_range(), + ); + return self.unknown(); } - self.db.error( - ErrorKind::UndefinedField { - field: field_name.to_string(), - ty: self.type_name(value.type_id), - }, - field_name.text_range(), - ); - return self.unknown(); } Type::Pair(PairType { first, rest }) => match field_name.text() { "first" => { - return Value::new(self.db.alloc_hir(Hir::First(value.hir_id)), first); + return Value::new(self.db.alloc_hir(Hir::First(old_value.hir_id)), first) + .extend_guard_path(old_value, GuardPathItem::First); } "rest" => { - return Value::new(self.db.alloc_hir(Hir::Rest(value.hir_id)), rest); + return Value::new(self.db.alloc_hir(Hir::Rest(old_value.hir_id)), rest) + .extend_guard_path(old_value, GuardPathItem::Rest); + } + _ => { + self.db.error( + ErrorKind::InvalidFieldAccess { + field: field_name.to_string(), + ty: self.type_name(old_value.type_id), + }, + field_name.text_range(), + ); + return self.unknown(); } - _ => {} }, - Type::Bytes | Type::Bytes32 if field_name.text() == "length" => { - return Value::new( - self.db.alloc_hir(Hir::Strlen(value.hir_id)), - self.builtins.int, - ); - } + Type::Bytes | Type::Bytes32 if field_name.text() == "length" => Value::new( + self.db.alloc_hir(Hir::Strlen(old_value.hir_id)), + self.builtins.int, + ), Type::PossiblyUndefined(inner) if field_name.text() == "exists" => { - let maybe_nil_reference = self.db.alloc_hir(Hir::CheckExists(value.hir_id)); + let maybe_nil_reference = self.db.alloc_hir(Hir::CheckExists(old_value.hir_id)); let exists = self.db.alloc_hir(Hir::IsCons(maybe_nil_reference)); let mut new_value = Value::new(exists, self.builtins.bool); - if let Hir::Reference(symbol_id) = self.db.hir(value.hir_id).clone() { + if let Some(guard_path) = old_value.guard_path { new_value .guards - .insert(symbol_id, Guard::new(inner, value.type_id)); + .insert(guard_path, Guard::new(inner, old_value.type_id)); } - return new_value; + new_value + } + _ => { + self.db.error( + ErrorKind::InvalidFieldAccess { + field: field_name.to_string(), + ty: self.type_name(old_value.type_id), + }, + field_name.text_range(), + ); + return self.unknown(); + } + }; + + if let Some(guard_path) = new_value.guard_path.as_ref() { + if let Some(type_id) = self.symbol_type(guard_path) { + new_value.type_id = type_id; } - _ => {} } - self.db.error( - ErrorKind::InvalidFieldAccess { - field: field_name.to_string(), - ty: self.type_name(value.type_id), - }, - field_name.text_range(), - ); - self.unknown() + new_value } } diff --git a/crates/rue-compiler/src/compiler/expr/guard_expr.rs b/crates/rue-compiler/src/compiler/expr/guard_expr.rs index 5f4b521..698ec04 100644 --- a/crates/rue-compiler/src/compiler/expr/guard_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/guard_expr.rs @@ -33,8 +33,8 @@ impl Compiler<'_> { let mut value = Value::new(hir_id, self.builtins.bool); - if let Hir::Reference(symbol_id) = self.db.hir(expr.hir_id) { - value.guards.insert(*symbol_id, guard); + if let Some(guard_path) = expr.guard_path { + value.guards.insert(guard_path, guard); } value diff --git a/crates/rue-compiler/src/compiler/expr/path_expr.rs b/crates/rue-compiler/src/compiler/expr/path_expr.rs index 199b47c..de8f7d0 100644 --- a/crates/rue-compiler/src/compiler/expr/path_expr.rs +++ b/crates/rue-compiler/src/compiler/expr/path_expr.rs @@ -8,7 +8,7 @@ use crate::{ }, hir::Hir, symbol::{Function, Symbol}, - value::{Type, Value}, + value::{GuardPath, Type, Value}, ErrorKind, }; @@ -59,9 +59,9 @@ impl Compiler<'_> { return self.unknown(); } - let override_type_id = self.symbol_type(symbol_id); + let override_type_id = self.symbol_type(&GuardPath::new(symbol_id)); - match self.db.symbol(symbol_id).clone() { + let mut value = match self.db.symbol(symbol_id).clone() { Symbol::Unknown | Symbol::Module(..) => unreachable!(), Symbol::Function(Function { ty, .. }) | Symbol::InlineFunction(Function { ty, .. }) => { let type_id = self.db.alloc_type(Type::Function(ty.clone())); @@ -81,6 +81,8 @@ impl Compiler<'_> { value.hir_id = self.db.alloc_hir(Hir::Reference(symbol_id)); value } - } + }; + value.guard_path = Some(GuardPath::new(symbol_id)); + value } } diff --git a/crates/rue-compiler/src/compiler/item/enum_item.rs b/crates/rue-compiler/src/compiler/item/enum_item.rs index f7deeea..a96bd5a 100644 --- a/crates/rue-compiler/src/compiler/item/enum_item.rs +++ b/crates/rue-compiler/src/compiler/item/enum_item.rs @@ -147,6 +147,7 @@ impl Compiler<'_> { // Update the variant to use the real `EnumVariant` type. *self.db.ty_mut(variant_type_id) = Type::EnumVariant(EnumVariantType { enum_type: enum_type_id, + original_type_id: variant_type_id, fields, discriminant, }); diff --git a/crates/rue-compiler/src/compiler/item/struct_item.rs b/crates/rue-compiler/src/compiler/item/struct_item.rs index 6443296..861a01f 100644 --- a/crates/rue-compiler/src/compiler/item/struct_item.rs +++ b/crates/rue-compiler/src/compiler/item/struct_item.rs @@ -22,7 +22,10 @@ impl Compiler<'_> { pub fn compile_struct_item(&mut self, struct_item: &StructItem, type_id: TypeId) { self.type_definition_stack.push(type_id); let fields = self.compile_struct_fields(struct_item.fields()); - *self.db.ty_mut(type_id) = Type::Struct(StructType { fields }); + *self.db.ty_mut(type_id) = Type::Struct(StructType { + original_type_id: type_id, + fields, + }); self.type_definition_stack.pop().unwrap(); } diff --git a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs index 8a42791..523092e 100644 --- a/crates/rue-compiler/src/compiler/stmt/if_stmt.rs +++ b/crates/rue-compiler/src/compiler/stmt/if_stmt.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use rue_parser::{AstNode, IfStmt}; -use crate::{compiler::Compiler, scope::Scope, ErrorKind, HirId, SymbolId, TypeId}; +use crate::{compiler::Compiler, scope::Scope, value::GuardPath, ErrorKind, HirId, TypeId}; impl Compiler<'_> { /// Compiles an if statement, returning the condition HIR, then block HIR, and else block guards. @@ -10,7 +10,7 @@ impl Compiler<'_> { &mut self, if_stmt: &IfStmt, expected_type: Option, - ) -> (HirId, HirId, HashMap) { + ) -> (HirId, HirId, HashMap) { // Compile the condition expression. let condition = if_stmt .condition() diff --git a/crates/rue-compiler/src/database/type_system.rs b/crates/rue-compiler/src/database/type_system.rs index 157ff94..932c554 100644 --- a/crates/rue-compiler/src/database/type_system.rs +++ b/crates/rue-compiler/src/database/type_system.rs @@ -136,6 +136,7 @@ impl Database { Type::EnumVariant(enum_variant) => { let new_variant = EnumVariantType { enum_type: enum_variant.enum_type, + original_type_id: enum_variant.original_type_id, fields: enum_variant .fields .iter() @@ -166,6 +167,7 @@ impl Database { } Type::Struct(struct_type) => { let new_struct = StructType { + original_type_id: struct_type.original_type_id, fields: struct_type .fields .iter() @@ -713,19 +715,25 @@ mod tests { fn test_struct_types() { let (mut db, ty) = setup(); - let two_ints = db.alloc_type(Type::Struct(StructType { + let two_ints = db.alloc_type(Type::Unknown); + *db.ty_mut(two_ints) = Type::Struct(StructType { + original_type_id: two_ints, fields: fields(&[ty.int, ty.int]), - })); + }); assert_eq!(db.compare_type(two_ints, two_ints), Comparison::Equal); - let one_int = db.alloc_type(Type::Struct(StructType { + let one_int = db.alloc_type(Type::Unknown); + *db.ty_mut(one_int) = Type::Struct(StructType { + original_type_id: one_int, fields: fields(&[ty.int]), - })); + }); assert_eq!(db.compare_type(one_int, two_ints), Comparison::Unrelated); - let empty_struct = db.alloc_type(Type::Struct(StructType { + let empty_struct = db.alloc_type(Type::Unknown); + *db.ty_mut(empty_struct) = Type::Struct(StructType { + original_type_id: empty_struct, fields: fields(&[]), - })); + }); assert_eq!( db.compare_type(empty_struct, empty_struct), Comparison::Equal @@ -738,11 +746,13 @@ mod tests { let enum_type = db.alloc_type(Type::Unknown); - let variant = db.alloc_type(Type::EnumVariant(EnumVariantType { + let variant = db.alloc_type(Type::Unknown); + *db.ty_mut(variant) = Type::EnumVariant(EnumVariantType { enum_type, + original_type_id: variant, fields: fields(&[ty.int]), discriminant: ty.unknown_hir, - })); + }); *db.ty_mut(enum_type) = Type::Enum(EnumType { has_fields: true, diff --git a/crates/rue-compiler/src/value.rs b/crates/rue-compiler/src/value.rs index 4b50df1..18ba279 100644 --- a/crates/rue-compiler/src/value.rs +++ b/crates/rue-compiler/src/value.rs @@ -1,18 +1,21 @@ use std::collections::HashMap; mod guard; +mod guard_path; mod ty; pub use guard::*; +pub use guard_path::*; pub use ty::*; -use crate::{HirId, SymbolId, TypeId}; +use crate::{HirId, TypeId}; #[derive(Debug, Clone)] pub struct Value { pub hir_id: HirId, pub type_id: TypeId, - pub guards: Guards, + pub guards: HashMap, + pub guard_path: Option, } impl Value { @@ -21,14 +24,32 @@ impl Value { hir_id, type_id, guards: HashMap::new(), + guard_path: None, } } - pub fn then_guards(&self) -> HashMap { - self.guards.iter().map(|(k, v)| (*k, v.then_type)).collect() + pub fn then_guards(&self) -> HashMap { + self.guards + .iter() + .map(|(k, v)| (k.clone(), v.then_type)) + .collect() } - pub fn else_guards(&self) -> HashMap { - self.guards.iter().map(|(k, v)| (*k, v.else_type)).collect() + pub fn else_guards(&self) -> HashMap { + self.guards + .iter() + .map(|(k, v)| (k.clone(), v.else_type)) + .collect() + } + + pub fn extend_guard_path(mut self, old_value: Value, item: GuardPathItem) -> Self { + match old_value.guard_path { + Some(mut path) => { + path.items.push(item); + self.guard_path = Some(path); + self + } + None => self, + } } } diff --git a/crates/rue-compiler/src/value/guard.rs b/crates/rue-compiler/src/value/guard.rs index dd6d00b..1389e33 100644 --- a/crates/rue-compiler/src/value/guard.rs +++ b/crates/rue-compiler/src/value/guard.rs @@ -1,8 +1,6 @@ -use std::{collections::HashMap, ops::Not}; +use std::ops::Not; -use crate::{SymbolId, TypeId}; - -pub type Guards = HashMap; +use crate::TypeId; #[derive(Debug, Clone, Copy)] pub struct Guard { diff --git a/crates/rue-compiler/src/value/guard_path.rs b/crates/rue-compiler/src/value/guard_path.rs new file mode 100644 index 0000000..9b1e9a1 --- /dev/null +++ b/crates/rue-compiler/src/value/guard_path.rs @@ -0,0 +1,23 @@ +use crate::SymbolId; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct GuardPath { + pub symbol_id: SymbolId, + pub items: Vec, +} + +impl GuardPath { + pub fn new(symbol_id: SymbolId) -> Self { + Self { + symbol_id, + items: Vec::new(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum GuardPathItem { + Field(String), + First, + Rest, +} diff --git a/crates/rue-compiler/src/value/ty.rs b/crates/rue-compiler/src/value/ty.rs index c682533..7a23c0f 100644 --- a/crates/rue-compiler/src/value/ty.rs +++ b/crates/rue-compiler/src/value/ty.rs @@ -32,6 +32,7 @@ pub struct PairType { #[derive(Debug, Clone, PartialEq, Eq)] pub struct StructType { + pub original_type_id: TypeId, pub fields: IndexMap, } @@ -44,6 +45,7 @@ pub struct EnumType { #[derive(Debug, Clone, PartialEq, Eq)] pub struct EnumVariantType { pub enum_type: TypeId, + pub original_type_id: TypeId, pub fields: IndexMap, pub discriminant: HirId, }