Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type guard redesign #29

Merged
merged 6 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@ jobs:
- name: Checkout
uses: actions/checkout@v4

- name: Cargo binstall
run: curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash

- name: Instal cargo-workspaces
run: cargo binstall cargo-workspaces --locked -y
- name: Install cargo-workspaces
run: cargo install cargo-workspaces

- name: Run tests
run: cargo test --all-features --workspace
Expand All @@ -31,7 +28,7 @@ jobs:

- name: Unused dependencies
run: |
cargo binstall cargo-machete --locked -y
cargo install cargo-machete --locked
cargo machete

- name: Fmt
Expand Down
54 changes: 46 additions & 8 deletions crates/rue-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
database::{Database, HirId, ScopeId, SymbolId},
hir::{Hir, Op},
scope::Scope,
symbol::{Function, Symbol},
value::{GuardPath, Value},
ErrorKind,
};
Expand Down Expand Up @@ -46,8 +47,8 @@ pub struct Compiler<'a> {
// The type definition stack is used for calculating types referenced in types.
type_definition_stack: Vec<TypeId>,

// The type guard stack is used for overriding types in certain contexts.
type_guard_stack: Vec<HashMap<GuardPath, TypeId>>,
// Overridden symbol types due to type guards.
type_overrides: Vec<HashMap<SymbolId, TypeId>>,

// The generic type stack is used for overriding generic types that are being checked against.
generic_type_stack: Vec<HashMap<TypeId, TypeId>>,
Expand All @@ -74,7 +75,7 @@ impl<'a> Compiler<'a> {
scope_stack: vec![builtins.scope_id],
symbol_stack: Vec::new(),
type_definition_stack: Vec::new(),
type_guard_stack: Vec::new(),
type_overrides: Vec::new(),
generic_type_stack: Vec::new(),
allow_generic_inference_stack: vec![false],
is_callee: false,
Expand Down Expand Up @@ -169,13 +170,50 @@ impl<'a> Compiler<'a> {
Value::new(self.builtins.unknown, self.ty.std().unknown)
}

fn symbol_type(&self, guard_path: &GuardPath) -> Option<TypeId> {
for guards in self.type_guard_stack.iter().rev() {
if let Some(guard) = guards.get(guard_path) {
return Some(*guard);
fn build_overrides(&mut self, guards: HashMap<GuardPath, TypeId>) -> HashMap<SymbolId, TypeId> {
type GuardItem = (Vec<TypePath>, TypeId);

let mut symbol_guards: HashMap<SymbolId, Vec<GuardItem>> = HashMap::new();

for (guard_path, type_id) in guards {
symbol_guards
.entry(guard_path.symbol_id)
.or_default()
.push((guard_path.items, type_id));
}

let mut overrides = HashMap::new();

for (symbol_id, mut items) in symbol_guards {
// Order by length.
items.sort_by_key(|(items, _)| items.len());

let mut type_id = self.symbol_type(symbol_id);

for (path_items, new_type_id) in items {
type_id = self.ty.replace(type_id, new_type_id, &path_items);
}

overrides.insert(symbol_id, type_id);
}

overrides
}

fn symbol_type(&self, symbol_id: SymbolId) -> TypeId {
for guards in self.type_overrides.iter().rev() {
if let Some(type_id) = guards.get(&symbol_id) {
return *type_id;
}
}
None

match self.db.symbol(symbol_id) {
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
Symbol::Function(Function { type_id, .. })
| Symbol::InlineFunction(Function { type_id, .. })
| Symbol::Parameter(type_id) => *type_id,
Symbol::Let(value) | Symbol::Const(value) | Symbol::InlineConst(value) => value.type_id,
}
}

fn scope(&self) -> &Scope {
Expand Down
14 changes: 8 additions & 6 deletions crates/rue-compiler/src/compiler/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ impl Compiler<'_> {

// Push the type guards onto the stack.
// This will be popped in reverse order later after all statements have been lowered.
self.type_guard_stack.push(else_guards);
let overrides = self.build_overrides(else_guards);
self.type_overrides.push(overrides);

statements.push(Statement::If(condition_hir, then_hir));
}
Expand Down Expand Up @@ -103,8 +104,8 @@ impl Compiler<'_> {
// If the condition is false, we raise an error.
// So we can assume that the condition is true from this point on.
// This will be popped in reverse order later after all statements have been lowered.

self.type_guard_stack.push(condition.then_guards());
let overrides = self.build_overrides(condition.then_guards());
self.type_overrides.push(overrides);

let not_condition = self.db.alloc_hir(Hir::Op(Op::Not, condition.hir_id));
let raise = self.db.alloc_hir(Hir::Raise(None));
Expand All @@ -126,7 +127,8 @@ impl Compiler<'_> {
assume_stmt.syntax().text_range(),
);

self.type_guard_stack.push(expr.then_guards());
let overrides = self.build_overrides(expr.then_guards());
self.type_overrides.push(overrides);
statements.push(Statement::Assume);
}
}
Expand Down Expand Up @@ -158,7 +160,7 @@ impl Compiler<'_> {
body = value;
}
Statement::If(condition, then_block) => {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

body = Value::new(
self.db
Expand All @@ -167,7 +169,7 @@ impl Compiler<'_> {
);
}
Statement::Assume => {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();
}
}
}
Expand Down
26 changes: 18 additions & 8 deletions crates/rue-compiler/src/compiler/expr/binary_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl Compiler<'_> {
let else_type = self.ty.difference(rhs.type_id, self.ty.std().nil);
value
.guards
.insert(guard_path, Guard::new(then_type, else_type));
.insert(guard_path, Guard::new(Some(then_type), Some(else_type)));
}
}

