diff --git a/Cargo.lock b/Cargo.lock index d8450ce8a..4979259e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,7 +31,7 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "getrandom 0.2.9", + "getrandom", "once_cell", "version_check", ] @@ -370,17 +370,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "contracts" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9424f2ca1e42776615720e5746eed6efa19866fdbaac2923ab51c294ac4d1f2" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "core-foundation-sys" version = "0.8.4" @@ -604,6 +593,12 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "ena" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef6777d97afbda76a79b1edb9127ed1a72a25159e75e12829acee0e3e9f27495" + [[package]] name = "encode_unicode" version = "0.3.6" @@ -691,17 +686,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.9" @@ -710,7 +694,7 @@ checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", ] [[package]] @@ -839,11 +823,11 @@ dependencies = [ "log", "memory-stats", "parking_lot", - "rand 0.8.5", + "rand", "regex", "rustc-hash", "semver", - "vector-map", + "threadpool", "winapi", ] @@ -896,27 +880,25 @@ dependencies = [ [[package]] name = "inkwell" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbac11e485159a525867fb7e6aa61981453e6a72f625fde6a4ab3047b0c6dec9" +version = "0.2.0" +source = "git+https://github.com/TheDan64/inkwell?branch=master#7a09ad8a5f3b1fc416f95b5e1c97d33df0ab3f06" dependencies = [ "either", "inkwell_internals", "libc", "llvm-sys", "once_cell", - "parking_lot", + "thiserror", ] [[package]] name = "inkwell_internals" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87d00c17e264ce02be5bc23d7bff959188ec7137beddd06b8b6b05a7c680ea85" +version = "0.8.0" +source = "git+https://github.com/TheDan64/inkwell?branch=master#7a09ad8a5f3b1fc416f95b5e1c97d33df0ab3f06" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.29", ] [[package]] @@ -1048,8 +1030,8 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "lld_rs" -version = "140.0.0" -source = "git+https://github.com/Pivot-Studio/lld-rs.git?branch=main#d60b0d447fe06923a83f2938b016ea3b443f6c5f" +version = "160.0.0" +source = "git+https://github.com/Pivot-Studio/lld-rs.git?branch=main#d1b3a4a2e0fffdb9d3e9174c72223861f5629933" dependencies = [ "cc", "lazy_static", @@ -1060,9 +1042,9 @@ dependencies = [ [[package]] name = "llvm-sys" -version = "140.0.5" +version = "160.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fe3609d79f74a2e4e158e3eaa61c2e931cecba242a79de5529cec2f92b423b7" +checksum = "0bf51981ac0622b10fe4790763e3de1f3d68a0ee4222e03accaaab6731bd508d" dependencies = [ "cc", "lazy_static", @@ -1322,6 +1304,7 @@ dependencies = [ "dissimilar", "dunce", "dyn-fmt", + "ena", "enum_dispatch", "expect-test", "immix", @@ -1467,19 +1450,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -1487,18 +1457,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", + "rand_chacha", + "rand_core", ] [[package]] @@ -1508,16 +1468,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", + "rand_core", ] [[package]] @@ -1526,16 +1477,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.9", -] - -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", + "getrandom", ] [[package]] @@ -1852,18 +1794,18 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", @@ -2027,16 +1969,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "vector-map" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "550f72ae94a45c0e2139188709e6c4179f0b5ff9bdaa435239ad19048b0cd68c" -dependencies = [ - "contracts", - "rand 0.7.3", -] - [[package]] name = "vergen" version = "7.5.1" @@ -2071,7 +2003,7 @@ dependencies = [ "internal_macro", "libc", "log", - "rand 0.8.5", + "rand", ] [[package]] @@ -2084,12 +2016,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index b92ca0d6d..aa72cb817 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ authors = ["The pivot-lang Authors"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -inkwell = { version = "0.1", optional = true, features = ["llvm14-0", "no-libffi-linking"] } -llvm-sys = { version = "140", optional = true } +inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", optional = true, features = ["llvm16-0", "no-libffi-linking"] } +llvm-sys = { version = "160", optional = true } pl_linker = { path = "./pl_linker", optional = true } immix = { path = "./immix", optional = true, features = ["llvm_gc_plugin", "llvm_stackmap"] } vm = { path = "./vm", optional = true, features = ["jit"] } @@ -45,6 +45,7 @@ derivative = "2.2" console = "0.15" anstyle = "1.0" regex = "1.9" +ena = "0.2" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2" diff --git a/immix/Cargo.toml b/immix/Cargo.toml index a17468144..a3fc2080b 100644 --- a/immix/Cargo.toml +++ b/immix/Cargo.toml @@ -11,10 +11,10 @@ libc = "0.2" parking_lot = "0.12" rustc-hash = "1.1" lazy_static = "1.4" -vector-map = "1.0" backtrace = "0.3" log = "0.4" memory-stats = "1.1" +threadpool = {version = "1.8"} [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["winuser","memoryapi"] } diff --git a/immix/benches/immix_bench.rs b/immix/benches/immix_bench.rs index 0316018d0..96a2c66fc 100644 --- a/immix/benches/immix_bench.rs +++ b/immix/benches/immix_bench.rs @@ -173,7 +173,7 @@ fn test_complecated_multiple_thread_gc(num_iter: usize, threads: usize) -> Durat for h in handles { times.push(h.join().unwrap()); } - times.sort_by(|k1, k2| (k1).cmp(k2)); + times.sort(); times.pop().unwrap() } diff --git a/immix/build.rs b/immix/build.rs index 6396e92ae..afef8eb3f 100644 --- a/immix/build.rs +++ b/immix/build.rs @@ -33,33 +33,33 @@ fn main() { lazy_static! { /// A single path to search for LLVM in (containing bin/llvm-config) static ref ENV_LLVM_PREFIX: String = - format!("LLVM_SYS_{}_PREFIX", 140); + format!("LLVM_SYS_{}_PREFIX", 160); /// If exactly "YES", ignore the version blocklist static ref ENV_IGNORE_BLOCKLIST: String = - format!("LLVM_SYS_{}_IGNORE_BLOCKLIST", 140); + format!("LLVM_SYS_{}_IGNORE_BLOCKLIST", 160); /// If set, enforce precise correspondence between crate and binary versions. static ref ENV_STRICT_VERSIONING: String = - format!("LLVM_SYS_{}_STRICT_VERSIONING", 140); + format!("LLVM_SYS_{}_STRICT_VERSIONING", 160); /// If set, do not attempt to strip irrelevant options for llvm-config --cflags static ref ENV_NO_CLEAN_CFLAGS: String = - format!("LLVM_SYS_{}_NO_CLEAN_CFLAGS", 140); + format!("LLVM_SYS_{}_NO_CLEAN_CFLAGS", 160); /// If set and targeting MSVC, force the debug runtime library static ref ENV_USE_DEBUG_MSVCRT: String = - format!("LLVM_SYS_{}_USE_DEBUG_MSVCRT", 140); + format!("LLVM_SYS_{}_USE_DEBUG_MSVCRT", 160); /// If set, always link against libffi static ref ENV_FORCE_FFI: String = - format!("LLVM_SYS_{}_FFI_WORKAROUND", 140); + format!("LLVM_SYS_{}_FFI_WORKAROUND", 160); } lazy_static! { /// LLVM version used by this version of the crate. static ref CRATE_VERSION: Version = { - let crate_version = Version::parse("140.0.6") + let crate_version = Version::parse("160.0.6") .expect("Crate version is somehow not valid semver"); Version { major: crate_version.major / 10, @@ -244,7 +244,7 @@ fn main() { fn get_system_libraries() -> Vec { llvm_config("--system-libs") .split(&[' ', '\n'] as &[char]) - .filter(|s| !s.is_empty()) + .filter(|s| !s.is_empty() && s.starts_with("-l")) .map(|flag| { if cfg!(target_env = "msvc") { // Same as --libnames, foo.lib @@ -252,11 +252,13 @@ fn main() { &flag[..flag.len() - 4] } else if cfg!(target_os = "macos") { // Linker flags style, -lfoo - assert!(flag.starts_with("-l")); + // assert!(flag.starts_with("-l"), "{}",flag); if flag.ends_with(".tbd") && flag.starts_with("-llib") { &flag[5..flag.len() - 4] + } else if let Some(postfix) = flag.strip_prefix("-l") { + postfix } else { - &flag[2..] + flag } } else { if let Some(f) = flag.strip_prefix("-l") { diff --git a/immix/llvm/CMakeLists.txt b/immix/llvm/CMakeLists.txt index 0e6ad0e61..322eeea58 100644 --- a/immix/llvm/CMakeLists.txt +++ b/immix/llvm/CMakeLists.txt @@ -5,13 +5,18 @@ cmake_minimum_required(VERSION 3.5) # Set the project name project (plimmix) -find_package(LLVM 14.0.0 REQUIRED CONFIG) +find_package(LLVM 16.0.0 REQUIRED CONFIG) separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS_LIST}) include_directories(${LLVM_INCLUDE_DIRS}) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") + +if (WIN32) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++17") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") +endif() set(SOURCES memory_manager.cpp diff --git a/immix/llvm/plimmixprinter.cpp b/immix/llvm/plimmixprinter.cpp index 75bfabd06..d86a02112 100644 --- a/immix/llvm/plimmixprinter.cpp +++ b/immix/llvm/plimmixprinter.cpp @@ -38,7 +38,7 @@ void PLImmixGCPrinter::finishAssembly(Module &M, GCModuleInfo &Info, AsmPrinter unsigned IntPtrSize = AP.getPointerSize(); AP.emitAlignment(llvm::Align(8)); // Put this in the data section. - AP.OutStreamer.get()->SwitchSection(AP.getObjFileLowering().getDataSection()); + AP.OutStreamer.get()->switchSection(AP.getObjFileLowering().getDataSection()); std::string symbol; symbol += "_IMMIX_GC_MAP_"; symbol += M.getSourceFileName(); diff --git a/immix/src/allocator/global_allocator.rs b/immix/src/allocator/global_allocator.rs index b771979f1..669967d26 100644 --- a/immix/src/allocator/global_allocator.rs +++ b/immix/src/allocator/global_allocator.rs @@ -1,6 +1,7 @@ use std::cell::Cell; use parking_lot::ReentrantMutex; +use threadpool::ThreadPool; use crate::{bigobj::BigObj, block::Block, consts::BLOCK_SIZE, mmap::Mmap}; @@ -33,6 +34,8 @@ pub struct GlobalAllocator { /// 周期计数器,每个内存总体增加周期加1,每个内存总体不变/减少周期减1 round: i64, + + pub pool: ThreadPool, } unsafe impl Sync for GlobalAllocator {} @@ -49,6 +52,7 @@ impl GlobalAllocator { let mmap = Mmap::new(size * 3 / 4); // mmap.commit(mmap.aligned(), BLOCK_SIZE); + // let n_workers = available_parallelism().unwrap().get(); Self { current: Cell::new(mmap.aligned(BLOCK_SIZE)), @@ -61,6 +65,7 @@ impl GlobalAllocator { last_get_block_time: std::time::Instant::now(), mem_usage_flag: 0, round: 0, + pool: ThreadPool::default(), } } /// 从big object mmap中分配一个大对象,大小为size diff --git a/immix/src/allocator/thread_local_allocator.rs b/immix/src/allocator/thread_local_allocator.rs index 832af09cf..bd563bd06 100644 --- a/immix/src/allocator/thread_local_allocator.rs +++ b/immix/src/allocator/thread_local_allocator.rs @@ -8,7 +8,7 @@ use std::collections::VecDeque; -use vector_map::VecMap; +use rustc_hash::FxHashMap; use crate::{ bigobj::BigObj, @@ -127,7 +127,7 @@ impl ThreadLocalAllocator { self.recyclable_blocks.len() > 1 } - pub fn fill_available_histogram(&self, histogram: &mut VecMap) -> usize { + pub fn fill_available_histogram(&self, histogram: &mut FxHashMap) -> usize { let mut total_available = 0; self.recyclable_blocks .iter() @@ -343,7 +343,7 @@ impl ThreadLocalAllocator { /// Iterate all blocks, if a block is not marked, free it. /// Correct all remain blocks' headers, and classify them /// into recyclable blocks and unavailable blocks. - pub fn sweep(&mut self, mark_histogram: *mut VecMap) -> usize { + pub fn sweep(&mut self, mark_histogram: *mut FxHashMap) -> usize { let mut recyclable_blocks = VecDeque::new(); let mut unavailable_blocks = Vec::new(); let mut free_blocks = Vec::new(); diff --git a/immix/src/block.rs b/immix/src/block.rs index 5fb499f39..ea30d91ab 100644 --- a/immix/src/block.rs +++ b/immix/src/block.rs @@ -1,7 +1,7 @@ use std::sync::atomic::{AtomicU8, Ordering}; use int_enum::IntEnum; -use vector_map::VecMap; +use rustc_hash::FxHashMap; use crate::consts::{BLOCK_SIZE, LINE_SIZE, NUM_LINES_PER_BLOCK}; @@ -220,7 +220,7 @@ impl Block { /// # correct_header /// 回收的最后阶段,重置block的header - pub unsafe fn correct_header(&mut self, mark_histogram: *mut VecMap) -> usize { + pub unsafe fn correct_header(&mut self, mark_histogram: *mut FxHashMap) -> usize { let mut idx = 3; let mut len = 0; let mut first_hole_line_idx: usize = 3; @@ -585,7 +585,7 @@ mod tests { let l = block.get_nth_line_header(6).get_obj_type(); assert_eq!(l, crate::block::ObjectType::Complex); assert_eq!(start, 6); - assert_eq!(newcursor, false); + assert!(!newcursor); assert_eq!(block.cursor, 4); // assert_eq!(block.limit, 1); let (start, newcursor) = block @@ -606,7 +606,7 @@ mod tests { // ...... // | 255 | 已使用 assert_eq!(start, 4); - assert_eq!(newcursor, false); + assert!(!newcursor); // assert_eq!(block.first_hole_line_idx, 255); 这个时候没hole了,此值无意义,len为0 // assert_eq!(block.limit, 0); diff --git a/immix/src/collector.rs b/immix/src/collector.rs index 19a7bf7fb..c5507f2ba 100644 --- a/immix/src/collector.rs +++ b/immix/src/collector.rs @@ -1,11 +1,15 @@ use std::{ cell::RefCell, ptr::drop_in_place, - sync::atomic::{AtomicPtr, Ordering}, + sync::{ + atomic::{AtomicBool, AtomicPtr, Ordering}, + Arc, + }, }; use libc::malloc; -use vector_map::VecMap; +use parking_lot::{Condvar, Mutex}; +use rustc_hash::{FxHashMap, FxHashSet}; #[cfg(feature = "llvm_stackmap")] use crate::STACK_MAP; @@ -13,8 +17,8 @@ use crate::{ allocator::{GlobalAllocator, ThreadLocalAllocator}, block::{Block, LineHeaderExt, ObjectType}, gc_is_auto_collect_enabled, spin_until, HeaderExt, ENABLE_EVA, GC_COLLECTOR_COUNT, GC_ID, - GC_MARKING, GC_MARK_COND, GC_RUNNING, GC_STW_COUNT, GC_SWEEPING, GC_SWEEPPING_NUM, LINE_SIZE, - NUM_LINES_PER_BLOCK, THRESHOLD_PROPORTION, USE_SHADOW_STACK, + GC_MARKING, GC_MARK_COND, GC_RUNNING, GC_STW_COUNT, GC_SWEEPING, GC_SWEEPPING_NUM, + GLOBAL_ALLOCATOR, LINE_SIZE, NUM_LINES_PER_BLOCK, THRESHOLD_PROPORTION, USE_SHADOW_STACK, }; /// # Collector @@ -37,8 +41,11 @@ pub struct Collector { roots: rustc_hash::FxHashMap<*mut u8, ObjectType>, queue: *mut Vec<(*mut u8, ObjectType)>, id: usize, - mark_histogram: *mut VecMap, + mark_histogram: *mut FxHashMap, status: RefCell, + frames_list: AtomicPtr>, + shadow_thread_running: AtomicBool, + live_set: FxHashSet<*mut u8>, } struct CollectorStatus { @@ -88,9 +95,12 @@ impl Collector { let mem = malloc(core::mem::size_of::()).cast::(); mem.write(tla); - let memvecmap = - malloc(core::mem::size_of::>()).cast::>(); - memvecmap.write(VecMap::with_capacity(NUM_LINES_PER_BLOCK)); + let memvecmap = malloc(core::mem::size_of::>()) + .cast::>(); + memvecmap.write(FxHashMap::with_capacity_and_hasher( + NUM_LINES_PER_BLOCK, + Default::default(), + )); let queue = Vec::new(); let memqueue = malloc(core::mem::size_of::>()) @@ -108,6 +118,9 @@ impl Collector { bytes_allocated_since_last_gc: 0, collecting: false, }), + frames_list: AtomicPtr::default(), + shadow_thread_running: AtomicBool::new(false), + live_set: FxHashSet::default(), } } } @@ -146,6 +159,9 @@ impl Collector { if size == 0 { return std::ptr::null_mut(); } + if !self.frames_list.load(Ordering::SeqCst).is_null() { + panic!("gc stucked, can not alloc") + } if gc_is_auto_collect_enabled() { if GC_RUNNING.load(Ordering::Acquire) { self.collect(); @@ -364,7 +380,12 @@ impl Collector { } self.mark_ptr(ptr as *mut u8); } - + pub fn keep_live(&mut self, gc_ptr: *mut u8) { + self.live_set.insert(gc_ptr); + } + pub fn rm_live(&mut self, gc_ptr: *mut u8) { + self.live_set.remove(&gc_ptr); + } pub fn print_stats(&self) { println!("gc {} states:", self.id); unsafe { @@ -389,7 +410,10 @@ impl Collector { /// /// this mark function is __precise__ pub fn mark(&self) { + let mutex = STUCK_MUTEX.lock(); GC_RUNNING.store(true, Ordering::Release); + STUCK_COND.notify_all(); + drop(mutex); let mut v = GC_COLLECTOR_COUNT.lock(); let (count, mut waiting) = *v; @@ -421,17 +445,17 @@ impl Collector { unsafe { match obj_type { ObjectType::Atomic => {} - ObjectType::Pointer => (*self.queue).push((*root, *obj_type)), - _ => { - if !self.thread_local_allocator.as_mut().unwrap().in_heap(*root) { - continue; - } - (*self.queue).push((*root, *obj_type)) - } + _ => (*self.queue).push((*root, *obj_type)), } } } } + for live in self.live_set.iter() { + unsafe { + self.mark_ptr(live as *const _ as _); + } + } + log::trace!("gc {}: marking...", self.id); #[cfg(feature = "llvm_stackmap")] { if USE_SHADOW_STACK.load(Ordering::Relaxed) { @@ -442,10 +466,23 @@ impl Collector { } else { // println!("{:?}", &STACK_MAP.map.borrow()); let mut depth = 0; - backtrace::trace(|frame| { - let addr = frame.ip() as *mut u8; + let fl = self.frames_list.load(Ordering::SeqCst); + let frames = if !fl.is_null() { + log::trace!("gc {}: tracing stucked frames", self.id); + unsafe { fl.as_mut().unwrap().clone() } + } else { + let mut frames: Vec<(*mut libc::c_void, *mut libc::c_void)> = vec![]; + backtrace::trace(|frame| { + frames.push((frame.ip(), frame.sp())); + true + }); + frames + }; + + frames.iter().for_each(|(ip, sp)| unsafe { + let addr = *ip as *mut u8; let const_addr = addr as *const u8; - let map = STACK_MAP.map.borrow(); + let map = STACK_MAP.map.as_ref().unwrap(); let f = map.get(&const_addr); // backtrace::resolve_frame(frame, // |s| @@ -455,15 +492,14 @@ impl Collector { // ); if let Some(f) = f { // println!("found fn in stackmap, f: {:?} sp: {:p}", f,frame.sp()); - f.iter_roots().for_each(|(offset, _obj_type)| unsafe { + f.iter_roots().for_each(|(offset, _obj_type)| { // println!("offset: {}", offset); - let sp = frame.sp() as *mut u8; + let sp = *sp as *mut u8; let root = sp.offset(offset as isize); self.mark_ptr(root); }); } depth += 1; - true }); self.mark_globals(); } @@ -498,28 +534,32 @@ impl Collector { } else { GC_MARKING.store(false, Ordering::Release); GC_MARK_COND.notify_all(); + GC_RUNNING.store(false, Ordering::Release); drop(v); } } #[cfg(feature = "llvm_stackmap")] fn mark_globals(&self) { - STACK_MAP - .global_roots - .borrow() - .iter() - .for_each(|root| unsafe { - self.mark_ptr((*root) as usize as *mut u8); - }); + unsafe { + STACK_MAP + .global_roots + .as_mut() + .unwrap() + .iter() + .for_each(|root| { + self.mark_ptr((*root) as usize as *mut u8); + }); + } } /// # sweep /// /// since we did synchronization in mark, we don't need to do synchronization again in sweep pub fn sweep(&self) -> usize { - GC_RUNNING.store(false, Ordering::Release); GC_SWEEPPING_NUM.fetch_add(1, Ordering::AcqRel); GC_SWEEPING.store(true, Ordering::Release); + log::trace!("gc {}: sweeping...", self.id); let used = unsafe { self.thread_local_allocator .as_mut() @@ -534,11 +574,21 @@ impl Collector { used } + pub fn safepoint(&self) { + if GC_RUNNING.load(Ordering::Acquire) { + self.collect(); + } + } + /// # collect + /// /// Collect garbage. pub fn collect(&self) { - // let start_time = std::time::Instant::now(); - log::info!("gc {} collecting...", self.id); + log::info!( + "gc {} collecting... stucked: {}", + self.id, + !self.frames_list.load(Ordering::SeqCst).is_null() + ); // self.print_stats(); let mut status = self.status.borrow_mut(); // println!("gc {} collecting... {}", self.id,status.bytes_allocated_since_last_gc); @@ -558,8 +608,8 @@ impl Collector { { // 如果需要驱逐,首先计算驱逐阀域 let mut eva_threshold = 0; - let mut available_histogram: VecMap = - VecMap::with_capacity(NUM_LINES_PER_BLOCK); + let mut available_histogram: FxHashMap = + FxHashMap::with_capacity_and_hasher(NUM_LINES_PER_BLOCK, Default::default()); let mut available_lines = self .thread_local_allocator .as_mut() @@ -624,8 +674,75 @@ impl Collector { .get_bigobjs_size() } } + + pub fn stuck(&mut self) { + log::trace!("gc {}: stucking...", self.id); + let mut frames: Box> = Box::default(); + backtrace::trace(|frame| { + frames.push((frame.ip(), frame.sp())); + true + }); + unsafe { + let ptr = Box::leak(frames) as *mut _; + self.frames_list.store(ptr, Ordering::SeqCst); + let c: *mut Collector = self as *mut _; + let c = c.as_mut().unwrap(); + c.shadow_thread_running.store(true, Ordering::SeqCst); + GLOBAL_ALLOCATOR.0.as_ref().unwrap().pool.execute(move || { + log::info!("gc {}: stucked, waiting for unstuck...", c.id); + loop { + let mut mutex = STUCK_MUTEX.lock(); + if c.frames_list.load(Ordering::SeqCst).is_null() { + log::trace!("gc {}: unstucking break...", c.id); + c.shadow_thread_running.store(false, Ordering::SeqCst); + drop(mutex); + break; + } else if GC_RUNNING.load(Ordering::Acquire) { + drop(mutex); + c.collect(); + } else { + STUCK_COND.wait(&mut mutex); + drop(mutex); + c.safepoint(); + } + } + }); + } + STUCK_COND.notify_all(); + // FRAMES_LIST.0.lock().borrow_mut().insert( self as _,frames); + } + + pub fn unstuck(&mut self) { + log::trace!("gc {}: unstucking...", self.id); + let mutex = STUCK_MUTEX.lock(); + let old = self + .frames_list + .swap(std::ptr::null_mut(), Ordering::SeqCst); + if !old.is_null() { + STUCK_COND.notify_all(); + drop(mutex); + unsafe { + drop_in_place(old as *mut Vec<(*mut libc::c_void, *mut libc::c_void)>); + } + } else { + STUCK_COND.notify_all(); + drop(mutex); + } + // wait until the shadow thread exit + spin_until!(!self.shadow_thread_running.load(Ordering::SeqCst)); + } } #[cfg(test)] #[cfg(feature = "shadow_stack")] mod tests; + +// static STUCK_GCED: AtomicBool = AtomicBool::new(false); + +lazy_static::lazy_static! { + static ref STUCK_COND: Arc = Arc::new(Condvar::new()); + static ref STUCK_MUTEX:Mutex<()> = Mutex::new(()); +} + +unsafe impl Sync for Collector {} +unsafe impl Send for Collector {} diff --git a/immix/src/lib.rs b/immix/src/lib.rs index ef4307159..4e0ff70c8 100644 --- a/immix/src/lib.rs +++ b/immix/src/lib.rs @@ -44,17 +44,16 @@ thread_local! { #[cfg(feature = "llvm_stackmap")] lazy_static! { static ref STACK_MAP: StackMapWrapper = { - StackMapWrapper { - map: RefCell::new(FxHashMap::default()), - global_roots: RefCell::new(vec![]), - } + let map = Box::into_raw(Box::default()); + let global_roots = Box::into_raw(Box::default()); + StackMapWrapper { map, global_roots } }; } #[cfg(feature = "llvm_stackmap")] pub struct StackMapWrapper { - pub map: RefCell>, - pub global_roots: RefCell>, + pub map: *mut FxHashMap<*const u8, Function>, + pub global_roots: *mut Vec<*const u8>, } #[cfg(feature = "llvm_stackmap")] unsafe impl Sync for StackMapWrapper {} @@ -131,6 +130,24 @@ pub fn gc_add_root(root: *mut u8, obj_type: u8) { }) } +pub fn gc_keep_live(gc_ptr: *mut u8) { + SPACE.with(|gc| { + // println!("start add_root"); + let mut gc = gc.borrow_mut(); + gc.keep_live(gc_ptr); + // println!("add_root") + }) +} + +pub fn gc_rm_live(gc_ptr: *mut u8) { + SPACE.with(|gc| { + // println!("start add_root"); + let mut gc = gc.borrow_mut(); + gc.rm_live(gc_ptr); + // println!("add_root") + }) +} + #[cfg(feature = "shadow_stack")] pub fn gc_remove_root(root: *mut u8) { SPACE.with(|gc| { @@ -162,35 +179,45 @@ pub fn no_gc_thread() { #[cfg(feature = "llvm_stackmap")] pub fn gc_init(ptr: *mut u8) { // println!("stackmap: {:?}", &STACK_MAP.map.borrow()); - build_root_maps( - ptr, - &mut STACK_MAP.map.borrow_mut(), - &mut STACK_MAP.global_roots.borrow_mut(), - ); + build_root_maps(ptr, unsafe { STACK_MAP.map.as_mut().unwrap() }, unsafe { + STACK_MAP.global_roots.as_mut().unwrap() + }); } -/// notify gc if a thread is going to stuck e.g. -/// lock a mutex or doing sync io +/// notify gc current thread is going to stuck e.g. +/// lock a mutex or doing sync io or sleep etc. +/// +/// during thread stucking, gc will start a nanny thread to +/// do gc works that original thread should do. /// -/// during thread stucking, if a gc is triggered, it will skip waiting for this thread to -/// reach a safe point +/// ## Note +/// +/// During thread stucking, the stucking thread should not +/// request any memory from gc, or it will cause a panic. pub fn thread_stuck_start() { - let mut v = GC_COLLECTOR_COUNT.lock(); - v.0 -= 1; - GC_MARK_COND.notify_all(); - drop(v); + // v.0 -= 1; + SPACE.with(|gc| { + // println!("start add_root"); + let mut gc = gc.borrow_mut(); + gc.stuck() + // println!("add_root") + }); } -/// notify gc a thread is not stuck anymore +/// notify gc current thread is not stuck anymore /// /// if a gc is triggered during thread stucking, this function /// will block until the gc is finished pub fn thread_stuck_end() { - let mut v = GC_COLLECTOR_COUNT.lock(); - GC_MARK_COND.wait_while(&mut v, |_| GC_RUNNING.load(Ordering::SeqCst)); - v.0 += 1; - GC_MARK_COND.notify_all(); - drop(v); + log::trace!("unstucking..."); + spin_until!(!GC_RUNNING.load(Ordering::SeqCst)); + // v.0 += 1; + SPACE.with(|gc| { + // println!("start add_root"); + let mut gc = gc.borrow_mut(); + gc.unstuck() + // println!("add_root") + }); } /// # set evacuation diff --git a/internal_macro/Cargo.toml b/internal_macro/Cargo.toml index 234d9fb04..a47884b8a 100644 --- a/internal_macro/Cargo.toml +++ b/internal_macro/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" ctor = "0.1" log = "0.4" -llvm-sys = { version = "140", optional = true } +llvm-sys = { version = "160", optional = true } [dependencies.add_symbol_macro] path = "src/add_symbol_macro" [dependencies.range_macro] diff --git a/pl_linker/Cargo.toml b/pl_linker/Cargo.toml index 096fbfbc7..7aea6749d 100644 --- a/pl_linker/Cargo.toml +++ b/pl_linker/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -lld_rs = { version = "140.0.0", default-features = false, git="https://github.com/Pivot-Studio/lld-rs.git", branch="main" } +lld_rs = { version = "160.0.0", default-features = false, git="https://github.com/Pivot-Studio/lld-rs.git", branch="main" } thiserror = "1.0.38" mun_target = { git="https://github.com/mun-lang/mun.git" } once_cell = { version = "1.4.0" } diff --git a/pl_linker/src/linker.rs b/pl_linker/src/linker.rs index 115d3de68..e94d95b61 100644 --- a/pl_linker/src/linker.rs +++ b/pl_linker/src/linker.rs @@ -128,6 +128,7 @@ impl Linker for LdLinker { "/lib/x86_64-linux-gnu/crti.o", // "/usr/lib/gcc/x86_64-linux-gnu//crtbeginS.o", "-lc", + "-lm", "-lpthread", "-lunwind", "--no-as-needed", diff --git a/planglib/core/gc.pi b/planglib/core/gc.pi index 80e74a26d..113937f45 100644 --- a/planglib/core/gc.pi +++ b/planglib/core/gc.pi @@ -4,6 +4,13 @@ pub fn DioGC__collect() void; // pub fn DioGC__malloc_no_collect(size: i64, obj_type: u8) *u8; +pub fn DioGC__disable_auto_collect() void; + +pub fn DioGC__enable_auto_collect() void; + +pub fn DioGC__stuck_begin() void; + +pub fn DioGC__stuck_end() void; pub fn malloc() *u8 { return DioGC__malloc(sizeof(), gc_type()); diff --git a/planglib/std/__private.pi b/planglib/std/__private.pi index 9b38bf0a8..2f7b68dfe 100644 --- a/planglib/std/__private.pi +++ b/planglib/std/__private.pi @@ -2,6 +2,9 @@ use std::io; use std::stdbuiltin; use std::iter; use std::cols::arr; +use std::future; +use std::chan; +use std::mutex; use std::cols::hashtable; use std::libc; use std::buf; diff --git a/planglib/std/chan.pi b/planglib/std/chan.pi new file mode 100644 index 000000000..8f95b4123 --- /dev/null +++ b/planglib/std/chan.pi @@ -0,0 +1,68 @@ +use std::mutex::*; +pub struct Chan { + buffer: Queue; + count: u64; + capacity: u64; + mtx: *Mutex; +} +struct Node { + data: T; + next: *Node; +} + + +struct Queue { + head: *Node; + tail: *Node; +} + +impl Queue{ + pub fn push(t: T) void { + let node = Node{}; + node.data = t; + self.tail.next = &node; + self.tail = &node; + return; + } + + pub fn pop() T { + let node = self.head.next; + self.head = node; + return node.data; + } + +} +pub fn channel(sz: u64) Chan { + let node = Node{}; + let ch = Chan {}; + ch.buffer = Queue{}; + ch.buffer.head = &node; + ch.buffer.tail = &node; + ch.count = 0; + ch.capacity = sz; + create_mutex(&ch.mtx); + return ch; +} +impl Chan { + pub fn send(s: S) void { + gc::DioGC__stuck_begin(); + while self.capacity <= self.count { + } + gc::DioGC__stuck_end(); + lock_mutexWrap(self.mtx); + self.buffer.push(s); + self.count = self.count + 1; + unlock_mutex(self.mtx); + return; + } + pub fn recv() S { + gc::DioGC__stuck_begin(); + while self.count==0 {} + gc::DioGC__stuck_end(); + lock_mutexWrap(self.mtx); + let s = self.buffer.pop(); + self.count = self.count - 1; + unlock_mutex(self.mtx); + return s; + } +} \ No newline at end of file diff --git a/planglib/std/future.pi b/planglib/std/future.pi new file mode 100644 index 000000000..71e9f5a1b --- /dev/null +++ b/planglib/std/future.pi @@ -0,0 +1,14 @@ +pub trait Waker { + fn wake()void; +} + +pub struct Pending {} + +pub struct Ready { + v: T; +} + +pub type Poll = Ready | Pending; +pub trait Future { + fn poll(wk: Waker) Poll; +} \ No newline at end of file diff --git a/planglib/std/mutex.pi b/planglib/std/mutex.pi new file mode 100644 index 000000000..5f97aa4cf --- /dev/null +++ b/planglib/std/mutex.pi @@ -0,0 +1,18 @@ +pub struct Mutex{} + +pub fn create_mutex(mutex: **Mutex) u64; + +fn lock_mutex(mutex: *Mutex) u64; + + +pub fn lock_mutexWrap(mutex: *Mutex) u64{ + let res:u64; + gc::DioGC__stuck_begin(); + res = lock_mutex(mutex); + gc::DioGC__stuck_end(); + return res; +} + +pub fn unlock_mutex(mutex: *Mutex) u64; + +pub fn drop_mutex(mutex: *Mutex) u64; \ No newline at end of file diff --git a/src/ast/builder/llvmbuilder.rs b/src/ast/builder/llvmbuilder.rs index 89f7278d0..5d75a87dc 100644 --- a/src/ast/builder/llvmbuilder.rs +++ b/src/ast/builder/llvmbuilder.rs @@ -22,12 +22,12 @@ use inkwell::{ module::{FlagBehavior, Linkage, Module}, targets::{InitializationConfig, Target, TargetMachine}, types::{ - AnyType, AsTypeRef, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FunctionType, - PointerType, StructType, VoidType, + AsTypeRef, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FunctionType, PointerType, + StructType, VoidType, }, values::{ - AnyValue, AnyValueEnum, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallableValue, - FunctionValue, PointerValue, + AnyValue, AnyValueEnum, BasicMetadataValueEnum, BasicValue, BasicValueEnum, FunctionValue, + PointerValue, }, AddressSpace, FloatPredicate, IntPredicate, OptimizationLevel, }; @@ -73,12 +73,8 @@ fn get_dw_ate_encoding(pritp: &PriType) -> u32 { } } -fn get_nth_mark_fn(f: FunctionValue, n: u32) -> CallableValue { - f.get_nth_param(n) - .unwrap() - .into_pointer_value() - .try_into() - .unwrap() +fn get_nth_mark_fn(f: FunctionValue, n: u32) -> PointerValue { + f.get_nth_param(n).unwrap().into_pointer_value() } pub fn create_llvm_deps<'ctx>( @@ -171,6 +167,16 @@ pub fn get_target_machine(level: OptimizationLevel) -> TargetMachine { } impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { + fn build_load_raw(&self, ptr: ValueHandle, name: &str, tp: BasicTypeEnum<'ctx>) -> ValueHandle { + let llvm_type = tp; + let ptr = self.get_llvm_value(ptr).unwrap(); + let ptr = ptr.into_pointer_value(); + let ptr = self.builder.build_load(llvm_type, ptr, name).unwrap(); + if ptr.is_pointer_value() { + self.create_root_for(ptr); + } + self.get_llvm_value_handle(&ptr.as_any_value_enum()) + } pub fn new( context: &'ctx Context, module: &'a Module<'ctx>, @@ -211,24 +217,37 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { let builder = self.builder; builder.unset_current_debug_location(); let (p, stack_root, _) = self.gc_malloc(name, ctx, pltype, malloc_fn); + let llvm_tp = self.get_basic_type_op(pltype, ctx).unwrap(); if let PLType::Struct(tp) = pltype { let f = self.get_or_insert_st_visit_fn_handle(&p, tp); - let i = self.builder.build_ptr_to_int( - f.as_global_value().as_pointer_value(), - self.context.i64_type(), - "_vtable", - ); - let vtable = self.builder.build_struct_gep(p, 0, "vtable").unwrap(); - self.builder.build_store(vtable, i); + let i = self + .builder + .build_ptr_to_int( + f.as_global_value().as_pointer_value(), + self.context.i64_type(), + "_vtable", + ) + .unwrap(); + let vtable = self + .builder + .build_struct_gep(llvm_tp, p, 0, "vtable") + .unwrap(); + self.builder.build_store(vtable, i).unwrap(); } else if let PLType::Arr(tp) = pltype { let f = self.gen_or_get_arr_visit_function(ctx, tp); - let i = self.builder.build_ptr_to_int( - f.as_global_value().as_pointer_value(), - self.context.i64_type(), - "_vtable", - ); - let vtable = self.builder.build_struct_gep(p, 0, "vtable").unwrap(); - self.builder.build_store(vtable, i); + let i = self + .builder + .build_ptr_to_int( + f.as_global_value().as_pointer_value(), + self.context.i64_type(), + "_vtable", + ) + .unwrap(); + let vtable = self + .builder + .build_struct_gep(llvm_tp, p, 0, "vtable") + .unwrap(); + self.builder.build_store(vtable, i).unwrap(); } if let Some(p) = declare { self.build_dbg_location(p); @@ -280,7 +299,8 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { self.builder.position_at_end(alloca); let stack_ptr = self .builder - .build_alloca(self.context.i64_type(), "ctx_tp_ptr"); + .build_alloca(self.context.i64_type(), "ctx_tp_ptr") + .unwrap(); ctx.generator_data .as_ref() .unwrap() @@ -291,7 +311,8 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { size = self .builder - .build_load(stack_ptr, "ctx_tp") + .build_load(self.context.i64_type(), stack_ptr, "ctx_tp") + .unwrap() .into_int_value(); } let heapptr = self @@ -301,15 +322,19 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { &[size.into(), immix_tp.into()], &format!("heapptr_{}", name), ) + .unwrap() .try_as_basic_value() .left() .unwrap(); - let casted_result = self.builder.build_bitcast( - heapptr.into_pointer_value(), - llvmtp.ptr_type(AddressSpace::default()), - name, - ); + let casted_result = self + .builder + .build_bitcast( + heapptr.into_pointer_value(), + llvmtp.ptr_type(AddressSpace::default()), + name, + ) + .unwrap(); // TODO: force user to manually init all structs, so we can remove this memset self.builder @@ -327,13 +352,24 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { if alloca.get_terminator().is_some() { panic!("alloca block should not have terminator yet") } - let stack_ptr = self.builder.build_alloca( - llvmtp.ptr_type(AddressSpace::default()), - &format!("stack_ptr_{}", name), - ); + let stack_ptr = self + .builder + .build_alloca( + llvmtp.ptr_type(AddressSpace::default()), + &format!("stack_ptr_{}", name), + ) + .unwrap(); + // self.builder + // .build_memset( + // stack_ptr, + // td.get_abi_alignment(&self.i8ptr()), + // self.context.i8_type().const_zero(), + // self.context.i64_type().const_int( td.get_store_size(&self.i8ptr()),false), + // ) + // .unwrap(); self.gc_add_root(stack_ptr.as_basic_value_enum(), obj_type); self.builder.position_at_end(lb); - self.builder.build_store(stack_ptr, casted_result); + self.builder.build_store(stack_ptr, casted_result).unwrap(); self.builder.position_at_end(cb); @@ -351,20 +387,22 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { .get_llvm_value(arr.size_handle) .unwrap() .into_int_value(); - let arr_size = self.builder.build_int_mul(arr_len, size, "arr_size"); - let arr_size = self.builder.build_int_z_extend_or_bit_cast( - arr_size, - self.context.i64_type(), - "arr_size", - ); + let arr_size = self + .builder + .build_int_mul(arr_len, size, "arr_size") + .unwrap(); + let arr_size = self + .builder + .build_int_z_extend_or_bit_cast(arr_size, self.context.i64_type(), "arr_size") + .unwrap(); let len_ptr = self .builder - .build_struct_gep(casted_result.into_pointer_value(), 2, "arr_len") + .build_struct_gep(llvmtp, casted_result.into_pointer_value(), 2, "arr_len") .unwrap(); - self.builder.build_store(len_ptr, arr_len); + self.builder.build_store(len_ptr, arr_len).unwrap(); let arr_ptr = self .builder - .build_struct_gep(casted_result.into_pointer_value(), 1, "arr_ptr") + .build_struct_gep(llvmtp, casted_result.into_pointer_value(), 1, "arr_ptr") .unwrap(); let arr_space = self .builder @@ -379,15 +417,27 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { ], "arr_space", ) + .unwrap() .try_as_basic_value() .left() .unwrap(); - let arr_space = self.builder.build_bitcast( - arr_space.into_pointer_value(), - etp.ptr_type(AddressSpace::default()), - "arr_space", - ); - self.builder.build_store(arr_ptr, arr_space); + self.builder + .build_memset( + arr_space.into_pointer_value(), + 8, + self.context.i8_type().const_zero(), + arr_size, + ) + .unwrap(); + let arr_space = self + .builder + .build_bitcast( + arr_space.into_pointer_value(), + etp.ptr_type(AddressSpace::default()), + "arr_space", + ) + .unwrap(); + self.builder.build_store(arr_ptr, arr_space).unwrap(); } } @@ -443,13 +493,16 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { .get_first_basic_block() .unwrap(); self.builder.position_at_end(alloca); - let stack_ptr = self.builder.build_alloca(heap_ptr.get_type(), "stack_ptr"); + let stack_ptr = self + .builder + .build_alloca(heap_ptr.get_type(), "stack_ptr") + .unwrap(); self.gc_add_root( stack_ptr.as_basic_value_enum(), ObjectType::Pointer.int_value(), ); self.builder.position_at_end(lb); - self.builder.build_store(stack_ptr, heap_ptr); + self.builder.build_store(stack_ptr, heap_ptr).unwrap(); self.heap_stack_map.borrow_mut().insert( self.get_llvm_value_handle(&heap_ptr.as_any_value_enum()), self.get_llvm_value_handle(&stack_ptr.as_any_value_enum()), @@ -472,11 +525,14 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { }) .and_then(|f| { let i8ptr = self.context.i8_type().ptr_type(AddressSpace::default()); - let stackptr = self.builder.build_bitcast( - stackptr.into_pointer_value(), - i8ptr.ptr_type(AddressSpace::default()), - "stackptr", - ); + let stackptr = self + .builder + .build_bitcast( + stackptr.into_pointer_value(), + i8ptr.ptr_type(AddressSpace::default()), + "stackptr", + ) + .unwrap(); let tp = ObjectType::from_int(obj_type).expect("invalid object type"); let tp_const_name = format!( "@{}_@IMMIX_OBJTYPE_{}", @@ -500,14 +556,16 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { Some(g) }) .map(|g| { - self.builder.build_call( - f, - &[ - stackptr.into_pointer_value().into(), - g.as_pointer_value().into(), - ], - "add_root", - ); + self.builder + .build_call( + f, + &[ + stackptr.into_pointer_value().into(), + g.as_pointer_value().into(), + ], + "add_root", + ) + .unwrap(); }) }); } @@ -524,6 +582,18 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { } nh } + #[allow(dead_code)] + fn get_or_insert_print_fn(&self, name: &str) -> FunctionValue<'ctx> { + if let Some(f) = self.module.get_function(name) { + return f; + } + let ftp = self + .context + .void_type() + .fn_type(&[self.context.i64_type().into()], false); + let f = self.module.add_function(name, ftp, None); + f + } fn visit_f_tp(&self) -> PointerType<'ctx> { let i8ptrtp = self.context.i8_type().ptr_type(AddressSpace::default()); @@ -553,20 +623,18 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { let currentbb = self.builder.get_insert_block(); self.builder.unset_current_debug_location(); let ptrtp = self.arr_type(v, ctx).ptr_type(AddressSpace::default()); - let ty = ptrtp.get_element_type().into_struct_type(); + let ty = self.arr_type(v, ctx).into_struct_type(); let ftp = self.mark_fn_tp(ptrtp); let arr_tp = ty.get_field_type_at_index(1).unwrap(); // windows linker won't recognize flags with special caracters (llvm.used will add linker flags // to prevent symbol trim), so we need to do a hash here to remove the special caracters let mut hasher = DefaultHasher::new(); (arr_tp.to_string() + "@" + &ctx.plmod.path).hash(&mut hasher); - let fname = &format!("{:x}", hasher.finish()); + let fname = &format!("arr_visit{:x}", hasher.finish()); if let Some(f) = self.module.get_function(fname) { return f; } - let f = self - .module - .add_function(fname, ftp, Some(Linkage::LinkOnceAny)); + let f = self.module.add_function(fname, ftp, Some(Linkage::Private)); self.used.borrow_mut().push(f); // the array is a struct, the first field is the visit function, // the second field is the real array, the third field is it's length @@ -575,81 +643,162 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { let bb = self.context.append_basic_block(f, "entry"); self.builder.position_at_end(bb); let arr = f.get_nth_param(0).unwrap().into_pointer_value(); - let real_arr_raw = self.builder.build_struct_gep(arr, 1, "arr").unwrap(); + let real_arr_raw = self.builder.build_struct_gep(ty, arr, 1, "arr").unwrap(); let real_arr = self .builder - .build_load(real_arr_raw, "loaded_arr") + .build_load( + ty.get_field_type_at_index(1).unwrap(), + real_arr_raw, + "loaded_arr", + ) + .unwrap() .into_pointer_value(); - let loop_var = self.builder.build_alloca(self.context.i64_type(), "i"); + let loop_var = self + .builder + .build_alloca(self.context.i64_type(), "i") + .unwrap(); self.builder - .build_store(loop_var, self.context.i64_type().const_zero()); + .build_store(loop_var, self.context.i64_type().const_zero()) + .unwrap(); // arr is the real array - let arr_len = self.builder.build_struct_gep(arr, 2, "arr_len").unwrap(); - let arr_len = self.builder.build_load(arr_len, "arr_len").into_int_value(); + let arr_len = self + .builder + .build_struct_gep(ty, arr, 2, "arr_len") + .unwrap(); + let arr_len = self + .builder + .build_load(self.context.i64_type(), arr_len, "arr_len") + .unwrap() + .into_int_value(); // generate a loop, iterate the real array, and do nothing let condbb = self.context.append_basic_block(f, "cond"); - self.builder.build_unconditional_branch(condbb); + self.builder.build_unconditional_branch(condbb).unwrap(); self.builder.position_at_end(condbb); - let i = self.builder.build_load(loop_var, "i").into_int_value(); + let i = self + .builder + .build_load(self.context.i64_type(), loop_var, "i") + .unwrap() + .into_int_value(); let cond = self .builder - .build_int_compare(IntPredicate::ULT, i, arr_len, "cond"); + .build_int_compare(IntPredicate::ULT, i, arr_len, "cond") + .unwrap(); let loopbb = self.context.append_basic_block(f, "loop"); let endbb = self.context.append_basic_block(f, "end"); - self.builder.build_conditional_branch(cond, loopbb, endbb); + self.builder + .build_conditional_branch(cond, loopbb, endbb) + .unwrap(); self.builder.position_at_end(loopbb); - let i = self.builder.build_load(loop_var, "i").into_int_value(); - let elm = unsafe { self.builder.build_in_bounds_gep(real_arr, &[i], "elm") }; + let i = self + .builder + .build_load(self.context.i64_type(), loop_var, "i") + .unwrap() + .into_int_value(); + let elm_tp = get_type_deep(v.element_type.clone()); + let llvm_elm_tp = self.get_basic_type_op(&elm_tp.borrow(), ctx).unwrap(); + let elm = unsafe { + self.builder + .build_in_bounds_gep(llvm_elm_tp, real_arr, &[i], "elm") + } + .unwrap(); let visitor = f.get_nth_param(1).unwrap().into_pointer_value(); let visit_ptr_f = get_nth_mark_fn(f, 2); // complex type needs to provide a visit function by itself // which is stored in the first field of the struct let visit_complex_f = get_nth_mark_fn(f, 3); let visit_trait_f = get_nth_mark_fn(f, 4); - match &*get_type_deep(v.element_type.clone()).borrow() { + match &*elm_tp.borrow() { PLType::Arr(_) | PLType::Struct(_) => { - let casted = self.builder.build_bitcast(elm, i8ptrtp, "casted_arg"); + let casted = self + .builder + .build_bitcast(elm, i8ptrtp, "casted_arg") + .unwrap(); // call the visit_complex function self.builder - .build_call(visit_complex_f, &[visitor.into(), casted.into()], "call"); + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + visit_complex_f, + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); } PLType::Pointer(_) => { // call the visit_ptr function - let casted = self.builder.build_bitcast(elm, i8ptrtp, "casted_arg"); + let casted = self + .builder + .build_bitcast(elm, i8ptrtp, "casted_arg") + .unwrap(); self.builder - .build_call(visit_ptr_f, &[visitor.into(), casted.into()], "call"); + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + visit_ptr_f, + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); } PLType::Trait(_) | PLType::Union(_) | PLType::Closure(_) => { // call the visit_trait function - let casted = self.builder.build_bitcast(elm, i8ptrtp, "casted_arg"); - + let casted = self + .builder + .build_bitcast(elm, i8ptrtp, "casted_arg") + .unwrap(); self.builder - .build_call(visit_trait_f, &[visitor.into(), casted.into()], "call"); + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + visit_trait_f, + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); } PLType::Fn(_) | PLType::Primitive(_) | PLType::Void | PLType::Generic(_) - | PLType::PlaceHolder(_) => (), + | PLType::PlaceHolder(_) + | PLType::Unknown => (), } - let i = self.builder.build_load(loop_var, "i").into_int_value(); let i = self .builder - .build_int_add(i, self.context.i64_type().const_int(1, false), "i"); - self.builder.build_store(loop_var, i); - self.builder.build_unconditional_branch(condbb); + .build_load(self.context.i64_type(), loop_var, "i") + .unwrap() + .into_int_value(); + let i = self + .builder + .build_int_add(i, self.context.i64_type().const_int(1, false), "i") + .unwrap(); + self.builder.build_store(loop_var, i).unwrap(); + self.builder.build_unconditional_branch(condbb).unwrap(); self.builder.position_at_end(endbb); // call the visit_ptr function let casted = self .builder - .build_bitcast(real_arr_raw, i8ptrtp, "casted_arg"); - self.builder.build_call( - get_nth_mark_fn(f, 2), - &[visitor.into(), casted.into()], - "call", - ); - self.builder.build_return(None); + .build_bitcast(real_arr_raw, i8ptrtp, "casted_arg") + .unwrap(); + self.builder + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + get_nth_mark_fn(f, 2), + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); + self.builder.build_return(None).unwrap(); if let Some(currentbb) = currentbb { self.builder.position_at_end(currentbb); } @@ -660,21 +809,14 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { if let Some(root) = self.heap_stack_map.borrow().get(&handle) { self.handle_table.borrow().get(root).map(|v| { let handle = self.handle_table.borrow().get(&handle).copied().unwrap(); - if v.into_pointer_value().get_type().get_element_type() != handle.get_type() { - let bt: BasicTypeEnum = handle.get_type().try_into().unwrap(); - let v = self.builder.build_bitcast( + self.builder + .build_load::( + handle.get_type().try_into().unwrap(), v.into_pointer_value(), - bt.ptr_type(AddressSpace::default()), - "get_root_cast", - ); - self.builder - .build_load(v.into_pointer_value(), "load_stack") - .as_any_value_enum() - } else { - self.builder - .build_load(v.into_pointer_value(), "load_stack") - .as_any_value_enum() - } + "load_stack", + ) + .unwrap() + .as_any_value_enum() }) } else { self.handle_table.borrow().get(&handle).copied() @@ -874,6 +1016,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { ]; Some(self.context.struct_type(&fields, false).into()) } + PLType::Unknown => None, } } /// # get_ret_type @@ -1267,7 +1410,8 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { .insert(u.get_full_name(), st.as_type()); Some(st.as_type()) } - PLType::Closure(_) => self.get_ditype(&PLType::Primitive(PriType::I64), ctx), // TODO + PLType::Closure(_) => self.get_ditype(&PLType::Primitive(PriType::I64), ctx), + PLType::Unknown => None, // TODO } } @@ -1358,6 +1502,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { .map(|v| { self.builder .build_bitcast(v.as_global_value().as_pointer_value(), self.i8ptr(), "") + .unwrap() .into_pointer_value() }) .collect::>(), @@ -1390,7 +1535,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { *self.optimized.borrow_mut() = true; } - fn try_load2var_inner(&self, v: usize) -> Result { + fn try_load2var_inner(&self, v: usize, tp: &PLType, ctx: &mut Ctx<'a>) -> Result { let handle = v; let v = self.get_llvm_value(handle).unwrap(); if !v.is_pointer_value() { @@ -1410,6 +1555,32 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { Ok(self.build_load( self.get_llvm_value_handle(&v.into_pointer_value().as_any_value_enum()), "loadtmp", + tp, + ctx, + )) + } + } + fn try_load2var_inner_raw(&self, v: usize, tp: BasicTypeEnum<'ctx>) -> Result { + let handle = v; + let v = self.get_llvm_value(handle).unwrap(); + if !v.is_pointer_value() { + Ok(match v { + AnyValueEnum::ArrayValue(_) + | AnyValueEnum::IntValue(_) + | AnyValueEnum::FloatValue(_) + | AnyValueEnum::PointerValue(_) + | AnyValueEnum::StructValue(_) + | AnyValueEnum::VectorValue(_) => handle, + AnyValueEnum::FunctionValue(f) => { + return Ok(self.get_llvm_value_handle(&f.as_global_value().as_any_value_enum())); + } + _ => return Err(()), + }) + } else { + Ok(self.build_load_raw( + self.get_llvm_value_handle(&v.into_pointer_value().as_any_value_enum()), + "loadtmp", + tp, )) } } @@ -1424,19 +1595,23 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> ValueHandle { let lv = self.get_llvm_value(from).unwrap(); let re = if lv.is_function_value() { - self.builder.build_bitcast( - lv.into_function_value() - .as_global_value() - .as_pointer_value(), - self.get_basic_type_op(to, ctx).unwrap(), - name, - ) + self.builder + .build_bitcast( + lv.into_function_value() + .as_global_value() + .as_pointer_value(), + self.get_basic_type_op(to, ctx).unwrap(), + name, + ) + .unwrap() } else { - self.builder.build_bitcast( - lv.into_pointer_value(), - self.get_basic_type_op(to, ctx).unwrap(), - name, - ) + self.builder + .build_bitcast( + lv.into_pointer_value(), + self.get_basic_type_op(to, ctx).unwrap(), + name, + ) + .unwrap() }; let new_handle = self.get_llvm_value_handle(&re.as_any_value_enum()); let root = self.heap_stack_map.borrow().get(&from).copied(); @@ -1454,11 +1629,14 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> ValueHandle { let lv = self.get_llvm_value(from).unwrap(); - let re = self.builder.build_pointer_cast( - lv.into_pointer_value(), - self.get_basic_type_op(to, ctx).unwrap().into_pointer_type(), - name, - ); + let re = self + .builder + .build_pointer_cast( + lv.into_pointer_value(), + self.get_basic_type_op(to, ctx).unwrap().into_pointer_type(), + name, + ) + .unwrap(); self.get_llvm_value_handle(&re.as_any_value_enum()) } fn get_global_var_handle(&self, name: &str) -> Option { @@ -1513,22 +1691,24 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) } - fn build_load(&self, ptr: ValueHandle, name: &str) -> ValueHandle { - let ptr = self.get_llvm_value(ptr).unwrap(); - let ptr = ptr.into_pointer_value(); - let ptr = self.builder.build_load(ptr, name); - if ptr.is_pointer_value() { - self.create_root_for(ptr); - } - self.get_llvm_value_handle(&ptr.as_any_value_enum()) + fn build_load( + &self, + ptr: ValueHandle, + name: &str, + tp: &PLType, + ctx: &mut Ctx<'a>, + ) -> ValueHandle { + let llvm_type = self.get_basic_type_op(tp, ctx).unwrap(); + self.build_load_raw(ptr, name, llvm_type) } fn try_load2var( &self, range: Range, v: ValueHandle, + tp: &PLType, ctx: &mut Ctx<'a>, ) -> Result { - match self.try_load2var_inner(v) { + match self.try_load2var_inner(v, tp, ctx) { Ok(value) => Ok(value), Err(_) => Err(range.new_err(ErrorCode::EXPECT_VALUE).add_to_ctx(ctx)), } @@ -1550,23 +1730,24 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> Option { let builder = self.builder; let f = self.get_llvm_value(f).unwrap(); - let f: CallableValue = if f.is_function_value() { - f.into_function_value().into() + let f = if f.is_function_value() { + f.into_function_value().as_global_value().as_pointer_value() } else { - f.into_pointer_value().try_into().unwrap() + f.into_pointer_value() }; - let args = args + let (args, tys): (Vec<_>, Vec<_>) = args .iter() .map(|v| { let be: BasicValueEnum = self.get_llvm_value(*v).unwrap().try_into().unwrap(); + let ty: BasicMetadataTypeEnum = be.get_type().into(); let bme: BasicMetadataValueEnum = be.into(); - bme + (bme, ty) }) - .collect::>(); + .unzip(); let dbg = builder.get_current_debug_location(); let bb = builder.get_insert_block().unwrap(); // malloc ret after call is not safe, as malloc may trigger collection - let alloca = if ret_type == &PLType::Void { + let alloca = if matches!(ret_type, PLType::Void | PLType::Primitive(_)) { 0 } else { self.alloc_raw("ret_alloca", ret_type, ctx, None, "DioGC__malloc") @@ -1588,19 +1769,30 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { } } builder.position_at_end(bb); + let fntp = match self.get_basic_type_op(ret_type, ctx) { + Some(r) => r.fn_type(&tys, false), + None => self.context.void_type().fn_type(&tys, false), + }; - let v = builder.build_call(f, &args, "calltmp").try_as_basic_value(); + let v = builder + .build_indirect_call(fntp, f, &args, "calltmp") + .unwrap() + .try_as_basic_value(); if v.right().is_some() { return None; } let ret = v.left().unwrap(); builder.unset_current_debug_location(); - - self.builder.build_store( - self.get_llvm_value(alloca).unwrap().into_pointer_value(), - ret, - ); + if alloca == 0 { + return Some(self.get_llvm_value_handle(&ret.as_any_value_enum())); + } + self.builder + .build_store( + self.get_llvm_value(alloca).unwrap().into_pointer_value(), + ret, + ) + .unwrap(); Some(alloca) } fn add_function( @@ -1673,28 +1865,59 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { self.builder.position_at_end(alloca); let f_v = ctx.function.unwrap(); - let f = self.get_llvm_value(f_v).unwrap().into_function_value(); - let yield_ctx = f.get_nth_param(0).unwrap(); let bt = self.get_basic_type_op(pltype, ctx).unwrap(); - let count = add_field( - yield_ctx.as_any_value_enum(), - bt.ptr_type(Default::default()).into(), - ); + let tp = self + .get_basic_type_op(&data.borrow().ctx_tp.as_ref().unwrap().borrow(), ctx) + .unwrap() + .into_struct_type(); + let count = add_field(tp, bt.ptr_type(Default::default()).into()); let i = count - 1; let data_ptr = self - .build_struct_gep(self.get_nth_param(f_v, 0), i, name) + .build_struct_gep( + self.get_nth_param(f_v, 0), + i, + name, + &data.borrow().ctx_tp.as_ref().unwrap().borrow(), + ctx, + ) .unwrap(); - let load = self.build_load(data_ptr, "data_load"); + // 我们现在在alloca block上(第一个block),egnerator yield函数每次进入都会执行这个 + // block,所以在这里我们要设置好所有的变量初始值,将他们从generator ctx中取出。 let stack_root = self.get_stack_root(ret_handle); + let load = self.build_load_raw( + data_ptr, + &format!("data_load_{}", name), + self.i8ptr().as_basic_type_enum(), + ); self.build_store(stack_root, load); self.builder.position_at_end(lb); - let load_again = self.build_load(load, "data_load"); - data.borrow_mut().param_tmp = load_again; + // 到目前正在生成代码的block上,这里是malloc函数所在的地方。 + // malloc之后要将生成的内存保存到generator ctx中,以便下次进入时可以取出来。 + // 但是函数参数除外,在一般函数的逻辑中函数进入后要将参数保存到堆中,然而 + // generator yield函数的参数在generator init的时候就已经分配好了内存 + // 而且保存在了generator ctx中,所以这里的malloc其实是不需要的,而且不能回 + // 存到ctx里,如果回存会导致参数被覆盖。 + if !data.borrow_mut().is_para { + self.build_store(data_ptr, ret_handle); + } + let load = self.build_load_raw( + data_ptr, + &format!("data_load_{}", name), + self.i8ptr().as_basic_type_enum(), + ); + + let load_again = self.build_load_raw( + load, + &format!("data_load_again_{}", name), + self.i8ptr().as_basic_type_enum(), + ); + data.borrow_mut().para_tmp = load_again; // self.build_store(ret_handle, load_again); - self.build_store(data_ptr, ret_handle); + // self.build_store(stack_root, load); + // self.build_store(data_ptr, ret_handle); self.set_root(load, stack_root); ret_handle = load; @@ -1719,20 +1942,30 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { structv: ValueHandle, index: u32, name: &str, - ) -> Result { + tp: &PLType, + ctx: &mut Ctx<'a>, + ) -> Result { let structv = self.get_llvm_value(structv).unwrap(); let structv = structv.into_pointer_value(); - let gep = self.builder.build_struct_gep(structv, index, name); + let sttp = self.get_basic_type_op(tp, ctx).unwrap(); + let gep = self.builder.build_struct_gep(sttp, structv, index, name); if let Ok(gep) = gep { - if gep.get_type().get_element_type().is_pointer_type() { - let loadgep = self.builder.build_load(gep, "field_heap_ptr"); + let geptp = sttp + .into_struct_type() + .get_field_type_at_index(index) + .unwrap(); + if geptp.is_pointer_type() { + let loadgep = self + .builder + .build_load(geptp, gep, "field_heap_ptr") + .unwrap(); self.create_root_for(loadgep); return Ok(self.get_llvm_value_handle(&gep.as_any_value_enum())); } else { return Ok(self.get_llvm_value_handle(&gep.as_any_value_enum())); } } else { - Err(()) + Err(format!("{:?}\ntp: {:?}\nindex: {}", gep, tp, index)) } } fn build_store(&self, ptr: ValueHandle, value: ValueHandle) { @@ -1747,28 +1980,37 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { } else { value.try_into().unwrap() }; - self.builder.build_store(ptr, value); + self.builder.build_store(ptr, value).unwrap(); } fn build_const_in_bounds_gep( &self, ptr: ValueHandle, index: &[u64], name: &str, + tp: &PLType, + ctx: &mut Ctx<'a>, ) -> ValueHandle { + let tp = self.get_basic_type_op(tp, ctx).unwrap(); let ptr = self.get_llvm_value(ptr).unwrap(); let ptr = ptr.into_pointer_value(); let gep = unsafe { - self.builder.build_in_bounds_gep( - ptr, - &index - .iter() - .map(|i| self.context.i64_type().const_int(*i, false)) - .collect::>(), - name, - ) + self.builder + .build_in_bounds_gep( + tp, + ptr, + &index + .iter() + .map(|i| self.context.i64_type().const_int(*i, false)) + .collect::>(), + name, + ) + .unwrap() }; - if gep.get_type().get_element_type().is_pointer_type() { - let loadgep = self.builder.build_load(gep, "field_heap_ptr"); + if tp.is_pointer_type() { + let loadgep = self + .builder + .build_load(self.i8ptr(), gep, "field_heap_ptr") + .unwrap(); self.create_root_for(loadgep); return self.get_llvm_value_handle(&gep.as_any_value_enum()); } else { @@ -1780,31 +2022,43 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ptr: ValueHandle, index: &[ValueHandle], name: &str, + tp: &PLType, + ctx: &mut Ctx<'a>, ) -> ValueHandle { let ptr = self.get_llvm_value(ptr).unwrap(); let ptr = ptr.into_pointer_value(); + let tp = self.get_basic_type_op(tp, ctx).unwrap(); let gep = unsafe { - self.builder.build_in_bounds_gep( - ptr, - &index - .iter() - .map(|i| self.get_llvm_value(*i).unwrap().try_into().unwrap()) - .collect::>(), - name, - ) + self.builder + .build_in_bounds_gep( + tp, + ptr, + &index + .iter() + .map(|i| self.get_llvm_value(*i).unwrap().try_into().unwrap()) + .collect::>(), + name, + ) + .unwrap() }; self.get_llvm_value_handle(&gep.as_any_value_enum()) } fn const_string(&self, s: &str) -> ValueHandle { - let s = self.builder.build_global_string_ptr( - s, - format!(".str_{}", ID.fetch_add(1, Ordering::Relaxed)).as_str(), - ); - let s = self.builder.build_bitcast( - s, - self.context.i8_type().ptr_type(Default::default()), - ".str", - ); + let s = self + .builder + .build_global_string_ptr( + s, + format!(".str_{}", ID.fetch_add(1, Ordering::Relaxed)).as_str(), + ) + .unwrap(); + let s = self + .builder + .build_bitcast( + s, + self.context.i8_type().ptr_type(Default::default()), + ".str", + ) + .unwrap(); self.get_llvm_value_handle(&s.as_any_value_enum()) } @@ -1879,7 +2133,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> ValueHandle { let phi = self .builder - .build_phi(self.get_basic_type_op(pltype, ctx).unwrap(), ""); + .build_phi(self.get_basic_type_op(pltype, ctx).unwrap(), "") + .unwrap(); for (value, block) in vbs { let value = self.get_llvm_value(*value).unwrap().into_int_value(); let block = self.get_llvm_block(*block).unwrap(); @@ -1890,7 +2145,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { fn build_unconditional_branch(&self, bb: BlockHandle) { let bb = self.get_llvm_block(bb).unwrap(); - self.builder.build_unconditional_branch(bb); + self.builder.build_unconditional_branch(bb).unwrap(); } fn get_first_instruction(&self, bb: BlockHandle) -> Option { @@ -1941,19 +2196,19 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { fn build_int_z_extend(&self, v: ValueHandle, ty: &PriType, name: &str) -> ValueHandle { let v = self.get_llvm_value(v).unwrap().into_int_value(); let ty = self.get_pri_basic_type(ty).into_int_type(); - let v = self.builder.build_int_z_extend(v, ty, name); + let v = self.builder.build_int_z_extend(v, ty, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_or(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_or(lhs, rhs, name); + let v = self.builder.build_or(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_and(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_and(lhs, rhs, name); + let v = self.builder.build_and(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_float_compare( @@ -1966,7 +2221,10 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let lhs = self.get_llvm_value(lhs).unwrap().into_float_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_float_value(); - let v = self.builder.build_float_compare(op.into(), lhs, rhs, name); + let v = self + .builder + .build_float_compare(op.into(), lhs, rhs, name) + .unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_compare( @@ -1978,36 +2236,39 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_compare(op.into(), lhs, rhs, name); + let v = self + .builder + .build_int_compare(op.into(), lhs, rhs, name) + .unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_neg(&self, v: ValueHandle, name: &str) -> ValueHandle { let v = self.get_llvm_value(v).unwrap().into_int_value(); - let v = self.builder.build_int_neg(v, name); + let v = self.builder.build_int_neg(v, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_add(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_add(lhs, rhs, name); + let v = self.builder.build_int_add(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_sub(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_sub(lhs, rhs, name); + let v = self.builder.build_int_sub(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_mul(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_mul(lhs, rhs, name); + let v = self.builder.build_int_mul(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_signed_div(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_signed_div(lhs, rhs, name); + let v = self.builder.build_int_signed_div(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } @@ -2019,14 +2280,14 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_unsigned_div(lhs, rhs, name); + let v = self.builder.build_int_unsigned_div(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_int_signed_srem(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_signed_rem(lhs, rhs, name); + let v = self.builder.build_int_signed_rem(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } @@ -2038,37 +2299,37 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { ) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_int_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_int_value(); - let v = self.builder.build_int_unsigned_rem(lhs, rhs, name); + let v = self.builder.build_int_unsigned_rem(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_float_neg(&self, v: ValueHandle, name: &str) -> ValueHandle { let v = self.get_llvm_value(v).unwrap().into_float_value(); - let v = self.builder.build_float_neg(v, name); + let v = self.builder.build_float_neg(v, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_float_add(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_float_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_float_value(); - let v = self.builder.build_float_add(lhs, rhs, name); + let v = self.builder.build_float_add(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_float_sub(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_float_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_float_value(); - let v = self.builder.build_float_sub(lhs, rhs, name); + let v = self.builder.build_float_sub(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_float_mul(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_float_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_float_value(); - let v = self.builder.build_float_mul(lhs, rhs, name); + let v = self.builder.build_float_mul(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_float_div(&self, lhs: ValueHandle, rhs: ValueHandle, name: &str) -> ValueHandle { let lhs = self.get_llvm_value(lhs).unwrap().into_float_value(); let rhs = self.get_llvm_value(rhs).unwrap().into_float_value(); - let v = self.builder.build_float_div(lhs, rhs, name); + let v = self.builder.build_float_div(lhs, rhs, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn append_basic_block(&self, func: ValueHandle, name: &str) -> BlockHandle { @@ -2081,7 +2342,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { fn build_int_truncate(&self, v: ValueHandle, dest_ty: &PriType, name: &str) -> ValueHandle { let v = self.get_llvm_value(v).unwrap().into_int_value(); let dest_ty = self.get_pri_basic_type(dest_ty).into_int_type(); - let v = self.builder.build_int_truncate(v, dest_ty, name); + let v = self.builder.build_int_truncate(v, dest_ty, name).unwrap(); self.get_llvm_value_handle(&v.as_any_value_enum()) } fn build_conditional_branch( @@ -2094,7 +2355,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let then_bb = self.get_llvm_block(then_bb).unwrap(); let else_bb = self.get_llvm_block(else_bb).unwrap(); self.builder - .build_conditional_branch(cond, then_bb, else_bb); + .build_conditional_branch(cond, then_bb, else_bb) + .unwrap(); } fn rm_curr_debug_location(&self) { self.builder.unset_current_debug_location(); @@ -2215,9 +2477,9 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { if let Some(v) = v { let v = self.get_llvm_value(v).unwrap(); let v: BasicValueEnum = v.try_into().unwrap(); - self.builder.build_return(Some(&v)); + self.builder.build_return(Some(&v)).unwrap(); } else { - self.builder.build_return(None); + self.builder.build_return(None).unwrap(); } } #[allow(clippy::too_many_arguments)] @@ -2250,8 +2512,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { .unwrap(); let raw_tp = v.get_type(); // self.builder.position_at_end(allocab); - let alloca = self.builder.build_alloca(raw_tp, "para"); - self.builder.build_store(alloca, v); + let alloca = self.builder.build_alloca(raw_tp, "para").unwrap(); + self.builder.build_store(alloca, v).unwrap(); self.dibuilder.insert_declare_at_end( alloca, Some(divar), @@ -2308,7 +2570,13 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { self.position_at_end_block(bb_v); let ctx_v = data.borrow().ctx_handle; let para_ptr = self - .build_struct_gep(ctx_v, (i + 2) as u32, "para") + .build_struct_gep( + ctx_v, + (i + 2) as u32, + "para", + &data.borrow().ctx_tp.as_ref().unwrap().borrow(), + child, + ) .unwrap(); child.ctx_flag = CtxFlag::Normal; let ptr = self.alloc("param_ptr", tp, child, None); @@ -2325,7 +2593,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { self.build_store(para_ptr, ptr); self.position_at_end_block(origin_bb); - self.build_store(alloca, data.borrow().param_tmp); + self.build_store(alloca, data.borrow().para_tmp); return; } let funcvalue = self @@ -2396,8 +2664,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let currentbb = ctx.block; self.builder.unset_current_debug_location(); let i8ptrtp = self.context.i8_type().ptr_type(AddressSpace::default()); - let ptrtp = self.struct_type(v, ctx).ptr_type(AddressSpace::default()); - let ty = ptrtp.get_element_type().into_struct_type(); + let ty = self.struct_type(v, ctx); + let ptrtp = ty.ptr_type(AddressSpace::default()); let ftp = self.mark_fn_tp(ptrtp); let name = v.get_full_name() + "@"; // if !name.starts_with(&ctx.get_root_ctx().get_file()) { @@ -2428,42 +2696,73 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { // which is stored in the first field of the struct let visit_complex_f = get_nth_mark_fn(f, 3); let visit_trait_f = get_nth_mark_fn(f, 4); - let f = self.builder.build_struct_gep(st, i, "gep").unwrap(); + let f = self.builder.build_struct_gep(ty, st, i, "gep").unwrap(); // 指针类型,递归调用visit函数 match field_pltp { PLType::Pointer(_) => { let ptr = f; - let casted = self.builder.build_bitcast(ptr, i8ptrtp, "casted_arg"); + let casted = self + .builder + .build_bitcast(ptr, i8ptrtp, "casted_arg") + .unwrap(); self.builder - .build_call(visit_ptr_f, &[visitor.into(), casted.into()], "call"); + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + visit_ptr_f, + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); } PLType::Struct(_) | PLType::Arr(_) => { let ptr = f; - let casted = self.builder.build_bitcast(ptr, i8ptrtp, "casted_arg"); - self.builder.build_call( - visit_complex_f, - &[visitor.into(), casted.into()], - "call", - ); + let casted = self + .builder + .build_bitcast(ptr, i8ptrtp, "casted_arg") + .unwrap(); + self.builder + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + visit_complex_f, + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); } PLType::Trait(_) | PLType::Union(_) | PLType::Closure(_) => { let ptr = f; - let casted = self.builder.build_bitcast(ptr, i8ptrtp, "casted_arg"); - self.builder.build_call( - visit_trait_f, - &[visitor.into(), casted.into()], - "call", - ); + let casted = self + .builder + .build_bitcast(ptr, i8ptrtp, "casted_arg") + .unwrap(); + self.builder + .build_indirect_call( + self.context.void_type().fn_type( + &[visitor.get_type().into(), casted.get_type().into()], + false, + ), + visit_trait_f, + &[visitor.into(), casted.into()], + "call", + ) + .unwrap(); } PLType::Fn(_) | PLType::Primitive(_) | PLType::Void | PLType::Generic(_) - | PLType::PlaceHolder(_) => (), + | PLType::PlaceHolder(_) + | PLType::Unknown => (), } // 其他为原子类型,跳过 } - self.builder.build_return(None); + self.builder.build_return(None).unwrap(); if let Some(currentbb) = currentbb { self.builder .position_at_end(self.get_llvm_block(currentbb).unwrap()); @@ -2484,35 +2783,44 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let target = target.into_int_type(); let val = self .builder - .build_int_cast_sign_flag(val, target, signed, "cast"); + .build_int_cast_sign_flag(val, target, signed, "cast") + .unwrap(); self.get_llvm_value_handle(&val.into()) } else if tp.is_float_type() && target.is_float_type() { let val = val.into_float_value(); let target = target.into_float_type(); - let val = self.builder.build_float_cast(val, target, "cast"); + let val = self.builder.build_float_cast(val, target, "cast").unwrap(); self.get_llvm_value_handle(&val.into()) } else if tp.is_int_type() && target.is_float_type() { let val = val.into_int_value(); let target = target.into_float_type(); if signed { - let val = self.builder.build_signed_int_to_float(val, target, "cast"); + let val = self + .builder + .build_signed_int_to_float(val, target, "cast") + .unwrap(); self.get_llvm_value_handle(&val.into()) } else { let val = self .builder - .build_unsigned_int_to_float(val, target, "cast"); + .build_unsigned_int_to_float(val, target, "cast") + .unwrap(); self.get_llvm_value_handle(&val.into()) } } else if tp.is_float_type() && target.is_int_type() { let val = val.into_float_value(); let target = target.into_int_type(); if signed { - let val = self.builder.build_float_to_signed_int(val, target, "cast"); + let val = self + .builder + .build_float_to_signed_int(val, target, "cast") + .unwrap(); self.get_llvm_value_handle(&val.into()) } else { let val = self .builder - .build_float_to_unsigned_int(val, target, "cast"); + .build_float_to_unsigned_int(val, target, "cast") + .unwrap(); self.get_llvm_value_handle(&val.into()) } } else { @@ -2602,19 +2910,22 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let old_bb = self.builder.get_insert_block(); self.builder.position_at_end(bb); let args = f.get_params(); - let re = self.builder.build_call( - ori_f, - &args - .iter() - .skip(1) - .map(|a| a.to_owned().into()) - .collect::>(), - "re", - ); + let re = self + .builder + .build_call( + ori_f, + &args + .iter() + .skip(1) + .map(|a| a.to_owned().into()) + .collect::>(), + "re", + ) + .unwrap(); if let Some(ret) = re.try_as_basic_value().left() { - self.builder.build_return(Some(&ret)); + self.builder.build_return(Some(&ret)).unwrap(); } else { - self.builder.build_return(None); + self.builder.build_return(None).unwrap(); } if let Some(old_bb) = old_bb { self.builder.position_at_end(old_bb); @@ -2626,8 +2937,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let funcvalue = self.get_llvm_value(f).unwrap().into_function_value(); self.get_llvm_value_handle(&funcvalue.get_nth_param(i).unwrap().into()) } - fn add_closure_st_field(&self, st: ValueHandle, field: ValueHandle) { - let st_v = self.handle_table.borrow().get(&st).copied().unwrap(); + fn add_closure_st_field(&self, st: &STType, field: ValueHandle, ctx: &mut Ctx<'a>) { + let st_tp = self.struct_type(st, ctx); let field_tp = self .handle_table .borrow() @@ -2635,7 +2946,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { .copied() .unwrap() .get_type(); - add_field(st_v, field_tp); + add_field(st_tp, field_tp); } fn add_generator_yield_fn( @@ -2673,20 +2984,22 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { fn build_indirect_br(&self, block: ValueHandle, ctx: &Ctx<'a>) { let block = self.get_llvm_value(block).unwrap(); let bv = self.get_llvm_block(ctx.block.unwrap()).unwrap(); - self.builder.build_indirect_branch::( - block.try_into().unwrap(), - &bv.get_parent() - .unwrap() - .get_basic_blocks() - .iter() - .skip(1) - .copied() - .filter(|b| { - b.get_name().to_str().unwrap().to_string().contains("yield") - || b.get_name().to_str().unwrap().to_string().contains("entry") - }) - .collect::>(), - ); + self.builder + .build_indirect_branch::( + block.try_into().unwrap(), + &bv.get_parent() + .unwrap() + .get_basic_blocks() + .iter() + .skip(1) + .copied() + .filter(|b| { + b.get_name().to_str().unwrap().to_string().contains("yield") + || b.get_name().to_str().unwrap().to_string().contains("entry") + }) + .collect::>(), + ) + .unwrap(); } unsafe fn store_with_aoto_cast(&self, ptr: ValueHandle, value: ValueHandle) { let v_ptr = self.get_llvm_value(ptr).unwrap(); @@ -2698,18 +3011,9 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { } else { v.try_into().unwrap() }; - let ptr_tp = v_ptr.get_type().into_pointer_type(); - let value_tp = v.get_type(); - if ptr_tp.get_element_type() != value_tp.as_any_type_enum() { - let casted = self.builder.build_bitcast::<_, BasicValueEnum>( - v_ptr.try_into().unwrap(), - value_tp.ptr_type(Default::default()), - "cast", - ); - self.builder.build_store(casted.into_pointer_value(), v); - } else { - self.build_store(ptr, value); - } + let _ptr_tp = v_ptr.get_type().into_pointer_type(); + let _value_tp = v.get_type(); + self.build_store(ptr, value); } fn stack_alloc(&self, name: &str, ctx: &mut Ctx<'a>, tp: &PLType) -> ValueHandle { let lb = self.builder.get_insert_block().unwrap(); @@ -2723,7 +3027,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { .get_first_basic_block() .unwrap(); self.builder.position_at_end(alloca); - let stack_ptr = self.builder.build_alloca(llvmtp, name); + let stack_ptr = self.builder.build_alloca(llvmtp, name).unwrap(); self.builder.position_at_end(lb); self.get_llvm_value_handle(&stack_ptr.as_any_value_enum()) } @@ -2743,27 +3047,34 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { .get_llvm_value(data.borrow().ctx_size_handle) .unwrap() .into_pointer_value(); - self.builder.build_store(v, size); + self.builder.build_store(v, size).unwrap(); self.builder.position_at_end(cur_bb); } - fn build_memcpy(&self, from: ValueHandle, to: ValueHandle, len: ValueHandle) { - let from = self - .get_llvm_value(self.try_load2var_inner(from).unwrap()) - .unwrap() - .into_pointer_value(); - let to = self - .get_llvm_value(self.try_load2var_inner(to).unwrap()) - .unwrap() - .into_pointer_value(); + fn build_memcpy( + &self, + from: ValueHandle, + from_tp: &PLType, + to: ValueHandle, + len: ValueHandle, + ctx: &mut Ctx<'a>, + ) { + let from = self.get_llvm_value(from).unwrap().into_pointer_value(); + let to = self.get_llvm_value(to).unwrap().into_pointer_value(); let td = self.targetmachine.get_target_data(); - let unit_size = td.get_store_size(&from.get_type().get_element_type()); + let unit_size = td.get_store_size(&self.get_basic_type_op(from_tp, ctx).unwrap()); let i64_size = self.context.i64_type().const_int(unit_size, true); let len = self - .get_llvm_value(self.try_load2var_inner(len).unwrap()) + .get_llvm_value( + self.try_load2var_inner_raw(len, self.context.i64_type().as_basic_type_enum()) + .unwrap(), + ) .unwrap() .into_int_value(); - let arg_len = self.builder.build_int_mul(len, i64_size, "arg_len"); + let arg_len = self + .builder + .build_int_mul(len, i64_size, "arg_len") + .unwrap(); self.builder.build_memcpy(to, 8, from, 8, arg_len).unwrap(); } fn build_bit_not(&self, v: ValueHandle) -> ValueHandle { @@ -2771,7 +3082,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let v = v.into_int_value(); let v = self .builder - .build_xor(v, v.get_type().const_all_ones(), "not"); + .build_xor(v, v.get_type().const_all_ones(), "not") + .unwrap(); self.get_llvm_value_handle(&v.into()) } fn build_bit_and(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle { @@ -2779,7 +3091,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let rhs = self.get_llvm_value(rhs).unwrap(); let lhs = lhs.into_int_value(); let rhs = rhs.into_int_value(); - let v = self.builder.build_and(lhs, rhs, "and"); + let v = self.builder.build_and(lhs, rhs, "and").unwrap(); self.get_llvm_value_handle(&v.into()) } fn build_bit_or(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle { @@ -2787,7 +3099,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let rhs = self.get_llvm_value(rhs).unwrap(); let lhs = lhs.into_int_value(); let rhs = rhs.into_int_value(); - let v = self.builder.build_or(lhs, rhs, "or"); + let v = self.builder.build_or(lhs, rhs, "or").unwrap(); self.get_llvm_value_handle(&v.into()) } fn build_bit_xor(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle { @@ -2795,7 +3107,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let rhs = self.get_llvm_value(rhs).unwrap(); let lhs = lhs.into_int_value(); let rhs = rhs.into_int_value(); - let v = self.builder.build_xor(lhs, rhs, "xor"); + let v = self.builder.build_xor(lhs, rhs, "xor").unwrap(); self.get_llvm_value_handle(&v.into()) } fn build_bit_left_shift(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle { @@ -2803,7 +3115,10 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let rhs = self.get_llvm_value(rhs).unwrap(); let lhs = lhs.into_int_value(); let rhs = rhs.into_int_value(); - let v = self.builder.build_left_shift(lhs, rhs, "left_shift"); + let v = self + .builder + .build_left_shift(lhs, rhs, "left_shift") + .unwrap(); self.get_llvm_value_handle(&v.into()) } fn build_bit_right_shift(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle { @@ -2813,7 +3128,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let rhs = rhs.into_int_value(); let v = self .builder - .build_right_shift(lhs, rhs, false, "right_shift"); + .build_right_shift(lhs, rhs, false, "right_shift") + .unwrap(); self.get_llvm_value_handle(&v.into()) } fn build_bit_right_shift_arithmetic(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle { @@ -2823,17 +3139,18 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let rhs = rhs.into_int_value(); let v = self .builder - .build_right_shift(lhs, rhs, true, "right_shift"); + .build_right_shift(lhs, rhs, true, "right_shift") + .unwrap(); self.get_llvm_value_handle(&v.into()) } } -fn add_field(st_v: AnyValueEnum<'_>, field_tp: inkwell::types::AnyTypeEnum<'_>) -> u32 { - let st_tp = st_v - .get_type() - .into_pointer_type() - .get_element_type() - .into_struct_type(); +fn add_field(st_tp: StructType, field_tp: inkwell::types::AnyTypeEnum<'_>) -> u32 { + // let st_tp = st_v + // .get_type() + // .into_pointer_type() + // .get_element_type() + // .into_struct_type(); let mut closure_data_tps = st_tp.get_field_types(); closure_data_tps.push(field_tp.try_into().unwrap()); set_body(&st_tp, &closure_data_tps, false); diff --git a/src/ast/builder/mod.rs b/src/ast/builder/mod.rs index fd281b623..83d70d093 100644 --- a/src/ast/builder/mod.rs +++ b/src/ast/builder/mod.rs @@ -52,14 +52,22 @@ pub trait IRBuilder<'a, 'ctx> { then_bb: BlockHandle, else_bb: BlockHandle, ); - fn build_const_in_bounds_gep(&self, ptr: ValueHandle, index: &[u64], name: &str) - -> ValueHandle; + fn build_const_in_bounds_gep( + &self, + ptr: ValueHandle, + index: &[u64], + name: &str, + tp: &PLType, + ctx: &mut Ctx<'a>, + ) -> ValueHandle; fn build_dbg_location(&self, pos: Pos); fn build_in_bounds_gep( &self, ptr: ValueHandle, index: &[ValueHandle], name: &str, + tp: &PLType, + ctx: &mut Ctx<'a>, ) -> ValueHandle; fn build_return(&self, v: Option); fn build_store(&self, ptr: ValueHandle, value: ValueHandle); @@ -68,7 +76,9 @@ pub trait IRBuilder<'a, 'ctx> { structv: ValueHandle, index: u32, name: &str, - ) -> Result; + tp: &PLType, + ctx: &mut Ctx<'a>, + ) -> Result; fn build_sub_program( &self, paralist: Vec>, @@ -129,11 +139,18 @@ pub trait IRBuilder<'a, 'ctx> { ctx: &mut Ctx<'a>, constant: bool, ) -> ValueHandle; - fn build_load(&self, ptr: ValueHandle, name: &str) -> ValueHandle; + fn build_load( + &self, + ptr: ValueHandle, + name: &str, + tp: &PLType, + ctx: &mut Ctx<'a>, + ) -> ValueHandle; fn try_load2var( &self, range: Range, v: ValueHandle, + tp: &PLType, ctx: &mut Ctx<'a>, ) -> Result; fn get_function(&self, name: &str) -> Option; @@ -211,7 +228,7 @@ pub trait IRBuilder<'a, 'ctx> { fn get_closure_trampoline(&self, f: ValueHandle) -> ValueHandle; fn create_closure_parameter_variable(&self, i: u32, f: ValueHandle, alloca: ValueHandle); fn get_nth_param(&self, f: ValueHandle, i: u32) -> ValueHandle; - fn add_closure_st_field(&self, st: ValueHandle, field: ValueHandle); + fn add_closure_st_field(&self, st: &STType, field: ValueHandle, ctx: &mut Ctx<'a>); fn build_sub_program_by_pltp( &self, paralist: &[Arc>], @@ -245,7 +262,14 @@ pub trait IRBuilder<'a, 'ctx> { fn stack_alloc(&self, name: &str, ctx: &mut Ctx<'a>, tp: &PLType) -> ValueHandle; fn correct_generator_ctx_malloc_inst(&self, ctx: &mut Ctx<'a>, name: &str); fn sizeof(&self, pltype: &PLType, ctx: &mut Ctx<'a>) -> u64; - fn build_memcpy(&self, from: ValueHandle, to: ValueHandle, len: ValueHandle); + fn build_memcpy( + &self, + from: ValueHandle, + from_tp: &PLType, + to: ValueHandle, + len: ValueHandle, + ctx: &mut Ctx<'a>, + ); fn build_bit_not(&self, v: ValueHandle) -> ValueHandle; fn build_bit_and(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle; fn build_bit_or(&self, lhs: ValueHandle, rhs: ValueHandle) -> ValueHandle; diff --git a/src/ast/builder/no_op_builder.rs b/src/ast/builder/no_op_builder.rs index 7572f3f94..4eb843ae3 100644 --- a/src/ast/builder/no_op_builder.rs +++ b/src/ast/builder/no_op_builder.rs @@ -75,6 +75,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { _ptr: super::ValueHandle, _index: &[u64], _name: &str, + _tp: &PLType, + _ctx: &mut Ctx<'a>, ) -> super::ValueHandle { 0 } @@ -86,6 +88,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { _ptr: super::ValueHandle, _index: &[super::ValueHandle], _name: &str, + _tp: &PLType, + _ctx: &mut Ctx<'a>, ) -> super::ValueHandle { 0 } @@ -99,7 +103,9 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { _structv: super::ValueHandle, _index: u32, _name: &str, - ) -> Result { + _tp: &PLType, + _ctx: &mut Ctx<'a>, + ) -> Result { Ok(0) } @@ -225,7 +231,13 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { 0 } - fn build_load(&self, _ptr: super::ValueHandle, _name: &str) -> super::ValueHandle { + fn build_load( + &self, + _ptr: super::ValueHandle, + _name: &str, + _tp: &PLType, + _ctx: &mut Ctx<'a>, + ) -> super::ValueHandle { 0 } @@ -233,7 +245,8 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { &self, _range: crate::ast::range::Range, _v: super::ValueHandle, - _ctx: &mut crate::ast::ctx::Ctx<'a>, + _tp: &PLType, + _ctx: &mut Ctx<'a>, ) -> Result { Ok(0) } @@ -450,7 +463,7 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { f } - fn add_closure_st_field(&self, _st: ValueHandle, _field: ValueHandle) {} + fn add_closure_st_field(&self, _st: &STType, _field: ValueHandle, _ctx: &mut Ctx<'a>) {} fn build_sub_program_by_pltp( &self, @@ -501,7 +514,15 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for NoOpBuilder<'a, 'ctx> { fn sizeof(&self, _pltype: &PLType, _ctx: &mut Ctx<'a>) -> u64 { 0 } - fn build_memcpy(&self, _from: ValueHandle, _to: ValueHandle, _len: ValueHandle) {} + fn build_memcpy( + &self, + _from: ValueHandle, + _from_tp: &PLType, + _to: ValueHandle, + _len: ValueHandle, + _ctx: &mut Ctx<'a>, + ) { + } fn build_bit_not(&self, _v: ValueHandle) -> ValueHandle { 0 diff --git a/src/ast/compiler.rs b/src/ast/compiler.rs index f777d3857..c67bc5ce5 100644 --- a/src/ast/compiler.rs +++ b/src/ast/compiler.rs @@ -109,7 +109,7 @@ pub fn run_pass(llvmmod: &Module, op: OptimizationLevel) { pass_manager_builder.set_size_level(0); pass_manager_builder.populate_function_pass_manager(&fpm); pass_manager_builder.populate_module_pass_manager(&mpm); - pass_manager_builder.populate_lto_pass_manager(&mpm, false, true); + // pass_manager_builder.populate_lto_pass_manager(&mpm, false, true); } let b = fpm.initialize(); trace!("fpm init: {}", b); diff --git a/src/ast/ctx.rs b/src/ast/ctx.rs index b46ae421f..655fcdd2f 100644 --- a/src/ast/ctx.rs +++ b/src/ast/ctx.rs @@ -29,11 +29,13 @@ use crate::ast::builder::BuilderEnum; use crate::ast::builder::IRBuilder; use crate::format_label; +use crate::inference::TyVariable; use crate::utils::read_config::Config; use crate::Db; use crate::ast::node::function::generator::GeneratorCtxData; +use ena::unify::UnificationTable; use indexmap::IndexMap; use lsp_types::CompletionItem; @@ -120,7 +122,7 @@ pub struct Ctx<'a> { pub macro_loop_len: usize, pub temp_source: Option, pub in_macro: bool, - pub closure_data: Option>, + pub closure_data: Option>>, pub expect_ty: Option>>, pub self_ref_map: FxHashMap>, // used to recognize self reference pub ctx_flag: CtxFlag, @@ -131,6 +133,8 @@ pub struct Ctx<'a> { is_active_file: bool, as_root: bool, macro_expand_depth: Arc>, + pub unify_table: Arc>>, + pub disable_diag: bool, } #[derive(Clone, PartialEq, Eq, Debug)] @@ -162,18 +166,24 @@ impl<'a, 'ctx> Ctx<'a> { pub fn is_active_file(&self) -> bool { self.is_active_file } - pub fn add_term_to_previous_yield( - &self, - builder: &'a BuilderEnum<'a, 'ctx>, + pub fn add_term_to_previous_yield<'b>( + &'b mut self, + builder: &'b BuilderEnum<'a, 'ctx>, curbb: usize, ) -> Arc> { let ctx = self; - let data = ctx.generator_data.as_ref().unwrap(); + let data = ctx.generator_data.as_ref().unwrap().clone(); if let Some(prev_bb) = data.borrow().prev_yield_bb { builder.position_at_end_block(prev_bb); let ctx_handle = builder.get_nth_param(ctx.function.unwrap(), 0); let ptr = builder - .build_struct_gep(ctx_handle, 1, "block_ptr") + .build_struct_gep( + ctx_handle, + 1, + "block_ptr", + &data.borrow().ctx_tp.as_ref().unwrap().borrow(), + ctx, + ) .unwrap(); let addr = builder.get_block_address(curbb); @@ -281,6 +291,8 @@ impl<'a, 'ctx> Ctx<'a> { is_active_file, as_root: false, macro_expand_depth: Default::default(), + unify_table: Arc::new(RefCell::new(UnificationTable::new())), + disable_diag: false, } } pub fn new_child(&'a self, start: Pos, builder: &'a BuilderEnum<'a, 'ctx>) -> Ctx<'a> { @@ -324,6 +336,8 @@ impl<'a, 'ctx> Ctx<'a> { is_active_file: self.is_active_file, as_root: false, macro_expand_depth: self.macro_expand_depth.clone(), + unify_table: self.unify_table.clone(), + disable_diag: self.disable_diag, }; add_primitive_types(&mut ctx); if start != Default::default() { @@ -570,7 +584,16 @@ impl<'a, 'ctx> Ctx<'a> { // captured by closure let new_symbol = symbol.clone(); let len = data.table.len(); - builder.add_closure_st_field(data.data_handle, new_symbol.value); + let st_r = data.data_tp.as_ref().unwrap().borrow(); + let st = match &*st_r { + PLType::Struct(s) => s, + _ => unreachable!(), + }; + let ptr = father as *const _; + let ptr = ptr as usize; + let ptr = ptr as *mut Ctx<'_>; + builder.add_closure_st_field(st, new_symbol.value, unsafe { &mut *ptr }); + drop(st_r); let new_symbol = PLSymbolData { value: builder.build_load( builder @@ -578,9 +601,13 @@ impl<'a, 'ctx> Ctx<'a> { data.data_handle, len as u32 + 1, "closure_tmp", + &data.data_tp.as_ref().unwrap().borrow(), + unsafe { &mut *ptr }, ) .unwrap(), "closure_loaded", + &PLType::Pointer(Arc::new(RefCell::new(PLType::new_i8_ptr()))), + unsafe { &mut *ptr }, ), ..new_symbol }; @@ -840,6 +867,9 @@ impl<'a, 'ctx> Ctx<'a> { } pub fn add_diag(&self, mut dia: PLDiag) -> PLDiag { + if self.disable_diag { + return dia; + } if let Some(src) = &self.temp_source { dia.set_source(src); } @@ -853,8 +883,9 @@ impl<'a, 'ctx> Ctx<'a> { range: Range, v: ValueHandle, builder: &'b BuilderEnum<'a, 'ctx>, + tp: &PLType, ) -> Result { - builder.try_load2var(range, v, self) + builder.try_load2var(range, v, tp, self) } fn set_mod(&mut self, plmod: Mod) -> Mod { let m = self.plmod.clone(); @@ -982,7 +1013,7 @@ impl<'a, 'ctx> Ctx<'a> { /// # auto_deref /// 自动解引用,有几层解几层 pub fn auto_deref<'b>( - &'b self, + &'b mut self, tp: Arc>, value: ValueHandle, builder: &'b BuilderEnum<'a, 'ctx>, @@ -990,8 +1021,9 @@ impl<'a, 'ctx> Ctx<'a> { let mut tp = tp; let mut value = value; while let PLType::Pointer(p) = &*get_type_deep(tp.clone()).borrow() { + let old_tp = tp.clone(); tp = p.clone(); - value = builder.build_load(value, "load"); + value = builder.build_load(value, "load", &old_tp.borrow(), self); } (tp, value) } diff --git a/src/ast/ctx/builtins.rs b/src/ast/ctx/builtins.rs index 39be0edca..bbe4c3f35 100644 --- a/src/ast/ctx/builtins.rs +++ b/src/ast/ctx/builtins.rs @@ -342,11 +342,25 @@ fn emit_arr_from_raw<'a, 'b>( size_handle: 0, }))); let arr = builder.alloc("array_alloca", &arr_tp.borrow(), ctx, None); - let arr_raw = builder.build_struct_gep(arr, 1, "arr_raw").unwrap(); - let loaded = ctx.try_load2var(f.paralist[0].range(), v.get_value(), builder)?; + let arr_raw = builder + .build_struct_gep(arr, 1, "arr_raw", &arr_tp.borrow(), ctx) + .unwrap(); + let loaded = ctx.try_load2var( + f.paralist[0].range(), + v.get_value(), + builder, + &v.get_ty().borrow(), + )?; builder.build_store(arr_raw, loaded); - let arr_len = builder.build_struct_gep(arr, 2, "arr_len").unwrap(); - let loaded = ctx.try_load2var(f.paralist[1].range(), v2.get_value(), builder)?; + let arr_len = builder + .build_struct_gep(arr, 2, "arr_len", &arr_tp.borrow(), ctx) + .unwrap(); + let loaded = ctx.try_load2var( + f.paralist[1].range(), + v2.get_value(), + builder, + &v2.get_ty().borrow(), + )?; builder.build_store(arr_len, loaded); arr.new_output(arr_tp).to_result() @@ -379,7 +393,7 @@ fn emit_arr_len<'a, 'b>( .add_to_ctx(ctx)); } let len = builder - .build_struct_gep(v.get_value(), 2, "arr_len") + .build_struct_gep(v.get_value(), 2, "arr_len", &v.get_ty().borrow(), ctx) .unwrap(); len.new_output(Arc::new(RefCell::new(PLType::Primitive(PriType::I64)))) .to_result() @@ -439,14 +453,32 @@ fn emit_arr_copy<'a, 'b>( } let from_raw = builder - .build_struct_gep(v.get_value(), 1, "arr_raw") + .build_struct_gep(v.get_value(), 1, "arr_raw", &v.get_ty().borrow(), ctx) .unwrap(); let to_raw = builder - .build_struct_gep(to.get_value(), 1, "arr_raw") + .build_struct_gep(to.get_value(), 1, "arr_raw", &to.get_ty().borrow(), ctx) .unwrap(); + let len_raw = len.get_value(); - builder.build_memcpy(from_raw, to_raw, len_raw); - Ok(Default::default()) + match (&*v.get_ty().borrow(), &*to.get_ty().borrow()) { + (PLType::Arr(a1), PLType::Arr(a2)) => { + let from_raw = builder.build_load( + from_raw, + "arr_load_field", + &PLType::Pointer(a1.element_type.clone()), + ctx, + ); + let to_raw = builder.build_load( + to_raw, + "arr_load_field", + &PLType::Pointer(a2.element_type.clone()), + ctx, + ); + builder.build_memcpy(from_raw, &a1.element_type.borrow(), to_raw, len_raw, ctx); + Ok(Default::default()) + } + _ => unreachable!(), + } } fn emit_name_of<'a, 'b>( @@ -604,7 +636,9 @@ fn emit_for_fields<'a, 'b>( if let Some(s) = stp { if let PLType::Struct(sttp) = &*s.borrow() { for (name, field) in sttp.fields.iter() { - let gep = builder.build_struct_gep(v, field.index, "tmp_gep").unwrap(); + let gep = builder + .build_struct_gep(v, field.index, "tmp_gep", &s.borrow(), ctx) + .unwrap(); ctx.run_in_origin_mod(|ctx| { ctx.run_in_type_mod(sttp, |ctx, _| { let field_tp = field.typenode.get_type(ctx, builder, true)?; @@ -798,11 +832,16 @@ fn emit_if_union<'a, 'b>( if let Some(s) = stp { if let PLType::Union(u) = &*s.borrow() { - let gep = builder.build_struct_gep(v, 0, "tmp_gep").unwrap(); - let ptr = builder.build_struct_gep(v, 1, "inner_ptr").unwrap(); - let ptr = builder.build_load(ptr, "inner_ptr"); - - let u_tp_i = builder.build_load(gep, "u_tp_i"); + let gep = builder + .build_struct_gep(v, 0, "tmp_gep", &s.borrow(), ctx) + .unwrap(); + let ptr = builder + .build_struct_gep(v, 1, "inner_ptr", &s.borrow(), ctx) + .unwrap(); + let ptr = builder.build_load(ptr, "inner_ptr", &PLType::Pointer(s.clone()), ctx); + + let u_tp_i = + builder.build_load(gep, "u_tp_i", &PLType::Primitive(PriType::I64), ctx); let after_bb = builder.append_basic_block(ctx.function.unwrap(), "after"); for (i, tp) in u.get_sum_types(ctx, builder).iter().enumerate() { diff --git a/src/ast/ctx/cast.rs b/src/ast/ctx/cast.rs index f41be83f6..f47547cb2 100644 --- a/src/ast/ctx/cast.rs +++ b/src/ast/ctx/cast.rs @@ -40,17 +40,20 @@ pub(crate) fn store_trait_hash_and_ptr<'a, T: TraitImplAble>( st_value: usize, trait_handle: usize, st: &T, + pltype: Arc>, ) -> Result { let st_value = builder.bitcast( ctx, st_value, - &PLType::Pointer(Arc::new(RefCell::new(PLType::Primitive(PriType::I64)))), + &PLType::Pointer(pltype.clone()), "traitcast_tmp", ); - let v_ptr = builder.build_struct_gep(trait_handle, 1, "v_tmp").unwrap(); + let v_ptr = builder + .build_struct_gep(trait_handle, 1, "v_tmp", &pltype.borrow(), ctx) + .unwrap(); builder.build_store(v_ptr, st_value); let type_hash = builder - .build_struct_gep(trait_handle, 0, "tp_hash") + .build_struct_gep(trait_handle, 0, "tp_hash", &pltype.borrow(), ctx) .unwrap(); let hash = st.get_type_code(); let hash = builder.int_value(&PriType::U64, hash, false); @@ -80,7 +83,7 @@ pub(crate) fn set_mthd_fields<'a, T: ImplAble>( } unreachable!() }); - set_mthd_field(mthd, ctx, st_pltype, builder, trait_handle, f)?; + set_mthd_field(mthd, ctx, t, st_pltype, builder, trait_handle, f)?; } Ok(()) } @@ -107,7 +110,7 @@ pub(crate) fn set_trait_impl_mthd_fields<'a, T: TraitImplAble>( } unreachable!() }); - set_mthd_field(mthd, ctx, st_pltype, builder, trait_handle, f)?; + set_mthd_field(mthd, ctx, t, st_pltype, builder, trait_handle, f)?; } Ok(()) } @@ -115,6 +118,7 @@ pub(crate) fn set_trait_impl_mthd_fields<'a, T: TraitImplAble>( pub(crate) fn set_mthd_field<'a>( mthd: Arc>, ctx: &mut Ctx<'a>, + t: &STType, st_pltype: &Arc>, builder: &BuilderEnum<'a, '_>, trait_handle: usize, @@ -127,7 +131,13 @@ pub(crate) fn set_mthd_field<'a>( let mthd = gen_mthd(m, ctx, st_pltype, builder, mthd)?; let fnhandle = builder.get_or_insert_fn_handle(&mthd.borrow(), ctx).0; let f_ptr = builder - .build_struct_gep(trait_handle, f.index, "field_tmp") + .build_struct_gep( + trait_handle, + f.index, + "field_tmp", + &PLType::Struct(t.clone()), + ctx, + ) .unwrap(); unsafe { builder.store_with_aoto_cast(f_ptr, fnhandle); @@ -218,7 +228,14 @@ impl<'a, 'ctx> Ctx<'a> { let trait_handle = builder.alloc("tmp_traitv", &target_pltype.borrow(), ctx, None); set_mthd_fields(t, st, ctx, st_pltype, builder, trait_handle)?; - store_trait_hash_and_ptr(builder, ctx, st_value, trait_handle, st) + store_trait_hash_and_ptr( + builder, + ctx, + st_value, + trait_handle, + st, + target_pltype.clone(), + ) }) }) } @@ -256,8 +273,11 @@ impl<'a, 'ctx> Ctx<'a> { .add_to_ctx(self)); } let closure_v = builder.alloc("tmp", &target_pltype.borrow(), self, None); - let closure_f = builder.build_struct_gep(closure_v, 0, "closure_f").unwrap(); - let ori_value = builder.try_load2var(ori_range, ori_value, self)?; + let closure_f = builder + .build_struct_gep(closure_v, 0, "closure_f", &target_pltype.borrow(), self) + .unwrap(); + let ori_value = + builder.try_load2var(ori_range, ori_value, &ori_pltype.borrow(), self)?; builder.build_store(closure_f, builder.get_closure_trampoline(ori_value)); return Ok(closure_v); } @@ -274,10 +294,22 @@ impl<'a, 'ctx> Ctx<'a> { let union_handle = builder.alloc("tmp_unionv", &target_pltype.borrow(), self, None); let union_value = builder - .build_struct_gep(union_handle, 1, "union_value") + .build_struct_gep( + union_handle, + 1, + "union_value", + &target_pltype.borrow(), + self, + ) .unwrap(); let union_type_field = builder - .build_struct_gep(union_handle, 0, "union_type") + .build_struct_gep( + union_handle, + 0, + "union_type", + &target_pltype.borrow(), + self, + ) .unwrap(); let union_type = builder.int_value(&PriType::U64, i as u64, false); builder.build_store(union_type_field, union_type); @@ -391,30 +423,63 @@ impl<'a, 'ctx> Ctx<'a> { let field = fs.get(&f.name).unwrap(); let fnhandle = builder - .build_struct_gep(st_value, field.index, "trait_mthd") + .build_struct_gep( + st_value, + field.index, + "trait_mthd", + &st_pltype.borrow(), + ctx, + ) .unwrap(); - let fnhandle = builder.build_load(fnhandle, "trait_mthd"); + let fnhandle = builder.build_load( + fnhandle, + "trait_mthd", + &PLType::new_i8_ptr(), + ctx, + ); // let targetftp = f.typenode.get_type(ctx, builder, true).unwrap(); // let casted = // builder.bitcast(ctx, fnhandle, &targetftp.borrow(), "fncast_tmp"); let f_ptr = builder - .build_struct_gep(trait_handle, f.index, "field_tmp") + .build_struct_gep( + trait_handle, + f.index, + "field_tmp", + &target_pltype.borrow(), + ctx, + ) .unwrap(); unsafe { builder.store_with_aoto_cast(f_ptr, fnhandle); } } - let st = builder.build_struct_gep(st_value, 1, "src_v_tmp").unwrap(); - let st = builder.build_load(st, "src_v"); - let v_ptr = builder.build_struct_gep(trait_handle, 1, "v_tmp").unwrap(); + let st = builder + .build_struct_gep(st_value, 1, "src_v_tmp", &st_pltype.borrow(), ctx) + .unwrap(); + let st = builder.build_load(st, "src_v", &PLType::new_i8_ptr(), ctx); + let v_ptr = builder + .build_struct_gep( + trait_handle, + 1, + "v_tmp", + &target_pltype.borrow(), + ctx, + ) + .unwrap(); builder.build_store(v_ptr, st); let type_hash = builder - .build_struct_gep(trait_handle, 0, "tp_hash") + .build_struct_gep( + trait_handle, + 0, + "tp_hash", + &target_pltype.borrow(), + ctx, + ) .unwrap(); let hash = builder - .build_struct_gep(st_value, 0, "src_tp_hash") + .build_struct_gep(st_value, 0, "src_tp_hash", &st_pltype.borrow(), ctx) .unwrap(); - let hash = builder.build_load(hash, "src_tp_hash"); + let hash = builder.build_load(hash, "src_tp_hash", &PLType::new_i64(), ctx); builder.build_store(type_hash, hash); Ok(trait_handle) }) @@ -457,7 +522,14 @@ impl<'a, 'ctx> Ctx<'a> { let trait_handle = builder.alloc("tmp_traitv", &target_pltype.borrow(), ctx, None); set_trait_impl_mthd_fields(t, st, ctx, st_pltype, builder, trait_handle)?; - store_trait_hash_and_ptr(builder, ctx, st_value, trait_handle, st) + store_trait_hash_and_ptr( + builder, + ctx, + st_value, + trait_handle, + st, + target_pltype.clone(), + ) }) }) } diff --git a/src/ast/diag.rs b/src/ast/diag.rs index 6f0eaeb47..9a3d29e0b 100644 --- a/src/ast/diag.rs +++ b/src/ast/diag.rs @@ -77,6 +77,7 @@ define_diag!( REF_CONST = "try referencing to a const value", INVALID_STRUCT_DEF = "invalid struct definition", UNDEFINED_TYPE = "undefined type", + UNKNOWN_TYPE = "unknown type", RETURN_VALUE_IN_VOID_FUNCTION = "return value in void function", RETURN_TYPE_MISMATCH = "return type mismatch", NO_RETURN_VALUE_IN_NON_VOID_FUNCTION = "non void function must have a return value", @@ -99,7 +100,6 @@ define_diag!( TYPE_MISMATCH = "type mismatch", ILLEGAL_GET_FIELD_OPERATION = "illegal get field operation", NOT_A_POINTER = "not a pointer", - CAN_NOT_REF_CONSTANT = "can not ref constant", ILLEGAL_SELF_RECURSION = "illegal self recursion, please use pointer", GENERIC_CANNOT_BE_INFER = "generic can not be infer", RECEIVER_CANNOT_BE_INFER = "receiver can not be infer", @@ -165,6 +165,7 @@ define_diag!( GENERIC_NOT_ALLOWED_IN_TRAIT_METHOD = "generic not allowed in trait method", THE_TARGET_TRAIT_CANNOT_BE_INSTANTIATED = "the target trait type cannot be instantiated", MACRO_EXPAND_DEPTH_TOO_DEEP = "macro expand depth too deep", + GLOBAL_MUST_BE_POINTER = "global must be pointer type", ); define_diag! { diff --git a/src/ast/node/cast.rs b/src/ast/node/cast.rs index 65d0ba6b1..de36ae2fe 100644 --- a/src/ast/node/cast.rs +++ b/src/ast/node/cast.rs @@ -82,9 +82,9 @@ impl<'a, 'ctx> Ctx<'a> { ) -> Result<(ValueHandle, Arc>), PLDiag> { let target_rc = target_ty.clone(); match (ty, &*target_ty.clone().borrow()) { - (PLType::Primitive(ty), PLType::Primitive(target_ty)) => { - let val = builder.try_load2var(node.expr.range(), val, self)?; - Ok((builder.cast_primitives(val, ty, target_ty), target_rc)) + (PLType::Primitive(tyi), PLType::Primitive(target_ty)) => { + let val = builder.try_load2var(node.expr.range(), val, ty, self)?; + Ok((builder.cast_primitives(val, tyi, target_ty), target_rc)) } (PLType::Union(union), target_ty) => { if node.tail.is_none() { @@ -103,17 +103,16 @@ impl<'a, 'ctx> Ctx<'a> { if let Some(tag) = union.has_type(target_ty, self, builder) { let (token, _) = node.tail.unwrap(); if token == TokenType::QUESTION { - Ok(self.cast_union_to(builder, val, tag, target_rc)) + Ok(self.cast_union_to(builder, val, tag, ty, target_rc)) } else { - Ok( - self.force_cast_union_to( - builder, - val, - tag, - target_rc, - node.range.start, - ), - ) + Ok(self.force_cast_union_to( + builder, + val, + tag, + ty, + target_rc, + node.range.start, + )) } } else { Err(node @@ -152,9 +151,9 @@ impl<'a, 'ctx> Ctx<'a> { } let (token, _) = node.tail.unwrap(); if token == TokenType::QUESTION { - Ok(self.cast_trait_to(builder, val, target_rc)) + Ok(self.cast_trait_to(builder, val, ty, target_rc)) } else { - Ok(self.force_cast_trait_to(builder, val, target_rc, node.range.start)) + Ok(self.force_cast_trait_to(builder, val, ty, target_rc, node.range.start)) } } _ => { @@ -191,10 +190,13 @@ impl<'a, 'ctx> Ctx<'a> { &mut self, builder: &'b BuilderEnum<'a, 'ctx>, val: ValueHandle, + ori_ty: &PLType, target_ty: Arc>, ) -> (ValueHandle, Arc>) { - let hash = builder.build_struct_gep(val, 0, "tp_hash").unwrap(); - let hash = builder.build_load(hash, "tp_hash"); + let hash = builder + .build_struct_gep(val, 0, "tp_hash", ori_ty, self) + .unwrap(); + let hash = builder.build_load(hash, "tp_hash", &PLType::new_i64(), self); let hasn_code = get_hash_code(target_ty.borrow().get_full_elm_name()); let hash_code = builder.int_value(&PriType::U64, hasn_code, false); let cond_block = builder.append_basic_block(self.function.unwrap(), "if.cond"); @@ -206,7 +208,12 @@ impl<'a, 'ctx> Ctx<'a> { self.position_at_end(cond_block, builder); let cond = builder.build_int_compare(IntPredicate::EQ, hash, hash_code, "hash.eq"); let cond = builder - .try_load2var(Default::default(), cond, self) + .try_load2var( + Default::default(), + cond, + &PLType::Primitive(PriType::BOOL), + self, + ) .unwrap(); let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); @@ -218,6 +225,7 @@ impl<'a, 'ctx> Ctx<'a> { cond, after_block, else_block, + ori_ty, ) } /// -> Option<_> @@ -226,12 +234,15 @@ impl<'a, 'ctx> Ctx<'a> { builder: &'b BuilderEnum<'a, 'ctx>, val: ValueHandle, union_tag: usize, + ori_ty: &PLType, target_ty: Arc>, ) -> (ValueHandle, Arc>) { - let tag = builder.build_struct_gep(val, 0, "tag").unwrap(); + let tag = builder + .build_struct_gep(val, 0, "tag", ori_ty, self) + .unwrap(); // check if the tag is the same - let tag = builder.build_load(tag, "tag"); + let tag = builder.build_load(tag, "tag", &PLType::new_i64(), self); let cond_block = builder.append_basic_block(self.function.unwrap(), "if.cond"); let then_block = builder.append_basic_block(self.function.unwrap(), "if.then"); let else_block = builder.append_basic_block(self.function.unwrap(), "if.else"); @@ -245,7 +256,12 @@ impl<'a, 'ctx> Ctx<'a> { "tag.eq", ); let cond = builder - .try_load2var(Default::default(), cond, self) + .try_load2var( + Default::default(), + cond, + &PLType::Primitive(PriType::BOOL), + self, + ) .unwrap(); let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); @@ -257,6 +273,7 @@ impl<'a, 'ctx> Ctx<'a> { cond, after_block, else_block, + ori_ty, ) } @@ -270,16 +287,23 @@ impl<'a, 'ctx> Ctx<'a> { cond: usize, after_block: usize, else_block: usize, + ori_ty: &PLType, ) -> (usize, Arc>) { - let result_tp = get_option_type(self, builder, target_ty).unwrap(); + let result_tp = get_option_type(self, builder, target_ty.clone()).unwrap(); let result = builder.alloc("cast_result", &result_tp.borrow(), self, None); - let result_tag_field = builder.build_struct_gep(result, 0, "tag").unwrap(); - let result_data_field = builder.build_struct_gep(result, 1, "data").unwrap(); + let result_tag_field = builder + .build_struct_gep(result, 0, "tag", &result_tp.borrow(), self) + .unwrap(); + let result_data_field = builder + .build_struct_gep(result, 1, "data", &result_tp.borrow(), self) + .unwrap(); builder.build_conditional_branch(cond, then_block, else_block); // then block self.position_at_end(then_block, builder); - let data = builder.build_struct_gep(val, 1, "data").unwrap(); - let data = builder.build_load(data, "data"); + let data = builder + .build_struct_gep(val, 1, "data", ori_ty, self) + .unwrap(); + let data = builder.build_load(data, "data", &target_ty.borrow(), self); builder.build_store(result_data_field, data); builder.build_store(result_tag_field, builder.int_value(&PriType::U64, 0, false)); builder.build_unconditional_branch(after_block); @@ -297,11 +321,14 @@ impl<'a, 'ctx> Ctx<'a> { &mut self, builder: &'b BuilderEnum<'a, 'ctx>, val: ValueHandle, + ori_ty: &PLType, target_ty: Arc>, pos: Pos, ) -> (ValueHandle, Arc>) { - let hash = builder.build_struct_gep(val, 0, "tp_hash").unwrap(); - let hash = builder.build_load(hash, "tp_hash"); + let hash = builder + .build_struct_gep(val, 0, "tp_hash", ori_ty, self) + .unwrap(); + let hash = builder.build_load(hash, "tp_hash", &PLType::new_i64(), self); let hasn_code = get_hash_code(target_ty.borrow().get_full_elm_name()); let hash_code = builder.int_value(&PriType::U64, hasn_code, false); let cond_block = builder.append_basic_block(self.function.unwrap(), "if.cond"); @@ -316,11 +343,17 @@ impl<'a, 'ctx> Ctx<'a> { self.position_at_end(cond_block, builder); let cond = builder.build_int_compare(IntPredicate::EQ, hash, hash_code, "hash.eq"); let cond = builder - .try_load2var(Default::default(), cond, self) + .try_load2var( + Default::default(), + cond, + &PLType::Primitive(PriType::BOOL), + self, + ) .unwrap(); let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); builder.build_conditional_branch(cond, then_block, else_block); self.build_cast_ret( + ori_ty, target_ty, builder, then_block, @@ -337,12 +370,15 @@ impl<'a, 'ctx> Ctx<'a> { builder: &'b BuilderEnum<'a, 'ctx>, val: ValueHandle, union_tag: usize, + ori_ty: &PLType, target_ty: Arc>, pos: Pos, ) -> (ValueHandle, Arc>) { - let tag = builder.build_struct_gep(val, 0, "tag").unwrap(); + let tag = builder + .build_struct_gep(val, 0, "tag", ori_ty, self) + .unwrap(); // check if the tag is the same - let tag = builder.build_load(tag, "tag"); + let tag = builder.build_load(tag, "tag", &PLType::new_i64(), self); let cond_block = builder.append_basic_block(self.function.unwrap(), "force.if.cond"); let then_block = builder.append_basic_block(self.function.unwrap(), "force.if.then"); let else_block = builder.append_basic_block(self.function.unwrap(), "force.if.else"); @@ -360,11 +396,17 @@ impl<'a, 'ctx> Ctx<'a> { "tag.eq", ); let cond = builder - .try_load2var(Default::default(), cond, self) + .try_load2var( + Default::default(), + cond, + &PLType::Primitive(PriType::BOOL), + self, + ) .unwrap(); let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); builder.build_conditional_branch(cond, then_block, else_block); self.build_cast_ret( + ori_ty, target_ty, builder, then_block, @@ -380,6 +422,7 @@ impl<'a, 'ctx> Ctx<'a> { #[allow(clippy::too_many_arguments)] fn build_cast_ret<'b>( &mut self, + ori_ty: &PLType, target_ty: Arc>, builder: &'b BuilderEnum<'a, 'ctx>, then_block: usize, @@ -392,10 +435,17 @@ impl<'a, 'ctx> Ctx<'a> { // then block self.position_at_end(then_block, builder); - let data = builder.build_struct_gep(val, 1, "data").unwrap(); - let data = builder.build_load(data, "data"); - let data = builder.bitcast(self, data, &PLType::Pointer(target_ty), "bitcasttemp"); - let data = builder.build_load(data, "data"); + let data = builder + .build_struct_gep(val, 1, "data", ori_ty, self) + .unwrap(); + let data = builder.build_load(data, "data", &PLType::new_i8_ptr(), self); + let data = builder.bitcast( + self, + data, + &PLType::Pointer(target_ty.clone()), + "bitcasttemp", + ); + let data = builder.build_load(data, "data", &target_ty.borrow(), self); builder.build_store(result, data); builder.build_unconditional_branch(after_block); // else block @@ -455,15 +505,22 @@ impl Node for IsNode { match tp { PLType::Union(u) => { if let Some(tag) = u.has_type(&target_tp.borrow(), ctx, builder) { - let tag_v = builder.build_struct_gep(val, 0, "tag").unwrap(); - let tag_v = builder.build_load(tag_v, "tag"); + let tag_v = builder.build_struct_gep(val, 0, "tag", tp, ctx).unwrap(); + let tag_v = builder.build_load(tag_v, "tag", &PLType::new_i64(), ctx); let cond = builder.build_int_compare( IntPredicate::EQ, tag_v, builder.int_value(&PriType::U64, tag as u64, false), "tag.eq", ); - let cond = builder.try_load2var(Default::default(), cond, ctx).unwrap(); + let cond = builder + .try_load2var( + Default::default(), + cond, + &PLType::Primitive(PriType::BOOL), + ctx, + ) + .unwrap(); let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); cond.new_output(ctx.get_type("bool", Default::default()).unwrap().tp) .set_const() @@ -480,10 +537,19 @@ impl Node for IsNode { let name = target_tp.borrow().get_full_elm_name(); let hash_code = get_hash_code(name); let hash_code = builder.int_value(&PriType::U64, hash_code, false); - let hash = builder.build_struct_gep(val, 0, "tp_hash").unwrap(); - let hash = builder.build_load(hash, "tp_hash"); + let hash = builder + .build_struct_gep(val, 0, "tp_hash", tp, ctx) + .unwrap(); + let hash = builder.build_load(hash, "tp_hash", &PLType::new_i64(), ctx); let cond = builder.build_int_compare(IntPredicate::EQ, hash, hash_code, "hash.eq"); - let cond = builder.try_load2var(Default::default(), cond, ctx).unwrap(); + let cond = builder + .try_load2var( + Default::default(), + cond, + &PLType::Primitive(PriType::BOOL), + ctx, + ) + .unwrap(); let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); cond.new_output(ctx.get_type("bool", Default::default()).unwrap().tp) .set_const() diff --git a/src/ast/node/control.rs b/src/ast/node/control.rs index 7cc91658c..e9a314728 100644 --- a/src/ast/node/control.rs +++ b/src/ast/node/control.rs @@ -49,7 +49,7 @@ impl Node for IfNode { check_bool(&v, ctx, condrange, code)?; let v = v.unwrap(); let cond = v.get_value(); - let cond = ctx.try_load2var(condrange, cond, builder)?; + let cond = ctx.try_load2var(condrange, cond, builder, &v.get_ty().borrow())?; let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); builder.build_conditional_branch(cond, then_block, else_block); // then block @@ -136,8 +136,9 @@ impl Node for WhileNode { let v = self.cond.emit(ctx, builder)?.get_value(); check_bool(&v, ctx, condrange, ErrorCode::WHILE_CONDITION_MUST_BE_BOOL)?; - let cond = v.unwrap().get_value(); - let cond = ctx.try_load2var(condrange, cond, builder)?; + let v = v.unwrap(); + let cond = v.get_value(); + let cond = ctx.try_load2var(condrange, cond, builder, &v.get_ty().borrow())?; let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); builder.build_conditional_branch(cond, body_block, after_block); ctx.position_at_end(body_block, builder); @@ -207,8 +208,9 @@ impl Node for ForNode { let cond_start = self.cond.range().start; let v = self.cond.emit(ctx, builder)?.get_value(); check_bool(&v, ctx, condrange, ErrorCode::FOR_CONDITION_MUST_BE_BOOL)?; - let cond = v.unwrap().get_value(); - let cond = ctx.try_load2var(condrange, cond, builder)?; + let node_value = &v.unwrap(); + let cond = node_value.get_value(); + let cond = ctx.try_load2var(condrange, cond, builder, &node_value.get_ty().borrow())?; let cond = builder.build_int_truncate(cond, &PriType::BOOL, "trunctemp"); builder.build_dbg_location(self.body.range().start); builder.build_conditional_branch(cond, body_block, after_block); diff --git a/src/ast/node/function.rs b/src/ast/node/function.rs index f2bc979f4..63a1b0328 100644 --- a/src/ast/node/function.rs +++ b/src/ast/node/function.rs @@ -14,6 +14,7 @@ use crate::ast::node::{deal_line, tab}; use crate::ast::pltype::{get_type_deep, ClosureType, FNValue, Field, FnType, PLType, STType}; use crate::ast::tokens::TokenType; +use crate::inference::{InferenceCtx, TyVariable}; use indexmap::IndexMap; use internal_macro::node; use linked_hash_map::LinkedHashMap; @@ -54,8 +55,11 @@ impl FuncCallNode { c: &ClosureType, v: ValueHandle, ) -> NodeResult { - let data = builder.build_struct_gep(v, 1, "closure_data").unwrap(); - let data = builder.build_load(data, "loaded_closure_data"); + let ct = PLType::Closure(c.clone()); + let data = builder + .build_struct_gep(v, 1, "closure_data", &ct, ctx) + .unwrap(); + let data = builder.build_load(data, "loaded_closure_data", &PLType::new_i8_ptr(), ctx); let mut para_values = vec![data]; let mut value_pltypes = vec![]; if self.paralist.len() != c.arg_types.len() { @@ -80,12 +84,12 @@ impl FuncCallNode { let v = v.unwrap(); let value_pltype = v.get_ty(); let value_pltype = get_type_deep(value_pltype); - let load = ctx.try_load2var(pararange, v.get_value(), builder)?; + let load = ctx.try_load2var(pararange, v.get_value(), builder, &v.get_ty().borrow())?; para_values.push(load); value_pltypes.push((value_pltype, pararange)); } - let re = builder.build_struct_gep(v, 0, "real_fn").unwrap(); - let re = builder.build_load(re, "real_fn"); + let re = builder.build_struct_gep(v, 0, "real_fn", &ct, ctx).unwrap(); + let re = builder.build_load(re, "real_fn", &PLType::new_i8_ptr(), ctx); let ret = builder.build_call(re, ¶_values, &c.ret_type.borrow(), ctx); builder.try_set_fn_dbg(self.range.start, ctx.function.unwrap()); handle_ret(ret, c.ret_type.clone()) @@ -122,7 +126,7 @@ impl FuncCallNode { let v = v.unwrap(); let value_pltype = v.get_ty(); let value_pltype = get_type_deep(value_pltype); - let load = ctx.try_load2var(pararange, v.get_value(), builder)?; + let load = ctx.try_load2var(pararange, v.get_value(), builder, &v.get_ty().borrow())?; para_values.push(load); value_pltypes.push((value_pltype, pararange)); } @@ -311,14 +315,14 @@ fn check_and_cast_params<'a, 'b>( builder.build_store(ptr2v, value); let trait_pltype = param_types[i + skip as usize].get_type(ctx, builder, true)?; value = ctx.up_cast( - trait_pltype, + trait_pltype.clone(), value_pltype.clone(), param_types[i + skip as usize].range(), *pararange, ptr2v, builder, )?; - value = ctx.try_load2var(*pararange, value, builder)?; + value = ctx.try_load2var(*pararange, value, builder, &trait_pltype.borrow())?; para_values[i + skip as usize] = value; } } @@ -630,12 +634,9 @@ impl FuncDefNode { let ret_value_ptr = if self.generator { generator::build_generator_ret(builder, child, &fnvalue, entry)? } else { - match &*fnvalue - .fntype - .ret_pltype - .get_type(child, builder, true)? - .borrow() - { + let tp = fnvalue.fntype.ret_pltype.get_type(child, builder, true)?; + child.rettp = Some(tp.clone()); + match &*tp.clone().borrow() { PLType::Void => None, other => { builder.rm_curr_debug_location(); @@ -647,7 +648,12 @@ impl FuncDefNode { child.position_at_end(return_block, builder); child.return_block = Some((return_block, ret_value_ptr)); if let Some(ptr) = ret_value_ptr { - let value = builder.build_load(ptr, "load_ret_tmp"); + let value = builder.build_load( + ptr, + "load_ret_tmp", + &child.rettp.clone().unwrap().borrow(), + child, + ); builder.build_return(Some(value)); } else { builder.build_return(None); @@ -656,6 +662,7 @@ impl FuncDefNode { if self.generator { // 设置flag,该flag影响alloc逻辑 child.ctx_flag = CtxFlag::InGeneratorYield; + child.generator_data.as_ref().unwrap().borrow_mut().is_para = true; } // alloc para for (i, para) in fnvalue.fntype.param_pltypes.iter().enumerate() { @@ -686,6 +693,9 @@ impl FuncDefNode { ) .unwrap(); } + if self.generator { + child.generator_data.as_ref().unwrap().borrow_mut().is_para = false; + } // emit body builder.rm_curr_debug_location(); if self.id.name == "main" { @@ -733,6 +743,11 @@ impl FuncDefNode { return Ok(()); } // body generation + let mut infer_ctx = InferenceCtx::new(ctx.unify_table.clone()); + infer_ctx.import_global_symbols(child); + let mut infer_ctx = infer_ctx.new_child(); + infer_ctx.import_symbols(child); + infer_ctx.inference_statements(self.body.as_mut().unwrap(), child, builder); let terminator = self .body .as_mut() @@ -835,6 +850,7 @@ pub struct ClosureNode { pub paralist: Vec<(Box, Option>)>, pub body: StatementsNode, pub ret: Option>, + pub ret_id: Option, } static CLOSURE_COUNT: AtomicI32 = AtomicI32::new(0); @@ -874,6 +890,17 @@ impl Node for ClosureNode { let tp = if let Some(typenode) = typenode { typenode.emit_highlight(ctx); typenode.get_type(ctx, builder, true)? + } else if let Some(id) = v.id { + let vv = ctx.unify_table.borrow_mut().probe(id); + let tp = vv.get_type(&mut ctx.unify_table.borrow_mut()); + if *tp.borrow() == PLType::Unknown { + v.range() + .new_err(ErrorCode::CLOSURE_PARAM_TYPE_UNKNOWN) + .add_help("try manually specify the parameter type of the closure") + .add_to_ctx(ctx); + } + ctx.push_type_hints(v.range(), tp.clone()); + tp } else if let Some(exp_ty) = &ctx.expect_ty { match &*exp_ty.borrow() { PLType::Closure(c) => { @@ -907,6 +934,10 @@ impl Node for ClosureNode { let ret_tp = if let Some(ret) = &self.ret { ret.emit_highlight(ctx); ret.get_type(ctx, builder, true)? + } else if let Some(ty) = self.ret_id { + let v = ctx.unify_table.borrow_mut().probe(ty); + let tp = v.get_type(&mut ctx.unify_table.borrow_mut()); + tp } else if let Some(exp_ty) = &ctx.expect_ty { match &*exp_ty.borrow() { PLType::Closure(c) => c.ret_type.clone(), @@ -933,7 +964,8 @@ impl Node for ClosureNode { child.ctx_flag = CtxFlag::Normal; child.function = Some(f); let stpltp = PLType::Struct(st_tp.clone()); - let ptr_tp = PLType::Pointer(Arc::new(RefCell::new(stpltp))); + let stp = Arc::new(RefCell::new(stpltp)); + let ptr_tp = PLType::Pointer(stp.clone()); let mut all_tps = vec![Arc::new(RefCell::new(i8ptr.clone()))]; all_tps.extend(paratps.clone()); builder.build_sub_program_by_pltp( @@ -961,7 +993,7 @@ impl Node for ClosureNode { child.position_at_end(return_block, builder); child.return_block = Some((return_block, ret_value_ptr)); if let Some(ptr) = ret_value_ptr { - let value = builder.build_load(ptr, "load_ret_tmp"); + let value = builder.build_load(ptr, "load_ret_tmp", &ret_tp.borrow(), child); builder.build_return(Some(value)); } else { builder.build_return(None); @@ -973,11 +1005,12 @@ impl Node for ClosureNode { // let alloca = builder.alloc("closure_data",&ptr_tp, child, None); // builder.build_store(alloca, casted_data); child.position_at_end(entry, builder); - child.closure_data = Some(RefCell::new(ClosureCtxData { + child.closure_data = Some(Arc::new(RefCell::new(ClosureCtxData { table: LinkedHashMap::default(), data_handle: casted_data, alloca_bb: None, - })); + data_tp: Some(stp), + }))); // alloc para for (i, tp) in paratps.iter().enumerate() { let b = tp.clone(); @@ -1060,16 +1093,22 @@ impl Node for ClosureNode { for (k, (_, ori_v)) in &closure_table.borrow().table { let field = st_tp.fields.get(k).unwrap(); let alloca: usize = builder - .build_struct_gep(closure_data_alloca, field.index, k) + .build_struct_gep( + closure_data_alloca, + field.index, + k, + &PLType::Struct(st_tp.clone()), + ctx, + ) .unwrap(); builder.build_store(alloca, *ori_v); } let f_field = builder - .build_struct_gep(closure_alloca, 0, "closure_f") + .build_struct_gep(closure_alloca, 0, "closure_f", &closure_f_tp, ctx) .unwrap(); builder.build_store(f_field, f); let d_field = builder - .build_struct_gep(closure_alloca, 1, "closure_d") + .build_struct_gep(closure_alloca, 1, "closure_d", &closure_f_tp, ctx) .unwrap(); let d_casted = builder.bitcast(ctx, closure_data_alloca, &i8ptr, "casted_closure_d"); builder.build_store(d_field, d_casted); diff --git a/src/ast/node/function/generator.rs b/src/ast/node/function/generator.rs index 9a69a3786..a4abc39c0 100644 --- a/src/ast/node/function/generator.rs +++ b/src/ast/node/function/generator.rs @@ -35,7 +35,10 @@ pub struct GeneratorCtxData { pub ret_handle: ValueHandle, //handle in setup function pub prev_yield_bb: Option, pub ctx_size_handle: ValueHandle, - pub param_tmp: ValueHandle, + pub is_para: bool, + pub para_tmp: ValueHandle, + pub ctx_tp: Option>>, + pub ret_type: Option>>, } /// # CtxFlag @@ -52,6 +55,7 @@ pub struct ClosureCtxData { pub table: LinkedHashMap, pub data_handle: ValueHandle, pub alloca_bb: Option, + pub data_tp: Option>>, } pub(crate) fn end_generator<'a>( @@ -90,8 +94,15 @@ pub(crate) fn end_generator<'a>( builder.gen_st_visit_function(child, st_tp, &tps); builder.position_at_end_block(data.borrow().entry_bb); // 3. 在setup函数中给返回值(接口)的对应字段赋值,并返回 + let b = data.borrow(); let ptr = builder - .build_struct_gep(data.borrow().ret_handle, 1, "ctx_handle_gep") + .build_struct_gep( + b.ret_handle, + 1, + "ctx_handle_gep", + &b.ret_type.as_ref().unwrap().borrow(), + child, + ) .unwrap(); let ptr = builder.bitcast( child, @@ -103,10 +114,21 @@ pub(crate) fn end_generator<'a>( ); builder.build_store(ptr, data.borrow().ctx_handle); let ptr = builder - .build_struct_gep(data.borrow().ret_handle, 2, "ctx_handle_gep") + .build_struct_gep( + data.borrow().ret_handle, + 2, + "ctx_handle_gep", + &b.ret_type.as_ref().unwrap().borrow(), + child, + ) .unwrap(); unsafe { builder.store_with_aoto_cast(ptr, funcvalue) }; - let ret_load = builder.build_load(data.borrow().ret_handle, "ret_load"); + let ret_load = builder.build_load( + data.borrow().ret_handle, + "ret_load", + &b.ret_type.as_ref().unwrap().borrow(), + child, + ); builder.build_return(Some(ret_load)); // 4. 生成yield函数的done分支代码 @@ -114,7 +136,13 @@ pub(crate) fn end_generator<'a>( builder.build_unconditional_branch(data.borrow().entry_bb); builder.position_at_end_block(done); let flag = builder - .build_struct_gep(child.return_block.unwrap().1.unwrap(), 0, "flag") + .build_struct_gep( + child.return_block.unwrap().1.unwrap(), + 0, + "flag", + &b.ret_type.as_ref().unwrap().borrow(), + child, + ) .unwrap(); builder.build_store(flag, builder.int_value(&PriType::U64, 1, false)); builder.build_unconditional_branch(child.return_block.unwrap().0); @@ -122,8 +150,16 @@ pub(crate) fn end_generator<'a>( // 5. 生成yield函数的跳转代码 builder.position_at_end_block(allocab); let ctx_v = builder.get_nth_param(child.function.unwrap(), 0); - let address = builder.build_struct_gep(ctx_v, 1, "block_address").unwrap(); - let address = builder.build_load(address, "block_address"); + let address = builder + .build_struct_gep( + ctx_v, + 1, + "block_address", + &data.borrow().ctx_tp.as_ref().unwrap().borrow(), + child, + ) + .unwrap(); + let address = builder.build_load(address, "block_address", &PLType::new_i8_ptr(), child); builder.build_indirect_br(address, child); // 6. 用最终的generator_ctx大小修正之前的malloc语句 @@ -214,6 +250,8 @@ pub(crate) fn init_generator<'a>( .unwrap() .borrow_mut() .ctx_handle = ctx_handle; + child.generator_data.as_ref().unwrap().borrow_mut().ctx_tp = + Some(Arc::new(RefCell::new(PLType::Struct(st_tp.clone())))); *sttp_opt = Some(st_tp); Ok(()) } @@ -227,7 +265,13 @@ pub(crate) fn save_generator_init_block<'a>( let address = builder.get_block_address(entry); let data = child.generator_data.as_ref().unwrap().clone(); let address_ptr = builder - .build_struct_gep(data.borrow().ctx_handle, 1, "block_address") + .build_struct_gep( + data.borrow().ctx_handle, + 1, + "block_address", + &data.borrow().ctx_tp.as_ref().unwrap().borrow(), + child, + ) .unwrap(); builder.build_store(address_ptr, address); } @@ -242,16 +286,13 @@ pub(crate) fn build_generator_ret<'a>( let data = child.generator_data.as_ref().unwrap().clone(); child.position_at_end(data.borrow().entry_bb, builder); let tp = child.rettp.clone().unwrap(); - match &*fnvalue - .fntype - .ret_pltype - .get_type(child, builder, true)? - .borrow() - { + let r = fnvalue.fntype.ret_pltype.get_type(child, builder, true)?; + match &*r.clone().borrow() { PLType::Void => unreachable!(), other => { builder.rm_curr_debug_location(); data.borrow_mut().ret_handle = builder.alloc("retvalue", other, child, None); + data.borrow_mut().ret_type = Some(r); } } child.position_at_end(entry, builder); diff --git a/src/ast/node/global.rs b/src/ast/node/global.rs index f7e151a14..a3a81e979 100644 --- a/src/ast/node/global.rs +++ b/src/ast/node/global.rs @@ -76,7 +76,8 @@ impl Node for GlobalNode { let v = self.exp.emit(ctx, builder)?.get_value(); let v = v.unwrap(); ctx.push_type_hints(self.var.range, v.get_ty()); - let base_value = ctx.try_load2var(exp_range, v.get_value(), builder)?; + let base_value = + ctx.try_load2var(exp_range, v.get_value(), builder, &v.get_ty().borrow())?; let res = ctx.get_symbol(&self.var.name, builder); if res.is_none() { return Ok(Default::default()); @@ -116,11 +117,15 @@ impl GlobalNode { ctx.add_symbol( self.var.name.clone(), globalptr, - pltype, + pltype.clone(), self.var.range, false, false, )?; + // for gc reason, globals must be pointer + if !matches!(&*pltype.borrow(), PLType::Pointer(_)) { + return Err(ctx.add_diag(self.var.range.new_err(ErrorCode::GLOBAL_MUST_BE_POINTER))); + } Ok(()) } } diff --git a/src/ast/node/interface.rs b/src/ast/node/interface.rs index c8b985b2e..f3d68612b 100644 --- a/src/ast/node/interface.rs +++ b/src/ast/node/interface.rs @@ -220,6 +220,7 @@ fn new_selfptr_tf_with_name(n: &str) -> TypedIdentifierNode { id: VarNode { name: n.to_string(), range: Default::default(), + id: None, }, typenode: Box::new(TypeNodeEnum::Pointer(PointerTypeNode { elm: Box::new(TypeNameNode::new_from_str("i64").into()), diff --git a/src/ast/node/operator.rs b/src/ast/node/operator.rs index 28f329d3b..49b7f0926 100644 --- a/src/ast/node/operator.rs +++ b/src/ast/node/operator.rs @@ -51,7 +51,7 @@ impl Node for UnaryOpNode { } let rv = rv.unwrap(); let pltype = rv.get_ty(); - let exp = ctx.try_load2var(exp_range, rv.get_value(), builder)?; + let exp = ctx.try_load2var(exp_range, rv.get_value(), builder, &rv.get_ty().borrow())?; return Ok(match (&*pltype.borrow(), self.op.0) { ( PLType::Primitive( @@ -113,7 +113,7 @@ impl Node for BinOpNode { let lv = lv.unwrap(); let lpltype = lv.get_ty(); let lv = lv.get_value(); - let left = ctx.try_load2var(lrange, lv, builder)?; + let left = ctx.try_load2var(lrange, lv, builder, &lpltype.borrow())?; if self.op.0 == TokenType::AND || self.op.0 == TokenType::OR { return Ok(match *lpltype.clone().borrow() { PLType::Primitive(PriType::BOOL) => { @@ -146,7 +146,9 @@ impl Node for BinOpNode { if rv.is_none() { return Err(ctx.add_diag(self.range.new_err(ErrorCode::EXPECT_VALUE))); } - let right = ctx.try_load2var(rrange, rv.unwrap().get_value(), builder)?; + let rv = rv.unwrap(); + let right = + ctx.try_load2var(rrange, rv.get_value(), builder, &rv.get_ty().borrow())?; let incoming_bb2 = builder.get_cur_basic_block(); // get incoming block 2 builder.build_unconditional_branch(merge_bb); // merge bb @@ -177,7 +179,8 @@ impl Node for BinOpNode { if re.is_none() { return Err(ctx.add_diag(self.range.new_err(ErrorCode::EXPECT_VALUE))); } - let right = ctx.try_load2var(rrange, re.unwrap().get_value(), builder)?; + let re = re.unwrap(); + let right = ctx.try_load2var(rrange, re.get_value(), builder, &re.get_ty().borrow())?; let lpltype = get_type_deep(lpltype); Ok(match self.op.0 { TokenType::BIT_AND => { @@ -436,16 +439,30 @@ impl Node for TakeOpNode { if let Some(field) = field { _ = s.expect_field_pub(ctx, &field, id_range); ctx.push_semantic_token(id_range, SemanticTokenType::METHOD, 0); - ctx.set_field_refs(head_pltype, &field, id_range); + ctx.set_field_refs(head_pltype.clone(), &field, id_range); ctx.send_if_go_to_def(id_range, field.range, s.path.clone()); let re = field.typenode.get_type(ctx, builder, true)?; let fnv = builder - .build_struct_gep(headptr, field.index, "mthd_ptr") + .build_struct_gep( + headptr, + field.index, + "mthd_ptr", + &head_pltype.borrow(), + ctx, + ) .unwrap(); - let fnv = builder.build_load(fnv, "mthd_ptr_load"); - let headptr = builder.build_struct_gep(headptr, 1, "traitptr").unwrap(); - let headptr = builder.build_load(headptr, "traitptr_load"); + let fnv = + builder.build_load(fnv, "mthd_ptr_load", &PLType::new_i8_ptr(), ctx); + let headptr = builder + .build_struct_gep(headptr, 1, "traitptr", &head_pltype.borrow(), ctx) + .unwrap(); + let headptr = builder.build_load( + headptr, + "traitptr_load", + &PLType::new_i8_ptr(), + ctx, + ); ctx.emit_comment_highlight(&self.comments[0]); Ok(NodeOutput::new_value(NodeValue::new_receiver( fnv, re, headptr, None, @@ -459,14 +476,20 @@ impl Node for TakeOpNode { if let Some(field) = s.fields.get(&id.name) { _ = s.expect_field_pub(ctx, field, id_range); ctx.push_semantic_token(id_range, SemanticTokenType::PROPERTY, 0); - ctx.set_field_refs(head_pltype, field, id_range); + ctx.set_field_refs(head_pltype.clone(), field, id_range); if field.range != Default::default() { // walkaround for tuple types ctx.send_if_go_to_def(id_range, field.range, s.path.clone()); } return Ok(NodeOutput::new_value(NodeValue::new( builder - .build_struct_gep(headptr, field.index, "structgep") + .build_struct_gep( + headptr, + field.index, + "structgep", + &head_pltype.borrow(), + ctx, + ) .unwrap(), field.typenode.get_type(ctx, builder, true)?, ))); diff --git a/src/ast/node/pointer.rs b/src/ast/node/pointer.rs index d0d045f2e..57db77f6d 100644 --- a/src/ast/node/pointer.rs +++ b/src/ast/node/pointer.rs @@ -44,18 +44,22 @@ impl Node for PointerOpNode { PointerOpEnum::Deref => { if let PLType::Pointer(tp1) = &*btp.borrow() { tp = tp1.clone(); - builder.build_load(value, "deref") + builder.build_load(value, "deref", &btp.borrow(), ctx) } else { return Err(ctx.add_diag(self.range.new_err(ErrorCode::NOT_A_POINTER))); } } PointerOpEnum::Addr => { // let old_tp = tp.clone().unwrap(); + let oldtp = tp.clone(); tp = Arc::new(RefCell::new(PLType::Pointer(tp))); - if v.is_const() { - return Err(ctx.add_diag(self.range.new_err(ErrorCode::CAN_NOT_REF_CONSTANT))); + let mut val = value; + if !builder.is_ptr(v.get_value()) { + // if not a pointer, then alloc a new tmp var + let var = builder.alloc("var", &oldtp.borrow(), ctx, None); + builder.build_store(var, value); + val = var; } - let val = value; let v = builder.alloc("addr", &tp.borrow(), ctx, None); builder.build_store(v, val); v diff --git a/src/ast/node/primary.rs b/src/ast/node/primary.rs index 30730d45c..b07a7c697 100644 --- a/src/ast/node/primary.rs +++ b/src/ast/node/primary.rs @@ -10,6 +10,7 @@ use crate::ast::ctx::MacroReplaceNode; use crate::ast::ctx::BUILTIN_FN_NAME_MAP; use crate::ast::diag::ErrorCode; use crate::ast::pltype::{PLType, PriType}; +use crate::inference::TyVariable; use crate::modifier_set; use internal_macro::node; use lsp_types::SemanticTokenType; @@ -104,6 +105,7 @@ impl Node for NumNode { #[node] pub struct VarNode { pub name: String, + pub id: Option, } impl Node for VarNode { @@ -295,15 +297,23 @@ impl Node for ArrayElementNode { let index_range = self.index.range(); let v = self.index.emit(ctx, builder)?.get_value().unwrap(); let index = v.get_value(); - let index = ctx.try_load2var(index_range, index, builder)?; + let index = ctx.try_load2var(index_range, index, builder, &v.get_ty().borrow())?; if !v.get_ty().borrow().is(&PriType::I64) { return Err(ctx.add_diag(self.range.new_err(ErrorCode::ARRAY_INDEX_MUST_BE_INT))); } let elemptr = { let index = &[index]; - let real_arr = builder.build_struct_gep(arr, 1, "real_arr").unwrap(); - let real_arr = builder.build_load(real_arr, "load_arr"); - builder.build_in_bounds_gep(real_arr, index, "element_ptr") + let real_arr = builder + .build_struct_gep(arr, 1, "real_arr", &pltype.borrow(), ctx) + .unwrap(); + let real_arr = builder.build_load(real_arr, "load_arr", &PLType::new_i8_ptr(), ctx); + builder.build_in_bounds_gep( + real_arr, + index, + "element_ptr", + &arrtp.element_type.borrow(), + ctx, + ) }; ctx.emit_comment_highlight(&self.comments[0]); return elemptr.new_output(arrtp.element_type.clone()).to_result(); diff --git a/src/ast/node/program.rs b/src/ast/node/program.rs index 31c37ff4e..611cbac41 100644 --- a/src/ast/node/program.rs +++ b/src/ast/node/program.rs @@ -120,6 +120,7 @@ fn new_var(name: &str) -> Box { Box::new(VarNode { name: name.to_string(), range: Default::default(), + id: None, }) } fn new_use(ns: &[&str]) -> Box { diff --git a/src/ast/node/ret.rs b/src/ast/node/ret.rs index 0719715cb..e527c7017 100644 --- a/src/ast/node/ret.rs +++ b/src/ast/node/ret.rs @@ -54,14 +54,14 @@ impl Node for RetNode { } // let value = ctx.try_load2var(self.range, v.get_value(), builder)?; let value = ctx.up_cast( - ret_pltype, + ret_pltype.clone(), value_pltype, ret_node.range(), ret_node.range(), v.get_value(), builder, )?; - let value = ctx.try_load2var(self.range, value, builder)?; + let value = ctx.try_load2var(self.range, value, builder, &ret_pltype.borrow())?; builder.build_store(ctx.return_block.unwrap().1.unwrap(), value); let curbb = builder.get_cur_basic_block(); @@ -93,7 +93,8 @@ impl Node for RetNode { let v = ret_node.emit(ctx, builder)?.get_value().unwrap(); ctx.emit_comment_highlight(&self.comments[0]); let value_pltype = v.get_ty(); - let mut value = ctx.try_load2var(self.range, v.get_value(), builder)?; + let mut value = + ctx.try_load2var(self.range, v.get_value(), builder, &v.get_ty().borrow())?; let eqres = ctx.eq(ret_pltype.clone(), value_pltype.clone()); if !eqres.eq { let err = ctx.add_diag(self.range.new_err(ErrorCode::RETURN_TYPE_MISMATCH)); @@ -103,14 +104,14 @@ impl Node for RetNode { let ptr2v = builder.alloc("tmp_up_cast_ptr", &value_pltype.borrow(), ctx, None); builder.build_store(ptr2v, value); value = ctx.up_cast( - ret_pltype, + ret_pltype.clone(), value_pltype.clone(), ret_node.range(), ret_node.range(), ptr2v, builder, )?; - value = ctx.try_load2var(self.range, value, builder)?; + value = ctx.try_load2var(self.range, value, builder, &ret_pltype.borrow())?; } if ctx.return_block.unwrap().1.is_none() { return Err(self diff --git a/src/ast/node/statement.rs b/src/ast/node/statement.rs index e4d33282c..9715e9892 100644 --- a/src/ast/node/statement.rs +++ b/src/ast/node/statement.rs @@ -6,6 +6,7 @@ use crate::ast::builder::IRBuilder; use crate::ast::ctx::Ctx; use crate::ast::diag::{ErrorCode, WarnCode}; use crate::format_label; + use crate::modifier_set; use indexmap::IndexMap; @@ -89,11 +90,8 @@ impl PrintTrait for DefNode { self.var.print(tabs + 1, false, line.clone()); if let Some(tp) = &self.tp { tp.print(tabs + 1, true, line.clone()); - } else { - self.exp - .as_ref() - .unwrap() - .print(tabs + 1, true, line.clone()); + } else if let Some(e) = self.exp.as_ref() { + e.print(tabs + 1, true, line.clone()); } } } @@ -170,16 +168,51 @@ impl Node for DefNode { builder: &'b BuilderEnum<'a, '_>, ) -> NodeResult { ctx.push_semantic_token(self.var.range(), SemanticTokenType::VARIABLE, 0); - if self.exp.is_none() && self.tp.is_none() { - return Err(ctx.add_diag(self.range.new_err(ErrorCode::UNDEFINED_TYPE))); - } let mut pltype = None; + if self.tp.is_none() { + let mut tp = Arc::new(RefCell::new(PLType::Unknown)); + if let DefVar::Identifier(i) = &*self.var { + if let Some(id) = i.id { + let v = ctx.unify_table.borrow_mut().probe(id); + tp = v.get_type(&mut ctx.unify_table.borrow_mut()); + if self.exp.is_none() { + ctx.push_type_hints(self.var.range(), tp.clone()); + } + } + } + if self.exp.is_none() && matches!(&*tp.borrow(), PLType::Unknown) { + match builder { + BuilderEnum::LLVM(_) => { + return Err(ctx.add_diag( + self.var + .range() + .new_err(ErrorCode::UNKNOWN_TYPE) + .add_to_ctx(ctx), + )); + } + BuilderEnum::NoOp(_) => { + ctx.add_diag( + self.var + .range() + .new_err(ErrorCode::UNKNOWN_TYPE) + .add_to_ctx(ctx), + ); + } + } + } + pltype = Some(tp); + } let mut expv = None; if let Some(tp) = &self.tp { tp.emit_highlight(ctx); let pltp = tp.get_type(ctx, builder, true)?; pltype = Some(pltp); } + if self.exp.is_some() + && matches!(pltype.clone(), Some(tp) if matches!(&*tp.borrow(), PLType::Unknown)) + { + pltype = None; + } if let Some(exp) = &mut self.exp { let re = if let Some(pltype) = pltype.clone() { ctx.emit_with_expectation(exp, pltype, self.var.range(), builder)? @@ -190,7 +223,7 @@ impl Node for DefNode { // for err tolerate if re.is_none() { - return Err(ctx.add_diag(self.range.new_err(ErrorCode::UNDEFINED_TYPE))); + return Err(ctx.add_diag(self.var.range().new_err(ErrorCode::UNKNOWN_TYPE))); } let re = re.unwrap(); let mut tp = re.get_ty(); @@ -213,6 +246,8 @@ impl Node for DefNode { if pltype.is_none() { ctx.push_type_hints(self.var.range(), tp.clone()); pltype = Some(tp); + } else if self.tp.is_none() { + ctx.push_type_hints(self.var.range(), pltype.clone().unwrap()); } expv = Some(v); } @@ -262,7 +297,7 @@ fn handle_deconstruct<'a, 'b>( ctx.add_symbol( var.name.clone(), ptr2value, - pltype, + pltype.clone(), def_var.range(), false, false, @@ -284,7 +319,10 @@ fn handle_deconstruct<'a, 'b>( }; if let Some(exp) = expv { builder.build_dbg_location(def_var.range().start); - builder.build_store(ptr2value, ctx.try_load2var(range, exp, builder)?); + builder.build_store( + ptr2value, + ctx.try_load2var(range, exp, builder, &pltype.borrow())?, + ); } } DefVar::TupleDeconstruct(TupleDeconstructNode { @@ -324,7 +362,7 @@ fn handle_deconstruct<'a, 'b>( for (i, (_, f)) in st.fields.iter().enumerate() { let ftp = f.typenode.get_type(ctx, builder, false)?; let expv = builder - .build_struct_gep(expv, f.index, "_deconstruct") + .build_struct_gep(expv, f.index, "_deconstruct", &pltype.borrow(), ctx) .unwrap(); let deconstruct_v = var[i].as_ref(); handle_deconstruct( @@ -389,7 +427,13 @@ fn handle_deconstruct<'a, 'b>( } let f = st.fields.get(&v.name).unwrap(); let expv = builder - .build_struct_gep(expv, f.index, "_deconstruct") + .build_struct_gep( + expv, + f.index, + "_deconstruct", + &pltype.borrow(), + ctx, + ) .unwrap(); let ftp = f.typenode.get_type(ctx, builder, false)?; (expv, ftp) @@ -503,14 +547,19 @@ impl Node for AssignNode { let lpltype = rel.get_ty(); // 要走转换逻辑,所以不和下方分支统一 let value = ctx - .emit_with_expectation(&mut self.exp, lpltype, self.var.range(), builder)? + .emit_with_expectation( + &mut self.exp, + lpltype.clone(), + self.var.range(), + builder, + )? .get_value() .unwrap() .get_value(); if rel.is_const() { return Err(ctx.add_diag(self.range.new_err(ErrorCode::ASSIGN_CONST))); } - let load = ctx.try_load2var(exp_range, value, builder)?; + let load = ctx.try_load2var(exp_range, value, builder, &lpltype.borrow())?; builder.build_store(ptr, load); Ok(Default::default()) } diff --git a/src/ast/node/string_literal.rs b/src/ast/node/string_literal.rs index 994792c21..6c72b2323 100644 --- a/src/ast/node/string_literal.rs +++ b/src/ast/node/string_literal.rs @@ -39,9 +39,15 @@ impl Node for StringNode { .map(|m| m.types.get("string").unwrap().clone()) .unwrap_or_else(|| ctx.plmod.types.get("string").unwrap().clone()); let alloca = builder.alloc("string", &tp.borrow(), ctx, None); - let len = builder.build_struct_gep(alloca, 1, "len").unwrap(); - let byte_len = builder.build_struct_gep(alloca, 2, "byte_len").unwrap(); - let read_arr = builder.build_struct_gep(alloca, 3, "real_arr").unwrap(); + let len = builder + .build_struct_gep(alloca, 1, "len", &tp.borrow(), ctx) + .unwrap(); + let byte_len = builder + .build_struct_gep(alloca, 2, "byte_len", &tp.borrow(), ctx) + .unwrap(); + let read_arr = builder + .build_struct_gep(alloca, 3, "real_arr_str", &tp.borrow(), ctx) + .unwrap(); builder.build_store(read_arr, v); builder.build_store( diff --git a/src/ast/node/tuple.rs b/src/ast/node/tuple.rs index 54a225ebf..876789a31 100644 --- a/src/ast/node/tuple.rs +++ b/src/ast/node/tuple.rs @@ -71,10 +71,10 @@ impl Node for TupleInitNode { // 初始化赋值 for (i, value) in expr_values.into_iter().enumerate() { let field_ptr = builder - .build_struct_gep(v, i as u32 + 1, &i.to_string()) + .build_struct_gep(v, i as u32 + 1, &i.to_string(), &stu.borrow(), ctx) .unwrap(); - let v = - builder.try_load2var(self.range, value.get_value().unwrap().get_value(), ctx)?; + let vv = value.get_value().unwrap(); + let v = builder.try_load2var(self.range, vv.get_value(), &vv.get_ty().borrow(), ctx)?; builder.build_store(field_ptr, v); } v.new_output(stu).to_result() diff --git a/src/ast/node/types.rs b/src/ast/node/types.rs index be0f9f37b..2c5b9562f 100644 --- a/src/ast/node/types.rs +++ b/src/ast/node/types.rs @@ -37,6 +37,7 @@ impl TypeNameNode { id: Box::new(VarNode { name: s.to_string(), range: Default::default(), + id: None, }), range: Default::default(), ns: vec![], @@ -684,7 +685,12 @@ impl Node for StructInitNode { return Err(ctx.add_diag(field_exp_range.new_err(ErrorCode::EXPECT_VALUE))); } let v = v.unwrap(); - let value = ctx.try_load2var(field_exp_range, v.get_value(), builder)?; + let value = ctx.try_load2var( + field_exp_range, + v.get_value(), + builder, + &v.get_ty().borrow(), + )?; let value_pltype = v.get_ty(); ctx.protect_generic_context(&sttype.generic_map, |ctx| { if !field @@ -730,7 +736,7 @@ impl Node for StructInitNode { let struct_pointer = builder.alloc("initstruct", &pltype.borrow(), ctx, None); //alloc(ctx, tp, "initstruct"); field_init_values.iter().for_each(|(index, value)| { let fieldptr = builder - .build_struct_gep(struct_pointer, *index, "fieldptr") + .build_struct_gep(struct_pointer, *index, "fieldptr", &pltype.borrow(), ctx) .unwrap(); builder.build_store(fieldptr, *value); }); @@ -781,7 +787,10 @@ impl Node for ArrayInitNode { } let v = v.unwrap(); let tp = v.get_ty(); - exps.push((ctx.try_load2var(range, v.get_value(), builder)?, tp)); + exps.push(( + ctx.try_load2var(range, v.get_value(), builder, &v.get_ty().borrow())?, + tp, + )); } let sz = exps.len() as u64; let (tp, size_handle) = if let Some((tp, len_v)) = &mut self.tp { @@ -791,7 +800,12 @@ impl Node for ArrayInitNode { if !matches!(&*len.get_ty().borrow(), PLType::Primitive(PriType::I64)) { return Err(ctx.add_diag(len_v.range().new_err(ErrorCode::ARRAY_LEN_MUST_BE_I64))); } - let len = ctx.try_load2var(len_v.range(), len.get_value(), builder)?; + let len = ctx.try_load2var( + len_v.range(), + len.get_value(), + builder, + &len.get_ty().borrow(), + )?; if let Some(tp0) = &tp0 { if !ctx.eq(tp.clone(), tp0.clone()).total_eq() { return Err(ctx.add_diag(self.range.new_err(ErrorCode::ARRAY_TYPE_NOT_MATCH))); @@ -806,15 +820,23 @@ impl Node for ArrayInitNode { (tp, builder.int_value(&PriType::I64, sz, true)) }; let arr_tp = Arc::new(RefCell::new(PLType::Arr(ARRType { - element_type: tp, + element_type: tp.clone(), size_handle, }))); let arr = builder.alloc("array_alloca", &arr_tp.borrow(), ctx, None); - let real_arr = builder.build_struct_gep(arr, 1, "real_arr").unwrap(); + let real_arr = builder + .build_struct_gep(arr, 1, "real_arr", &arr_tp.borrow(), ctx) + .unwrap(); - let real_arr = builder.build_load(real_arr, "load_arr"); + let real_arr = builder.build_load(real_arr, "load_arr", &PLType::Pointer(tp.clone()), ctx); for (i, (v, _)) in exps.into_iter().enumerate() { - let ptr = builder.build_const_in_bounds_gep(real_arr, &[i as u64], "elem_ptr"); + let ptr = builder.build_const_in_bounds_gep( + real_arr, + &[i as u64], + "elem_ptr", + &tp.borrow(), + ctx, + ); builder.build_store(ptr, v); } arr.new_output(arr_tp).to_result() diff --git a/src/ast/plmod.rs b/src/ast/plmod.rs index d23fcb150..968466671 100644 --- a/src/ast/plmod.rs +++ b/src/ast/plmod.rs @@ -484,6 +484,7 @@ impl Mod { PLType::PlaceHolder(_) => continue, PLType::Union(_) => CompletionItemKind::ENUM, PLType::Closure(_) => unreachable!(), + PLType::Unknown => unreachable!(), }; if k.starts_with('|') { // skip method diff --git a/src/ast/pltype.rs b/src/ast/pltype.rs index f992ca294..0ec0b1dfd 100644 --- a/src/ast/pltype.rs +++ b/src/ast/pltype.rs @@ -67,6 +67,7 @@ pub enum PLType { Trait(STType), Union(UnionType), Closure(ClosureType), + Unknown, } impl TraitImplAble for PriType { @@ -87,7 +88,16 @@ pub struct ClosureType { impl PartialEq for ClosureType { fn eq(&self, other: &Self) -> bool { - self.arg_types == other.arg_types && self.ret_type == other.ret_type + if self.arg_types.len() != other.arg_types.len() { + return false; + } + for i in 0..self.arg_types.len() { + if get_type_deep(self.arg_types[i].clone()) != get_type_deep(other.arg_types[i].clone()) + { + return false; + } + } + get_type_deep(self.ret_type.clone()) == get_type_deep(other.ret_type.clone()) } } @@ -343,12 +353,14 @@ fn new_typename_node(name: &str, range: Range, ns: &[String]) -> Box Resu .add_to_ctx(ctx)) } impl PLType { + pub fn new_i8_ptr() -> PLType { + PLType::Pointer(Arc::new(RefCell::new(PLType::Primitive(PriType::I8)))) + } + pub fn new_i64() -> PLType { + PLType::Primitive(PriType::I64) + } #[cfg(feature = "llvm")] pub fn get_immix_type(&self) -> ObjectType { match self { @@ -421,8 +439,17 @@ impl PLType { PLType::Struct(s) => s.implements_trait(tp, ctx), PLType::Union(u) => u.implements_trait(tp, ctx), _ => { - let plmod = &ctx.db.get_module(&tp.path).unwrap(); - if impl_in_mod(plmod, name, tp) { + let plmod = if tp.path == ctx.plmod.path { + ctx.plmod.clone() + } else { + ctx.db.get_module(&tp.path).unwrap_or_else(|| { + panic!( + "expect module {} exists, trait name: {}", + &tp.path, &tp.name + ) + }) + }; + if impl_in_mod(&plmod, name, tp) { return true; } else { for m in plmod.submods.values() { @@ -455,6 +482,7 @@ impl PLType { PLType::Trait(_) => "trait".to_string(), PLType::Union(_) => "union".to_string(), PLType::Closure(_) => "closure".to_string(), + PLType::Unknown => "unknown".to_string(), } } @@ -482,9 +510,10 @@ impl PLType { } PLType::PlaceHolder(p) => new_typename_node(&p.name, Default::default(), &[]), PLType::Trait(t) => Self::new_custom_tp_node(t, path), - PLType::Fn(_) => unreachable!(), + PLType::Fn(_) => new_typename_node("Unknown", Default::default(), &[]), PLType::Union(u) => Self::new_custom_tp_node(u, path), PLType::Closure(c) => Box::new(c.to_type_node(path)), + PLType::Unknown => new_typename_node("Unknown", Default::default(), &[]), } } pub fn is(&self, pri_type: &PriType) -> bool { @@ -506,6 +535,7 @@ impl PLType { PLType::Generic(g) => f_local(g), PLType::PlaceHolder(_) => (), PLType::Closure(_) => (), + PLType::Unknown => (), } } @@ -539,6 +569,7 @@ impl PLType { PLType::Trait(t) => t.name.clone(), PLType::Union(u) => u.name.clone(), PLType::Closure(c) => c.get_name(), + PLType::Unknown => "Unknown".to_string(), } } pub fn get_llvm_name(&self) -> String { @@ -562,6 +593,7 @@ impl PLType { PLType::PlaceHolder(p) => p.get_place_holder_name(), PLType::Union(u) => u.name.clone(), PLType::Closure(c) => c.get_name(), + PLType::Unknown => "Unknown".to_string(), } } @@ -586,6 +618,7 @@ impl PLType { PLType::PlaceHolder(p) => p.name.clone(), PLType::Union(u) => u.get_full_name(), PLType::Closure(c) => c.get_name(), + PLType::Unknown => "Unknown".to_string(), } } pub fn get_full_elm_name_without_generic(&self) -> String { @@ -603,6 +636,7 @@ impl PLType { PLType::PlaceHolder(p) => p.name.clone(), PLType::Union(u) => u.get_full_name_except_generic(), PLType::Closure(c) => c.get_name(), + PLType::Unknown => "Unknown".to_string(), } } pub fn get_ptr_depth(&self) -> usize { @@ -701,6 +735,7 @@ impl PLType { PLType::Trait(t) => Some(t.range), PLType::Union(u) => Some(u.range), PLType::Closure(c) => Some(c.range), + PLType::Unknown => None, } } @@ -1088,28 +1123,8 @@ impl PartialEq for STType { fn eq(&self, other: &Self) -> bool { if self.is_tuple && other.is_tuple { self.fields == other.fields - } else if self.is_trait && other.is_trait { - self.name == other.name - && self.path == other.path - && self.range == other.range - && self.derives == other.derives - && self.modifier == other.modifier - && self.body_range == other.body_range - && self.is_trait == other.is_trait - && self.is_tuple == other.is_tuple - && self.generic_map == other.generic_map } else { - self.name == other.name - && self.path == other.path - && self.fields == other.fields - && self.range == other.range - && self.doc == other.doc - && self.generic_map == other.generic_map - && self.derives == other.derives - && self.modifier == other.modifier - && self.body_range == other.body_range - && self.is_trait == other.is_trait - && self.is_tuple == other.is_tuple + self.name == other.name && self.path == other.path } } } diff --git a/src/inference/mod.rs b/src/inference/mod.rs new file mode 100644 index 000000000..9b7285200 --- /dev/null +++ b/src/inference/mod.rs @@ -0,0 +1,830 @@ +//! # Inference +//! +//! This module is used to do type inference. +//! +//! ## How it works +//! +//! The basic idea is that most of statements +//! have type constraints, for example, +//! +//! ```pl +//! a = b +//! ``` +//! +//! The type of `a` and `b` shall be the same. +//! +//! So we can use a unify table to record the type relationship. +//! It's not always necessary to generate all the type constraints, +//! as type inference will only take effect when the type is unknown. +//! +//! > What is unify table? +//! > +//! > Unify table is very much like a hashtable, but it can map +//! > multiple keys to the same value. In type inference, one +//! > variable has many constraints, and different variables' +//! > constraints may be the same. So we use unify table to +//! > record the constraints. +use std::{cell::RefCell, sync::Arc}; + +use ena::unify::{UnificationTable, UnifyKey, UnifyValue}; +use rustc_hash::FxHashMap; + +use crate::ast::{ + builder::BuilderEnum, + ctx::Ctx, + node::{ + pointer::PointerOpEnum, + statement::{DefVar, StatementsNode}, + NodeEnum, TypeNode, + }, + pltype::{get_type_deep, ClosureType, PLType}, + tokens::TokenType, +}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TyVariable { + id: u32, +} + +impl UnifyKey for TyVariable { + type Value = TyInfer; + + fn index(&self) -> u32 { + self.id + } + + fn from_index(u: u32) -> Self { + Self { id: u } + } + + fn tag() -> &'static str { + "TyVariable" + } +} + +/// # The type inference result +/// +/// ## Term +/// +/// A `Term`` is a `PLType`. When it's `PLType::Unknown`, +/// it means that the type is not inferred yet. An `Unknown` +/// type unify with any other type will become the other type. +/// +/// ## Err +/// +/// If an inference error occurs, the type will be `Err`. +/// +/// `Err`'s type is `PLType::Unknown`. +/// However, it's not the same as `Term(PLType::Unknown)`. +/// When `Err` unify with any other type, it will always become `Err`. +/// +/// ## Closure +/// +/// A `Closure` is a function type. It contains a list of argument types, +/// and a return type. +/// +/// As the function type may not be inferred yet, the argument types and return type +/// are all `TyVariable`, which allows them to unify with other `TyVariable`s. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TyInfer { + Err, + Term(Arc>), + Closure((Vec, TyVariable)), +} + +impl UnifyValue for TyInfer { + fn unify_values(value1: &Self, value2: &Self) -> Result { + if matches!(value1, TyInfer::Err) || matches!(value2, TyInfer::Err) { + return Ok(TyInfer::Err); + } + + // if there's no error, then set unknown to the real type + if value1 == value2 { + Ok(value1.clone()) + } else if matches!(value1, TyInfer::Term(ty) if *get_type_deep(ty.clone()).borrow()== PLType::Unknown) + { + Ok(value2.clone()) + } else if matches!(value2, TyInfer::Term(ty) if *get_type_deep(ty.clone()).borrow()== PLType::Unknown) + || matches!(value2, TyInfer::Closure(_)) + || matches!(value1, TyInfer::Closure(_)) + { + Ok(value1.clone()) + } else { + Ok(TyInfer::Err) + } + } +} + +impl TyInfer { + pub fn get_type(&self, unify_table: &mut UnificationTable) -> Arc> { + match self { + TyInfer::Term(ty) => ty.clone(), + TyInfer::Closure((args, ty)) => { + let mut argtys = vec![]; + for arg in args { + argtys.push(unify_table.probe(*arg).get_type(unify_table)); + } + let ret_ty = unify_table.probe(*ty).get_type(unify_table); + Arc::new(RefCell::new(PLType::Closure(ClosureType { + arg_types: argtys, + ret_type: ret_ty, + range: Default::default(), + }))) + } + _ => unknown_arc(), + } + } +} + +pub struct InferenceCtx<'ctx> { + unify_table: Arc>>, + symbol_table: FxHashMap, + father: Option<&'ctx InferenceCtx<'ctx>>, +} + +fn unknown() -> SymbolType { + SymbolType::PLType(Arc::new(RefCell::new(PLType::Unknown))) +} + +fn unknown_arc() -> Arc> { + Arc::new(RefCell::new(PLType::Unknown)) +} +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SymbolType { + Var(TyVariable), + PLType(Arc>), +} + +impl<'ctx> InferenceCtx<'ctx> { + pub fn new(table: Arc>>) -> Self { + Self { + unify_table: table, + symbol_table: FxHashMap::default(), + father: None, + } + } + + pub fn new_child(&'ctx self) -> Self { + Self { + unify_table: self.unify_table.clone(), + symbol_table: FxHashMap::default(), + father: Some(self), + } + } + + pub fn add_symbol(&mut self, name: &str, ty: TyVariable) { + if !self.symbol_table.contains_key(name) { + self.symbol_table.insert(name.to_string(), ty); + } + } + + pub fn get_symbol(&self, name: &str) -> Option { + if let Some(ty) = self.symbol_table.get(name) { + return Some(*ty); + } + if let Some(father) = self.father { + return father.get_symbol(name); + } + None + } + + pub fn unify<'a, 'b>( + &self, + var: TyVariable, + value: SymbolType, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, '_>, + ) { + self.unify_two_symbol(SymbolType::Var(var), value, ctx, builder) + } + + pub fn unify_two_symbol<'a, 'b>( + &self, + var1: SymbolType, + var2: SymbolType, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, '_>, + ) { + match (var1, var2) { + (SymbolType::Var(v), SymbolType::Var(v2)) => { + let v_1 = self.unify_table.borrow_mut().probe(v); + let v_2 = self.unify_table.borrow_mut().probe(v2); + match (v_1, v_2) { + (TyInfer::Closure(c1), TyInfer::Closure(c2)) => { + if c1.0.len() == c2.0.len() { + for (i, arg) in c1.0.iter().enumerate() { + self.unify_two_symbol( + SymbolType::Var(*arg), + SymbolType::Var(c2.0[i]), + ctx, + builder, + ); + } + } + self.unify_two_symbol( + SymbolType::Var(c1.1), + SymbolType::Var(c2.1), + ctx, + builder, + ); + } + (TyInfer::Closure(_), TyInfer::Term(t)) => { + self.unify_var_tp(v, t, ctx, builder); + } + (TyInfer::Term(t), TyInfer::Closure(_)) => { + self.unify_var_tp(v2, t, ctx, builder); + } + _ => (), + }; + self.unify_table.borrow_mut().unify_var_var(v, v2).unwrap(); + } + (SymbolType::Var(v), SymbolType::PLType(tp)) => { + self.unify_var_tp(v, tp, ctx, builder); + } + (SymbolType::PLType(tp), SymbolType::Var(v)) => { + self.unify_var_tp(v, tp, ctx, builder); + } + (SymbolType::PLType(_), SymbolType::PLType(_)) => (), + } + } + + pub fn unify_var_tp<'a, 'b>( + &self, + var: TyVariable, + tp: Arc>, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, '_>, + ) { + let ty = self.unify_table.borrow_mut().probe(var); + // check if closure + match (ty, &*tp.borrow()) { + (TyInfer::Closure(c1), PLType::Closure(c2)) => { + if c1.0.len() == c2.arg_types.len() { + for (i, arg) in c1.0.iter().enumerate() { + self.unify_var_tp(*arg, c2.arg_types[i].clone(), ctx, builder); + } + } + self.unify_var_tp(c1.1, c2.ret_type.clone(), ctx, builder); + } + (TyInfer::Closure(c1), PLType::Fn(c2)) => { + if c1.0.len() == c2.fntype.param_pltypes.len() { + for (i, arg) in c1.0.iter().enumerate() { + self.unify_var_tp( + *arg, + c2.fntype.param_pltypes[i] + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ctx, + builder, + ); + } + } + self.unify_var_tp( + c1.1, + c2.fntype + .ret_pltype + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ctx, + builder, + ); + } + _ => (), + } + self.unify_table + .borrow_mut() + .unify_var_value(var, TyInfer::Term(tp)) + .unwrap(); + } + + pub fn import_symbols(&mut self, ctx: &Ctx) { + for (name, ty) in &ctx.table { + // self.add_symbol(name, ); + let key = self.new_key(); + self.unify_table + .borrow_mut() + .unify_var_value(key, TyInfer::Term(ty.pltype.clone())) + .unwrap(); + self.add_symbol(name, key); + } + } + #[allow(dead_code)] + pub fn add_unknown_variable(&mut self, name: &str) { + let key = self.new_key(); + self.add_symbol(name, key); + } + + pub fn import_global_symbols(&mut self, ctx: &Ctx) { + let ctx = ctx.get_root_ctx(); + for (name, ty) in &ctx.plmod.global_table { + // self.add_symbol(name, ); + let key = self.new_key(); + self.unify_table + .borrow_mut() + .unify_var_value(key, TyInfer::Term(ty.tp.clone())) + .unwrap(); + self.add_symbol(name, key); + } + } + + pub fn inference_statements<'a, 'b>( + &mut self, + node: &mut StatementsNode, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, '_>, + ) { + let prev = ctx.disable_diag; + ctx.disable_diag = true; + for s in &mut node.statements { + self.inference(&mut *s, ctx, builder); + } + ctx.disable_diag = prev; + } + + pub fn new_key(&self) -> TyVariable { + self.unify_table + .borrow_mut() + .new_key(TyInfer::Term(unknown_arc())) + } + + pub fn inference<'a, 'b>( + &mut self, + node: &mut NodeEnum, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, '_>, + ) -> SymbolType { + match node { + NodeEnum::Def(d) => { + let mut ty = unknown(); + if let Some(exp) = &mut d.exp { + ty = self.inference(&mut *exp, ctx, builder); + } + if let Some(tp) = &d.tp { + let new_ty = SymbolType::PLType( + tp.get_type(ctx, builder, true).unwrap_or(unknown_arc()), + ); + ty = new_ty; + } + match &mut *d.var { + DefVar::Identifier(v) => { + let id = self.new_key(); + self.unify(id, ty, ctx, builder); + v.id = Some(id); + self.add_symbol(&v.name, id); + } + DefVar::TupleDeconstruct(_) => (), + DefVar::StructDeconstruct(_) => (), + } + } + NodeEnum::Assign(a) => { + let ty = self.inference(&mut a.exp, ctx, builder); + match &mut a.var { + crate::ast::node::statement::AssignVar::Pointer(p) => { + let re = self.inference(&mut *p, ctx, builder); + self.unify_two_symbol(re, ty, ctx, builder); + } + crate::ast::node::statement::AssignVar::Raw(d) => match &mut **d { + DefVar::Identifier(v) => { + let id = self.new_key(); + v.id = Some(id); + if let Some(ty) = self.get_symbol(&v.name) { + self.unify_two_symbol( + SymbolType::Var(id), + SymbolType::Var(ty), + ctx, + builder, + ); + } + self.unify(id, ty, ctx, builder); + } + DefVar::TupleDeconstruct(_) => (), + DefVar::StructDeconstruct(_) => (), + }, + } + } + NodeEnum::Expr(e) => match e.op.0 { + TokenType::EQ + | TokenType::NE + | TokenType::LEQ + | TokenType::GEQ + | TokenType::GREATER + | TokenType::LESS => { + let i1 = self.inference(&mut e.left, ctx, builder); + let i2 = self.inference(&mut e.right, ctx, builder); + self.unify_two_symbol(i1, i2, ctx, builder); + return SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::BOOL, + ))); + } + _ => { + let i1 = self.inference(&mut e.left, ctx, builder); + let i2 = self.inference(&mut e.right, ctx, builder); + self.unify_two_symbol(i1.clone(), i2, ctx, builder); + return i1; + } + }, + NodeEnum::ExternIdNode(ex) => { + if ex.ns.is_empty() { + let id = self.new_key(); + ex.id.id = Some(id); + if let Some(t) = self.get_symbol(&ex.id.name) { + self.unify_two_symbol( + SymbolType::Var(id), + SymbolType::Var(t), + ctx, + builder, + ); + return SymbolType::Var(id); + } + if let Some(r) = ctx.get_root_ctx().plmod.types.get(&ex.id.name) { + if let PLType::Fn(f) = &*r.tp.clone().borrow() { + if f.fntype.generic { + return unknown(); + } + let mut argtys = vec![]; + for arg in &f.fntype.param_pltypes { + let arg = arg.get_type(ctx, builder, true).unwrap_or(unknown_arc()); + let arg_key = self.new_key(); + self.unify(arg_key, SymbolType::PLType(arg), ctx, builder); + argtys.push(arg_key); + } + let ret_ty = self.new_key(); + self.unify( + ret_ty, + SymbolType::PLType( + f.fntype + .ret_pltype + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ), + ctx, + builder, + ); + self.unify_table + .borrow_mut() + .unify_var_value(id, TyInfer::Closure((argtys, ret_ty))) + .unwrap(); + return SymbolType::Var(id); + } + } + } + } + NodeEnum::Bool(_) => { + return SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::BOOL, + ))) + } + NodeEnum::Num(n) => match n.value { + crate::ast::node::Num::Int(_) => { + return SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::I64, + ))) + } + crate::ast::node::Num::Float(_) => { + return SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::F64, + ))) + } + }, + NodeEnum::Primary(p) => { + return self.inference(&mut p.value, ctx, builder); + } + NodeEnum::AsNode(a) => { + if a.tail.is_none() || a.tail.unwrap().0 == TokenType::NOT { + let tp = a.ty.get_type(ctx, builder, true).unwrap_or(unknown_arc()); + return SymbolType::PLType(tp); + } + } + NodeEnum::ClosureNode(c) => { + let mut child = self.new_child(); + let mut argtys = vec![]; + for (var, ty) in &mut c.paralist { + let key_id = child.new_key(); + match ty { + Some(ty) => { + let arg_ty = ty.get_type(ctx, builder, true).unwrap_or(unknown_arc()); + child.unify(key_id, SymbolType::PLType(arg_ty), ctx, builder); + } + None => {} + } + argtys.push(key_id); + child.add_symbol(&var.name, key_id); + var.id = Some(key_id); + } + let ret_ty = child.new_key(); + child.unify( + ret_ty, + SymbolType::PLType( + c.ret + .as_ref() + .and_then(|r| r.get_type(ctx, builder, true).ok()) + .unwrap_or(unknown_arc()), + ), + ctx, + builder, + ); + let id = child.new_key(); + child + .unify_table + .borrow_mut() + .unify_var_value(id, TyInfer::Closure((argtys, ret_ty))) + .unwrap(); + child.add_symbol("@ret", ret_ty); + c.ret_id = Some(ret_ty); + child.inference_statements(&mut c.body, ctx, builder); + return SymbolType::Var(id); + } + NodeEnum::FuncCall(fc) => { + let mut argtys = vec![]; + for arg in &mut fc.paralist { + let arg_ty = self.inference(&mut *arg, ctx, builder); + argtys.push(arg_ty); + } + let func_ty = self.inference(&mut fc.callee, ctx, builder); + match func_ty { + SymbolType::Var(id) => { + let k = self.unify_table.borrow_mut().probe(id); + match k { + TyInfer::Err => (), + TyInfer::Term(t) => match &*t.borrow() { + PLType::Closure(c) => { + if c.arg_types.len() != argtys.len() { + return unknown(); + } + + for (i, arg) in c.arg_types.iter().enumerate() { + self.unify_two_symbol( + argtys[i].clone(), + SymbolType::PLType(arg.clone()), + ctx, + builder, + ); + } + return SymbolType::PLType(c.ret_type.clone()); + } + PLType::Fn(f) => { + if f.fntype.param_pltypes.len() != argtys.len() { + return unknown(); + } + + for (i, arg) in f.fntype.param_pltypes.iter().enumerate() { + self.unify_two_symbol( + argtys[i].clone(), + SymbolType::PLType( + arg.get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ), + ctx, + builder, + ); + } + return SymbolType::PLType( + f.fntype + .ret_pltype + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ); + } + PLType::Unknown => { + let mut arg_keys = vec![]; + for arg in argtys { + let arg_key = self.new_key(); + self.unify(arg_key, arg, ctx, builder); + arg_keys.push(arg_key); + } + let ret_ty = self.new_key(); + self.unify_table + .borrow_mut() + .unify_var_value(id, TyInfer::Closure((arg_keys, ret_ty))) + .unwrap(); + return SymbolType::Var(ret_ty); + } + _ => (), + }, + TyInfer::Closure((args, ret)) => { + if args.len() != argtys.len() { + return unknown(); + } + for (i, arg) in args.iter().enumerate() { + self.unify(*arg, argtys[i].clone(), ctx, builder); + } + return SymbolType::Var(ret); + } + } + } + SymbolType::PLType(tp) => match &*tp.borrow() { + PLType::Fn(f) => { + if f.fntype.param_pltypes.len() != argtys.len() { + return unknown(); + } + for (i, arg) in f.fntype.param_pltypes.iter().enumerate() { + self.unify_two_symbol( + argtys[i].clone(), + SymbolType::PLType( + arg.get_type(ctx, builder, true) + .unwrap_or(unknown_arc()) + .clone(), + ), + ctx, + builder, + ); + } + return SymbolType::PLType( + f.fntype + .ret_pltype + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ); + } + PLType::Closure(cl) => { + if cl.arg_types.len() != argtys.len() { + return unknown(); + } + for (i, arg) in cl.arg_types.iter().enumerate() { + self.unify_two_symbol( + argtys[i].clone(), + SymbolType::PLType(arg.clone()), + ctx, + builder, + ); + } + return SymbolType::PLType(cl.ret_type.clone()); + } + _ => (), + }, + } + } + NodeEnum::If(i) => { + let cond_ty = self.inference(&mut i.cond, ctx, builder); + self.unify_two_symbol( + cond_ty, + SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::BOOL, + ))), + ctx, + builder, + ); + let mut child_then = self.new_child(); + child_then.inference_statements(&mut i.then, ctx, builder); + if let Some(else_) = &mut i.els { + self.inference(else_, ctx, builder); + } + } + NodeEnum::Sts(sts) => { + let mut child = self.new_child(); + child.inference_statements(sts, ctx, builder); + } + NodeEnum::Ret(r) => { + let ret = self.get_symbol("@ret"); + if r.yiel.is_some() { + return unknown(); + } + if let Some(ret) = ret { + if let Some(r) = &mut r.value { + let ty = self.inference(&mut *r, ctx, builder); + self.unify(ret, ty, ctx, builder); + } else { + self.unify( + ret, + SymbolType::PLType(new_arc_refcell(PLType::Void)), + ctx, + builder, + ); + } + } + } + NodeEnum::While(w) => { + let cond_ty = self.inference(&mut w.cond, ctx, builder); + self.unify_two_symbol( + cond_ty, + SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::BOOL, + ))), + ctx, + builder, + ); + let mut child = self.new_child(); + child.inference_statements(&mut w.body, ctx, builder); + } + NodeEnum::For(f) => { + let mut child = self.new_child(); + if let Some(pre) = &mut f.pre { + child.inference(&mut *pre, ctx, builder); + } + let cond_ty = child.inference(&mut f.cond, ctx, builder); + child.unify_two_symbol( + cond_ty, + SymbolType::PLType(new_arc_refcell(PLType::Primitive( + crate::ast::pltype::PriType::BOOL, + ))), + ctx, + builder, + ); + if let Some(post) = &mut f.opt { + child.inference(&mut *post, ctx, builder); + } + + child.inference_statements(&mut f.body, ctx, builder); + } + NodeEnum::Take(tk) => { + let head = self.inference(&mut tk.head, ctx, builder); + if tk.field.is_none() { + return head; + } + match head { + SymbolType::Var(v) => { + let k = self.unify_table.borrow_mut().probe(v); + match k { + TyInfer::Term(t) => { + let tp = ctx.auto_deref_tp(t); + match &*tp.borrow() { + PLType::Struct(a) => { + let f = a.fields.get(&tk.field.as_ref().unwrap().name); + if let Some(f) = f { + return SymbolType::PLType( + f.typenode + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()), + ); + } + } + _ => (), + }; + } + _ => (), + } + } + SymbolType::PLType(_) => (), + } + } + NodeEnum::StructInit(si) => { + let ty = si + .typename + .get_type(ctx, builder, true) + .unwrap_or(unknown_arc()); + match &*ty.clone().borrow() { + PLType::Struct(s) => { + if s.generic_map.is_empty() { + return SymbolType::PLType(ty); + } + } + _ => (), + }; + } + NodeEnum::Un(u) => { + return self.inference(&mut u.exp, ctx, builder); + } + NodeEnum::PointerOpNode(p) => { + let ty = self.inference(&mut p.value, ctx, builder); + match ty { + SymbolType::Var(v) => { + let k = self.unify_table.borrow_mut().probe(v); + match k { + TyInfer::Term(t) => { + if p.op == PointerOpEnum::Addr { + if *t.borrow() != PLType::Unknown { + return SymbolType::PLType(new_arc_refcell( + PLType::Pointer(t), + )); + } + } else if p.op == PointerOpEnum::Deref { + match &*t.borrow() { + PLType::Pointer(p) => { + return SymbolType::PLType(p.clone()); + } + _ => (), + } + } + } + _ => (), + } + } + SymbolType::PLType(t) => { + if p.op == PointerOpEnum::Addr { + if *t.borrow() != PLType::Unknown { + return SymbolType::PLType(new_arc_refcell(PLType::Pointer(t))); + } + } else if p.op == PointerOpEnum::Deref { + match &*t.borrow() { + PLType::Pointer(p) => { + return SymbolType::PLType(p.clone()); + } + _ => (), + } + } + } + } + } + NodeEnum::ParanthesesNode(p) => { + return self.inference(&mut p.node, ctx, builder); + } + + _ => (), + } + unknown() + } +} + +fn new_arc_refcell(t: T) -> Arc> { + Arc::new(RefCell::new(t)) +} diff --git a/src/lib.rs b/src/lib.rs index 65e4dc9e5..8caf5353f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,3 +14,6 @@ mod lsp; mod nomparser; #[cfg(target_arch = "wasm32")] mod utils; + +#[cfg(target_arch = "wasm32")] +mod inference; diff --git a/src/main.rs b/src/main.rs index 149a8dc15..87c104568 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ #![allow(suspicious_double_ref_op)] #![allow(clippy::derive_ord_xor_partial_ord)] #![allow(clippy::missing_safety_doc)] +#![allow(clippy::single_match)] mod jar; pub use jar::*; @@ -9,6 +10,7 @@ pub use jar::*; mod ast; mod db; mod flow; +mod inference; mod lsp; mod nomparser; mod utils; diff --git a/src/nomparser/control.rs b/src/nomparser/control.rs index 6e9e08499..f1cd0cdf7 100644 --- a/src/nomparser/control.rs +++ b/src/nomparser/control.rs @@ -119,7 +119,7 @@ pub fn while_statement(input: Span) -> IResult> { delspace(tuple(( tag_token_word(TokenType::WHILE), alt_except( - general_exp, + parse_with_ex(general_exp, true), "{", "failed to parse while condition", ErrorCode::WHILE_CONDITION_MUST_BE_BOOL, diff --git a/src/nomparser/expression.rs b/src/nomparser/expression.rs index ecd285649..7a5e4acfa 100644 --- a/src/nomparser/expression.rs +++ b/src/nomparser/expression.rs @@ -22,6 +22,7 @@ use internal_macro::{test_parser, test_parser_error}; use super::{cast::as_exp, macro_parse::macro_call_op, string_literal::string_literal, *}; +#[test_parser("a")] pub fn general_exp(input: Span) -> IResult> { logic_exp(input) } @@ -359,6 +360,7 @@ fn closure(input: Span) -> IResult> { paralist: args, body, ret, + ret_id: None, } .into(), ) diff --git a/src/nomparser/identifier.rs b/src/nomparser/identifier.rs index 7450bb21f..05ec1bbfb 100644 --- a/src/nomparser/identifier.rs +++ b/src/nomparser/identifier.rs @@ -81,6 +81,7 @@ pub fn identifier(input: Span) -> IResult> { Ok(Box::new(VarNode { name: out.to_string(), range: Range::new(out, out.take_split(out.len()).0), + id: None, })) }, ))(input) @@ -93,6 +94,7 @@ pub fn tuple_field_identifier(input: Span) -> IResult> { Ok::<_, ()>(Box::new(VarNode { name: out.to_string(), range: Range::new(out, out.take_split(out.len()).0), + id: None, })) }, ))(input) diff --git a/src/nomparser/program.rs b/src/nomparser/program.rs index 4dc81b4fd..e7c60f0a8 100644 --- a/src/nomparser/program.rs +++ b/src/nomparser/program.rs @@ -72,6 +72,7 @@ pub fn program(input: Span) -> IResult> { id: VarNode { name: "self".to_string(), range: Default::default(), + id: None, }, typenode: Box::new(TypeNodeEnum::Pointer(PointerTypeNode { elm: Box::new(target.clone()), diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ff3065f65..dd211f40f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,6 @@ pub mod plc_new; pub mod read_config; -pub mod test_symbol; +// pub mod test_symbol; use std::{ collections::hash_map::DefaultHasher, diff --git a/test/main.pi b/test/main.pi index aaf4d1767..2b0368bb7 100644 --- a/test/main.pi +++ b/test/main.pi @@ -28,6 +28,7 @@ use project1::test::iter; use std::libc; use std::io; use std::cols::hashtable; +use project1::test::inference; use core::hash::pl_hash::*; @@ -57,6 +58,7 @@ pub fn main() i64 { generic::ret_generic1(); test_compile_time_reflection(); iter::test_generator(); + inference::test_inference(); let fd = io::open_read("./Cargo.toml"); @@ -102,4 +104,3 @@ pub fn main() i64 { return 0; } - diff --git a/test/test/closure.pi b/test/test/closure.pi index bafdb49f1..01cc66d5c 100644 --- a/test/test/closure.pi +++ b/test/test/closure.pi @@ -2,9 +2,9 @@ use core::panic; pub fn test_closure() i64 { let b = 1; let bb = 100; - test_type_infer(|a| => { + test_type_infer(|a| => { let c = b; - let aa = |a: i64| => i64 { + let aa = |a| => { let c = bb; return c; }; @@ -12,6 +12,25 @@ pub fn test_closure() i64 { panic::assert(re == 100); return c; }); + let fff = |a:|i64| => i64| => { + return; + }; + let ddd = |a| => { + return; + }; + fff = ddd; + let eee = |a| => { + let c = b; + let aa = |a| => { + let c = bb; + return c; + }; + let re = aa(2); + panic::assert(re == 100); + return c; + }; + fff(eee); + let f = test_ret_closure(); let d = f(2); panic::assert(d == 1); diff --git a/test/test/fixed_point.pi b/test/test/fixed_point.pi index 3245b3cd6..19e893c0a 100644 --- a/test/test/fixed_point.pi +++ b/test/test/fixed_point.pi @@ -1,6 +1,6 @@ use core::panic; pub fn test_fixed_point() void { - let g = |f: |i64| => i64, x: i64| => i64 { + let g = |f, x| => { if x == 0 { return 1; } @@ -30,7 +30,7 @@ fn Y(g: ||A| => R, A| => R) |A| => R { return f.call(f, x); }(Func{ f: |f: Func, x: A| => R { - return g(|x: A| => R { + return g(|x| => { return f.call(f, x); }, x); } diff --git a/test/test/inference.pi b/test/test/inference.pi new file mode 100644 index 000000000..0a1602c6a --- /dev/null +++ b/test/test/inference.pi @@ -0,0 +1,37 @@ +pub fn test_inference() void { + let x; + x = 100; + let y; + y = x; + + let z; + test_f_infer(z); + let d; + d = test_f_infer(z); + let a = t{}; + let l; + l = a.a; + let h = |a| =>{ + return a==a; + }; + h(100); + let hh = |a,b,c| =>{ + return b; + }(1,1.2,t{}); + let yy; + yy = ~y; + let xx ; + xx = &(~y); + + return; +} + +fn test_f_infer(x:i32) i128 { + + return x as i128; +} + + +struct t { + a:i16; +} \ No newline at end of file diff --git a/test/test/iter.pi b/test/test/iter.pi index 92ea437fb..f65ad191d 100644 --- a/test/test/iter.pi +++ b/test/test/iter.pi @@ -13,6 +13,7 @@ pub fn test_generator() void { next = iterator.next(); assert(next is i64); i = next as i64!; + println!(i); assert(i == 100); next = iterator.next(); assert(next is i64); diff --git a/test/test/simple.pi b/test/test/simple.pi index 13b07814e..ef60f95a2 100644 --- a/test/test/simple.pi +++ b/test/test/simple.pi @@ -29,7 +29,7 @@ pub fn test_primitives() void { panic::assert(f as i64 == -64); let g = test >>> 1; panic::assert(g as i64 == 64); - let h = ~test; + let h = ~test ; panic::assert(h as i64 == 127); return; } diff --git a/vm/src/gc/mod.rs b/vm/src/gc/mod.rs index 1323fb799..cdf084c25 100644 --- a/vm/src/gc/mod.rs +++ b/vm/src/gc/mod.rs @@ -45,7 +45,6 @@ mod _immix { trace!("malloc: {} {}", size, obj_type); #[cfg(any(test, debug_assertions))] // enable eager gc in test mode immix::gc_collect(); - // immix::gc_disable_auto_collect(); let re = gc_malloc(size as usize, obj_type); if re.is_null() && size != 0 { eprintln!("gc malloc failed! (OOM)"); @@ -56,6 +55,21 @@ mod _immix { re } + pub unsafe fn disable_auto_collect() { + immix::gc_disable_auto_collect(); + } + + pub unsafe fn enable_auto_collect() { + immix::gc_enable_auto_collect(); + } + + pub unsafe fn stuck_begin() { + immix::thread_stuck_start(); + } + pub unsafe fn stuck_end() { + immix::thread_stuck_end(); + } + pub unsafe fn collect() { trace!("manual collect"); immix::gc_collect() diff --git a/vm/src/lib.rs b/vm/src/lib.rs index 55cd4cc1a..c8b12b9b9 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -1,13 +1,14 @@ #![allow(improper_ctypes_definitions)] #![allow(clippy::missing_safety_doc)] -use std::process::exit; +use std::{process::exit, sync::mpsc::channel, thread}; use backtrace::Backtrace; use internal_macro::is_runtime; pub mod gc; pub mod libcwrap; pub mod logger; +pub mod mutex; #[is_runtime] fn test_vm_link() -> i64 { @@ -62,12 +63,44 @@ fn print_i64(i: i64) { } #[is_runtime] -fn print_i128(i: i128) { - print!("{}", i); +fn new_thread(f: *mut i128) { + // f's first 8 byte is fn pointer, next 8 byte is data pointer + let ptr = f as *const i64; + let f_ptr = ptr as *const extern "C" fn(i64); + let data_ptr = unsafe { *ptr.offset(1) }; + let func = unsafe { *f_ptr }; + let (s, r) = channel::<()>(); + let ptr_i = ptr as i64; + // immix::gc_add_root(data_ptr as *mut _, ObjectType::Pointer.int_value()); + let c = move || { + // thread::sleep(std::time::Duration::from_secs(1)); + immix::gc_keep_live(ptr_i as _); + // immix::gc_add_root(&mut f as *mut _ as *mut _, ObjectType::Trait.int_value()); + s.send(()).unwrap(); + func(data_ptr); + // immix::gc_remove_root(&mut f as *mut _ as *mut _); + immix::gc_rm_live(ptr_i as _); + immix::no_gc_thread(); + }; + thread::spawn(c); + r.recv().unwrap(); +} + +#[is_runtime] +fn sleep(secs: u64) { + gc::DioGC__stuck_begin(); + println!("sleeping for {} secs", secs); + thread::sleep(std::time::Duration::from_secs(secs)); + gc::DioGC__stuck_end(); + println!("sleeping done"); } #[is_runtime] -fn print_u64(i: u64) { +fn print_u64(u: u64) { + println!("u64( {} )", u); +} +#[is_runtime] +fn print_i128(i: i128) { print!("{}", i); } diff --git a/vm/src/mutex/mod.rs b/vm/src/mutex/mod.rs new file mode 100644 index 000000000..f364b9cfb --- /dev/null +++ b/vm/src/mutex/mod.rs @@ -0,0 +1,51 @@ +use std::{ + cell::Cell, + mem, + sync::{Mutex, MutexGuard}, +}; + +use internal_macro::is_runtime; + +struct MutexContainer { + mutex: Mutex<()>, + guard: Cell>>, +} +pub struct OpaqueMutex { + _data: [usize; 0], +} + +#[is_runtime] +fn create_mutex(mutex: *mut *mut OpaqueMutex) -> u64 { + *mutex = Box::into_raw(Box::new(MutexContainer { + mutex: Mutex::new(()), + guard: Cell::new(None), + })) + .cast(); + 0 +} + +#[is_runtime] +fn lock_mutex(mutex: *mut OpaqueMutex) -> u64 { + let container: &MutexContainer = &*mutex.cast(); + let lock: MutexGuard<'static, _> = mem::transmute(container.mutex.lock().unwrap()); + container.guard.set(Some(lock)); + 0 +} + +#[is_runtime] +fn unlock_mutex(mutex: *mut OpaqueMutex) -> u64 { + let container: &MutexContainer = &*mutex.cast(); + if container.mutex.try_lock().is_ok() { + return !0; //can't unlock an unlocked mutex + } else { + container.guard.set(None); + } + 0 +} + +#[is_runtime] +fn drop_mutex(mutex: *mut OpaqueMutex) -> u64 { + unlock_mutex(mutex); + drop(Box::from_raw(mutex.cast::())); + 0 +}