Skip to content

Commit

Permalink
Field type guards
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jun 25, 2024
1 parent e143485 commit dd195a0
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 83 deletions.
30 changes: 18 additions & 12 deletions crates/rue-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
database::{Database, HirId, ScopeId, SymbolId, TypeId},
hir::Hir,
scope::Scope,
value::{PairType, Type, Value},
value::{GuardPath, PairType, Type, Value},
ErrorKind,
};

Expand Down Expand Up @@ -46,7 +46,7 @@ pub struct Compiler<'a> {
type_definition_stack: Vec<TypeId>,

// The type guard stack is used for overriding types in certain contexts.
type_guard_stack: Vec<HashMap<SymbolId, TypeId>>,
type_guard_stack: Vec<HashMap<GuardPath, TypeId>>,

// The generic type stack is used for overriding generic types that are being checked against.
generic_type_stack: Vec<HashMap<TypeId, TypeId>>,
Expand Down Expand Up @@ -155,13 +155,19 @@ impl<'a> Compiler<'a> {
format!("({first}, {rest})")
}
Type::Struct(struct_type) => {
let fields: Vec<String> = struct_type
.fields
.iter()
.map(|(name, ty)| format!("{}: {}", name, self.type_name_visitor(*ty, stack)))
.collect();

format!("{{ {} }}", fields.join(", "))
if struct_type.original_type_id == ty {
let fields: Vec<String> = struct_type
.fields
.iter()
.map(|(name, ty)| {
format!("{}: {}", name, self.type_name_visitor(*ty, stack))
})
.collect();

format!("{{ {} }}", fields.join(", "))
} else {
self.type_name_visitor(struct_type.original_type_id, stack)
}
}
Type::Enum { .. } => "<unnamed enum>".to_string(),
Type::EnumVariant(enum_variant) => {
Expand All @@ -181,7 +187,7 @@ impl<'a> Compiler<'a> {
enum_type
.variants
.iter()
.find(|item| *item.1 == ty)
.find(|item| *item.1 == enum_variant.original_type_id)
.expect("enum type is missing variant")
.0
.clone()
Expand Down Expand Up @@ -260,9 +266,9 @@ impl<'a> Compiler<'a> {
Value::new(self.builtins.unknown_hir, self.builtins.unknown)
}

fn symbol_type(&self, symbol_id: SymbolId) -> Option<TypeId> {
fn symbol_type(&self, guard_path: &GuardPath) -> Option<TypeId> {
for guards in &self.type_guard_stack {
if let Some(guard) = guards.get(&symbol_id) {
if let Some(guard) = guards.get(guard_path) {
return Some(*guard);
}
}
Expand Down
22 changes: 12 additions & 10 deletions crates/rue-compiler/src/compiler/expr/binary_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,12 @@ impl Compiler<'_> {
.compare_type(lhs.type_id, self.builtins.nil)
.is_equal()
{
if let Hir::Reference(symbol_id) = self.db.hir(rhs.hir_id) {
value.guards.insert(
*symbol_id,
Guard::new(self.builtins.nil, self.db.non_optional(rhs.type_id)),
);
if let Some(guard_path) = rhs.guard_path {
let then_type = self.builtins.nil;
let else_type = self.db.non_optional(rhs.type_id);
value
.guards
.insert(guard_path, Guard::new(then_type, else_type));
}
}

Expand All @@ -180,11 +181,12 @@ impl Compiler<'_> {
.compare_type(rhs.type_id, self.builtins.nil)
.is_equal()
{
if let Hir::Reference(symbol_id) = self.db.hir(lhs.hir_id) {
value.guards.insert(
*symbol_id,
Guard::new(self.builtins.nil, self.db.non_optional(lhs.type_id)),
);
if let Some(guard_path) = lhs.guard_path.clone() {
let then_type = self.builtins.nil;
let else_type = self.db.non_optional(lhs.type_id);
value
.guards
.insert(guard_path, Guard::new(then_type, else_type));
}
}

Expand Down
109 changes: 75 additions & 34 deletions crates/rue-compiler/src/compiler/expr/field_access_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ use rue_parser::FieldAccessExpr;
use crate::{
compiler::Compiler,
hir::Hir,
value::{Guard, PairType, Type, Value},
value::{Guard, GuardPathItem, PairType, Type, Value},
ErrorKind,
};

impl Compiler<'_> {
/// Compiles a field access expression, or special properties for certain types.
pub fn compile_field_access_expr(&mut self, field_access: &FieldAccessExpr) -> Value {
let Some(value) = field_access
let Some(old_value) = field_access
.expr()
.map(|expr| self.compile_expr(&expr, None))
else {
Expand All @@ -21,59 +21,100 @@ impl Compiler<'_> {
return self.unknown();
};

match self.db.ty(value.type_id).clone() {
let mut new_value = match self.db.ty(old_value.type_id).clone() {
Type::Struct(struct_type) => {
if let Some(field) = struct_type.fields.get_full(field_name.text()) {
let (index, _, field_type) = field;
return Value::new(self.compile_index(value.hir_id, index, false), *field_type);
Value::new(
self.compile_index(old_value.hir_id, index, false),
*field_type,
)
.extend_guard_path(old_value, GuardPathItem::Field(field_name.to_string()))
} else {
self.db.error(
ErrorKind::UndefinedField {
field: field_name.to_string(),
ty: self.type_name(old_value.type_id),
},
field_name.text_range(),
);
return self.unknown();
}
}
Type::EnumVariant(variant_type) => {
if let Some(field) = variant_type.fields.get_full(field_name.text()) {
let (index, _, field_type) = field;
Value::new(
self.compile_index(old_value.hir_id, index, false),
*field_type,
)
.extend_guard_path(old_value, GuardPathItem::Field(field_name.to_string()))
} else {
self.db.error(
ErrorKind::UndefinedField {
field: field_name.to_string(),
ty: self.type_name(old_value.type_id),
},
field_name.text_range(),
);
return self.unknown();
}
self.db.error(
ErrorKind::UndefinedField {
field: field_name.to_string(),
ty: self.type_name(value.type_id),
},
field_name.text_range(),
);
return self.unknown();
}
Type::Pair(PairType { first, rest }) => match field_name.text() {
"first" => {
return Value::new(self.db.alloc_hir(Hir::First(value.hir_id)), first);
return Value::new(self.db.alloc_hir(Hir::First(old_value.hir_id)), first)
.extend_guard_path(old_value, GuardPathItem::First);
}
"rest" => {
return Value::new(self.db.alloc_hir(Hir::Rest(value.hir_id)), rest);
return Value::new(self.db.alloc_hir(Hir::Rest(old_value.hir_id)), rest)
.extend_guard_path(old_value, GuardPathItem::Rest);
}
_ => {
self.db.error(
ErrorKind::InvalidFieldAccess {
field: field_name.to_string(),
ty: self.type_name(old_value.type_id),
},
field_name.text_range(),
);
return self.unknown();
}
_ => {}
},
Type::Bytes | Type::Bytes32 if field_name.text() == "length" => {
return Value::new(
self.db.alloc_hir(Hir::Strlen(value.hir_id)),
self.builtins.int,
);
}
Type::Bytes | Type::Bytes32 if field_name.text() == "length" => Value::new(
self.db.alloc_hir(Hir::Strlen(old_value.hir_id)),
self.builtins.int,
),
Type::PossiblyUndefined(inner) if field_name.text() == "exists" => {
let maybe_nil_reference = self.db.alloc_hir(Hir::CheckExists(value.hir_id));
let maybe_nil_reference = self.db.alloc_hir(Hir::CheckExists(old_value.hir_id));
let exists = self.db.alloc_hir(Hir::IsCons(maybe_nil_reference));
let mut new_value = Value::new(exists, self.builtins.bool);

if let Hir::Reference(symbol_id) = self.db.hir(value.hir_id).clone() {
if let Some(guard_path) = old_value.guard_path {
new_value
.guards
.insert(symbol_id, Guard::new(inner, value.type_id));
.insert(guard_path, Guard::new(inner, old_value.type_id));
}

return new_value;
new_value
}
_ => {
self.db.error(
ErrorKind::InvalidFieldAccess {
field: field_name.to_string(),
ty: self.type_name(old_value.type_id),
},
field_name.text_range(),
);
return self.unknown();
}
};

if let Some(guard_path) = new_value.guard_path.as_ref() {
if let Some(type_id) = self.symbol_type(guard_path) {
new_value.type_id = type_id;
}
_ => {}
}

self.db.error(
ErrorKind::InvalidFieldAccess {
field: field_name.to_string(),
ty: self.type_name(value.type_id),
},
field_name.text_range(),
);
self.unknown()
new_value
}
}
4 changes: 2 additions & 2 deletions crates/rue-compiler/src/compiler/expr/guard_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ impl Compiler<'_> {

let mut value = Value::new(hir_id, self.builtins.bool);

if let Hir::Reference(symbol_id) = self.db.hir(expr.hir_id) {
value.guards.insert(*symbol_id, guard);
if let Some(guard_path) = expr.guard_path {
value.guards.insert(guard_path, guard);
}

value
Expand Down
10 changes: 6 additions & 4 deletions crates/rue-compiler/src/compiler/expr/path_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
},
hir::Hir,
symbol::{Function, Symbol},
value::{Type, Value},
value::{GuardPath, Type, Value},
ErrorKind,
};

Expand Down Expand Up @@ -59,9 +59,9 @@ impl Compiler<'_> {
return self.unknown();
}

let override_type_id = self.symbol_type(symbol_id);
let override_type_id = self.symbol_type(&GuardPath::new(symbol_id));

match self.db.symbol(symbol_id).clone() {
let mut value = match self.db.symbol(symbol_id).clone() {
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
Symbol::Function(Function { ty, .. }) | Symbol::InlineFunction(Function { ty, .. }) => {
let type_id = self.db.alloc_type(Type::Function(ty.clone()));
Expand All @@ -81,6 +81,8 @@ impl Compiler<'_> {
value.hir_id = self.db.alloc_hir(Hir::Reference(symbol_id));
value
}
}
};
value.guard_path = Some(GuardPath::new(symbol_id));
value
}
}
1 change: 1 addition & 0 deletions crates/rue-compiler/src/compiler/item/enum_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ impl Compiler<'_> {
// Update the variant to use the real `EnumVariant` type.
*self.db.ty_mut(variant_type_id) = Type::EnumVariant(EnumVariantType {
enum_type: enum_type_id,
original_type_id: variant_type_id,
fields,
discriminant,
});
Expand Down
5 changes: 4 additions & 1 deletion crates/rue-compiler/src/compiler/item/struct_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ impl Compiler<'_> {
pub fn compile_struct_item(&mut self, struct_item: &StructItem, type_id: TypeId) {
self.type_definition_stack.push(type_id);
let fields = self.compile_struct_fields(struct_item.fields());
*self.db.ty_mut(type_id) = Type::Struct(StructType { fields });
*self.db.ty_mut(type_id) = Type::Struct(StructType {
original_type_id: type_id,
fields,
});
self.type_definition_stack.pop().unwrap();
}

Expand Down
4 changes: 2 additions & 2 deletions crates/rue-compiler/src/compiler/stmt/if_stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use std::collections::HashMap;

use rue_parser::{AstNode, IfStmt};

use crate::{compiler::Compiler, scope::Scope, ErrorKind, HirId, SymbolId, TypeId};
use crate::{compiler::Compiler, scope::Scope, value::GuardPath, ErrorKind, HirId, TypeId};

impl Compiler<'_> {
/// Compiles an if statement, returning the condition HIR, then block HIR, and else block guards.
pub fn compile_if_stmt(
&mut self,
if_stmt: &IfStmt,
expected_type: Option<TypeId>,
) -> (HirId, HirId, HashMap<SymbolId, TypeId>) {
) -> (HirId, HirId, HashMap<GuardPath, TypeId>) {
// Compile the condition expression.
let condition = if_stmt
.condition()
Expand Down
26 changes: 18 additions & 8 deletions crates/rue-compiler/src/database/type_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl Database {
Type::EnumVariant(enum_variant) => {
let new_variant = EnumVariantType {
enum_type: enum_variant.enum_type,
original_type_id: enum_variant.original_type_id,
fields: enum_variant
.fields
.iter()
Expand Down Expand Up @@ -166,6 +167,7 @@ impl Database {
}
Type::Struct(struct_type) => {
let new_struct = StructType {
original_type_id: struct_type.original_type_id,
fields: struct_type
.fields
.iter()
Expand Down Expand Up @@ -713,19 +715,25 @@ mod tests {
fn test_struct_types() {
let (mut db, ty) = setup();

let two_ints = db.alloc_type(Type::Struct(StructType {
let two_ints = db.alloc_type(Type::Unknown);
*db.ty_mut(two_ints) = Type::Struct(StructType {
original_type_id: two_ints,
fields: fields(&[ty.int, ty.int]),
}));
});
assert_eq!(db.compare_type(two_ints, two_ints), Comparison::Equal);

let one_int = db.alloc_type(Type::Struct(StructType {
let one_int = db.alloc_type(Type::Unknown);
*db.ty_mut(one_int) = Type::Struct(StructType {
original_type_id: one_int,
fields: fields(&[ty.int]),
}));
});
assert_eq!(db.compare_type(one_int, two_ints), Comparison::Unrelated);

let empty_struct = db.alloc_type(Type::Struct(StructType {
let empty_struct = db.alloc_type(Type::Unknown);
*db.ty_mut(empty_struct) = Type::Struct(StructType {
original_type_id: empty_struct,
fields: fields(&[]),
}));
});
assert_eq!(
db.compare_type(empty_struct, empty_struct),
Comparison::Equal
Expand All @@ -738,11 +746,13 @@ mod tests {

let enum_type = db.alloc_type(Type::Unknown);

let variant = db.alloc_type(Type::EnumVariant(EnumVariantType {
let variant = db.alloc_type(Type::Unknown);
*db.ty_mut(variant) = Type::EnumVariant(EnumVariantType {
enum_type,
original_type_id: variant,
fields: fields(&[ty.int]),
discriminant: ty.unknown_hir,
}));
});

*db.ty_mut(enum_type) = Type::Enum(EnumType {
has_fields: true,
Expand Down
Loading

0 comments on commit dd195a0

Please sign in to comment.