Expand All @@ -182,7 +182,7 @@ impl Compiler<'_> {
let else_type = self.ty.difference(lhs.type_id, self.ty.std().nil);
value
.guards
.insert(guard_path, Guard::new(then_type, else_type));
.insert(guard_path, Guard::new(Some(then_type), Some(else_type)));
}
}

Expand Down Expand Up @@ -250,13 +250,14 @@ impl Compiler<'_> {
}

fn op_and(&mut self, lhs: Value, rhs: Option<&Expr>, text_range: TextRange) -> Value {
self.type_guard_stack.push(lhs.then_guards());
let overrides = self.build_overrides(lhs.then_guards());
self.type_overrides.push(overrides);

let rhs = rhs
.map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool)))
.unwrap_or_else(|| self.unknown());

self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

self.type_check(lhs.type_id, self.ty.std().bool, text_range);
self.type_check(rhs.type_id, self.ty.std().bool, text_range);
Expand All @@ -267,19 +268,28 @@ impl Compiler<'_> {
rhs.hir_id,
self.ty.std().bool,
);
value.guards.extend(lhs.guards);
value.guards.extend(rhs.guards);
value.guards.extend(
lhs.guards
.into_iter()
.map(|(path, guard)| (path, Guard::new(guard.then_type, None))),
);
value.guards.extend(
rhs.guards
.into_iter()
.map(|(path, guard)| (path, Guard::new(guard.then_type, None))),
);
value
}

fn op_or(&mut self, lhs: &Value, rhs: Option<&Expr>, text_range: TextRange) -> Value {
self.type_guard_stack.push(lhs.then_guards());
let overrides = self.build_overrides(lhs.else_guards());
self.type_overrides.push(overrides);

let rhs = rhs
.map(|rhs| self.compile_expr(rhs, Some(self.ty.std().bool)))
.unwrap_or_else(|| self.unknown());

self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

self.type_check(lhs.type_id, self.ty.std().bool, text_range);
self.type_check(rhs.type_id, self.ty.std().bool, text_range);
Expand Down
18 changes: 5 additions & 13 deletions crates/rue-compiler/src/compiler/expr/field_access_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ impl Compiler<'_> {
return self.unknown();
};

let mut new_value = match self.ty.get(old_value.type_id).clone() {
Type::Unknown => return self.unknown(),
match self.ty.get(old_value.type_id).clone() {
Type::Unknown => self.unknown(),
Type::Struct(ty) => {
let Some(value) = self.compile_struct_field_access(old_value, &ty, &name) else {
return self.unknown();
Expand Down Expand Up @@ -55,17 +55,9 @@ impl Compiler<'_> {
),
name.text_range(),
);
return self.unknown();
}
};

if let Some(guard_path) = new_value.guard_path.as_ref() {
if let Some(type_override) = self.symbol_type(guard_path) {
new_value.type_id = type_override;
self.unknown()
}
}

new_value
}

fn compile_pair_field_access(
Expand Down Expand Up @@ -113,7 +105,7 @@ impl Compiler<'_> {
) -> Option<Value> {
let fields =
deconstruct_items(self.ty, ty.type_id, ty.field_names.len(), ty.nil_terminated)
.expect("invalid struct type");
.unwrap();

let Some(index) = ty.field_names.get_index_of(name.text()) else {
self.db
Expand Down Expand Up @@ -157,7 +149,7 @@ impl Compiler<'_> {
.as_ref()
.map(|field_names| {
deconstruct_items(self.ty, type_id, field_names.len(), ty.nil_terminated)
.expect("invalid struct type")
.unwrap()
})
.unwrap_or_default()
} else {
Expand Down
4 changes: 3 additions & 1 deletion crates/rue-compiler/src/compiler/expr/guard_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ impl Compiler<'_> {

if let Some(guard_path) = expr.guard_path {
let difference = self.ty.difference(expr.type_id, rhs);
value.guards.insert(guard_path, Guard::new(rhs, difference));
value
.guards
.insert(guard_path, Guard::new(Some(rhs), Some(difference)));
}

value
Expand Down
10 changes: 6 additions & 4 deletions crates/rue-compiler/src/compiler/expr/if_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ impl Compiler<'_> {
.map(|condition| self.compile_expr(&condition, Some(self.ty.std().bool)));

if let Some(condition) = condition.as_ref() {
self.type_guard_stack.push(condition.then_guards());
let overrides = self.build_overrides(condition.then_guards());
self.type_overrides.push(overrides);
}

let then_block = if_expr
.then_block()
.map(|then_block| self.compile_block_expr(&then_block, expected_type));

if condition.is_some() {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();
}

if let Some(condition) = condition.as_ref() {
self.type_guard_stack.push(condition.else_guards());
let overrides = self.build_overrides(condition.else_guards());
self.type_overrides.push(overrides);
}

let expected_type =
Expand All @@ -33,7 +35,7 @@ impl Compiler<'_> {
.map(|else_block| self.compile_block_expr(&else_block, expected_type));

if condition.is_some() {
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();
}

if let Some(condition_type) = condition.as_ref().map(|condition| condition.type_id) {
Expand Down
4 changes: 2 additions & 2 deletions crates/rue-compiler/src/compiler/expr/initializer_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl Compiler<'_> {
.path()
.map(|path| self.compile_path_type(&path.items(), path.syntax().text_range()));

match ty.map(|ty| self.ty.get(ty)).cloned() {
match ty.map(|ty| self.ty.get_unaliased(ty)).cloned() {
Some(Type::Struct(struct_type)) => {
let fields = deconstruct_items(
self.ty,
Expand Down Expand Up @@ -85,7 +85,7 @@ impl Compiler<'_> {
self.unknown()
}
}
Some(_) => {
Some(..) => {
self.db.error(
ErrorKind::UninitializableType(self.type_name(ty.unwrap())),
initializer.path().unwrap().syntax().text_range(),
Expand Down
14 changes: 6 additions & 8 deletions crates/rue-compiler/src/compiler/expr/path_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
Compiler,
},
hir::Hir,
symbol::{Function, Symbol},
symbol::Symbol,
value::{GuardPath, Value},
ErrorKind,
};
Expand Down Expand Up @@ -75,18 +75,16 @@ impl Compiler<'_> {
return self.unknown();
}

let type_override = self.symbol_type(&GuardPath::new(symbol_id));
let type_id = self.symbol_type(symbol_id);
let reference = self.db.alloc_hir(Hir::Reference(symbol_id, text_range));

let mut value = match self.db.symbol(symbol_id).clone() {
Symbol::Unknown | Symbol::Module(..) => unreachable!(),
Symbol::Function(Function { type_id, .. })
| Symbol::InlineFunction(Function { type_id, .. })
| Symbol::Parameter(type_id) => Value::new(reference, type_override.unwrap_or(type_id)),
Symbol::Function(..) | Symbol::InlineFunction(..) | Symbol::Parameter(..) => {
Value::new(reference, type_id)
}
Symbol::Let(mut value) | Symbol::Const(mut value) | Symbol::InlineConst(mut value) => {
if let Some(type_id) = type_override {
value.type_id = type_id;
}
value.type_id = type_id;
value.hir_id = reference;
value
}
Expand Down
5 changes: 3 additions & 2 deletions crates/rue-compiler/src/compiler/stmt/if_stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ impl Compiler<'_> {
let scope_id = self.db.alloc_scope(Scope::default());

// We can apply any type guards from the condition.
self.type_guard_stack.push(condition.then_guards());
let overrides = self.build_overrides(condition.then_guards());
self.type_overrides.push(overrides);

// Compile the then block.
self.scope_stack.push(scope_id);
let summary = self.compile_block(&then_block, expected_type);
self.scope_stack.pop().unwrap();

// Pop the type guards, since we've left the scope.
self.type_guard_stack.pop().unwrap();
self.type_overrides.pop().unwrap();

// If there's an implicit return, we want to raise an error.
// This could technically work but makes the intent of the code unclear.
Expand Down
Loading