diff --git a/sway-core/src/semantic_analysis/ast_node/declaration/function.rs b/sway-core/src/semantic_analysis/ast_node/declaration/function.rs index b49780c553f..b516a595110 100644 --- a/sway-core/src/semantic_analysis/ast_node/declaration/function.rs +++ b/sway-core/src/semantic_analysis/ast_node/declaration/function.rs @@ -39,7 +39,6 @@ impl ty::TyFunctionDeclaration { let type_engine = ctx.type_engine; let decl_engine = ctx.decl_engine; - let engines = ctx.engines(); // If functions aren't allowed in this location, return an error. if ctx.functions_disallowed() { @@ -60,7 +59,7 @@ impl ty::TyFunctionDeclaration { // create a namespace for the function let mut fn_namespace = ctx.namespace.clone(); - let mut fn_ctx = ctx + let mut ctx = ctx .by_ref() .scoped(&mut fn_namespace) .with_purity(purity) @@ -70,7 +69,7 @@ impl ty::TyFunctionDeclaration { let mut new_type_parameters = vec![]; for type_parameter in type_parameters.into_iter() { new_type_parameters.push(check!( - TypeParameter::type_check(fn_ctx.by_ref(), type_parameter), + TypeParameter::type_check(ctx.by_ref(), type_parameter), continue, warnings, errors @@ -84,7 +83,7 @@ impl ty::TyFunctionDeclaration { let mut new_parameters = vec![]; for parameter in parameters.into_iter() { new_parameters.push(check!( - ty::TyFunctionParameter::type_check(fn_ctx.by_ref(), parameter, is_method), + ty::TyFunctionParameter::type_check(ctx.by_ref(), parameter, is_method), continue, warnings, errors @@ -97,7 +96,7 @@ impl ty::TyFunctionDeclaration { // type check the return type let initial_return_type = type_engine.insert(decl_engine, return_type); let return_type = check!( - fn_ctx.resolve_type_with_self( + ctx.resolve_type_with_self( initial_return_type, &return_type_span, EnforceTypeArguments::Yes, @@ -113,7 +112,7 @@ impl ty::TyFunctionDeclaration { // If there are no implicit block returns, then we do not want to type check them, so we // stifle the errors. If there _are_ implicit block returns, we want to type_check them. let (body, _implicit_block_return) = { - let fn_ctx = fn_ctx + let fn_ctx = ctx .by_ref() .with_purity(purity) .with_help_text("Function body's return type does not match up with its return type annotation.") @@ -137,7 +136,7 @@ impl ty::TyFunctionDeclaration { .collect(); check!( - unify_return_statements(fn_ctx.by_ref(), &return_statements, return_type), + unify_return_statements(ctx.by_ref(), &return_statements, return_type), return err(warnings, errors), warnings, errors @@ -150,7 +149,7 @@ impl ty::TyFunctionDeclaration { (Visibility::Public, false) } } else { - (visibility, fn_ctx.mode() == Mode::ImplAbiFn) + (visibility, ctx.mode() == Mode::ImplAbiFn) }; let function_decl = ty::TyFunctionDeclaration { @@ -169,20 +168,6 @@ impl ty::TyFunctionDeclaration { purity, }; - // Retrieve the implemented traits for the type of the return type and - // insert them in the broader namespace. We don't want to include any - // type parameters, so we filter them out. - let mut return_type_namespace = fn_ctx - .namespace - .implemented_traits - .filter_by_type(function_decl.return_type, fn_ctx.engines()); - for type_param in function_decl.type_parameters.iter() { - return_type_namespace.filter_against_type(engines, type_param.type_id); - } - ctx.namespace - .implemented_traits - .extend(return_type_namespace, engines); - ok(function_decl, warnings, errors) } } diff --git a/sway-core/src/semantic_analysis/ast_node/declaration/trait_fn.rs b/sway-core/src/semantic_analysis/ast_node/declaration/trait_fn.rs index 7cffde0a0b0..a0f0145832f 100644 --- a/sway-core/src/semantic_analysis/ast_node/declaration/trait_fn.rs +++ b/sway-core/src/semantic_analysis/ast_node/declaration/trait_fn.rs @@ -69,11 +69,8 @@ impl ty::TyTraitFn { // Retrieve the implemented traits for the type of the return type and // insert them in the broader namespace. - let trait_map = fn_ctx - .namespace - .implemented_traits - .filter_by_type(trait_fn.return_type, engines); - ctx.namespace.implemented_traits.extend(trait_map, engines); + ctx.namespace + .insert_trait_implementation_for_type(engines, trait_fn.return_type); ok(trait_fn, warnings, errors) } diff --git a/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/function_application.rs b/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/function_application.rs index 1485925a326..4ebd3a1d82f 100644 --- a/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/function_application.rs +++ b/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/function_application.rs @@ -61,6 +61,11 @@ pub(crate) fn instantiate_function_application( errors ); + // Retrieve the implemented traits for the type of the return type and + // insert them in the broader namespace. + ctx.namespace + .insert_trait_implementation_for_type(engines, function_decl.return_type); + // Handle the trait constraints. This includes checking to see if the trait // constraints are satisfied and replacing old decl ids based on the // constraint with new decl ids based on the new type. diff --git a/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/method_application.rs b/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/method_application.rs index e64f4c55eb0..3d41ff8f775 100644 --- a/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/method_application.rs +++ b/sway-core/src/semantic_analysis/ast_node/expression/typed_expression/method_application.rs @@ -350,6 +350,11 @@ pub(crate) fn type_check_method_application( .map(|(param, arg)| (param.name.clone(), arg)) .collect::>(); + // Retrieve the implemented traits for the type of the return type and + // insert them in the broader namespace. + ctx.namespace + .insert_trait_implementation_for_type(engines, method.return_type); + let exp = ty::TyExpression { expression: ty::TyExpressionVariant::FunctionApplication { call_path, diff --git a/sway-core/src/semantic_analysis/namespace/namespace.rs b/sway-core/src/semantic_analysis/namespace/namespace.rs index 7981e111729..c6d18bab074 100644 --- a/sway-core/src/semantic_analysis/namespace/namespace.rs +++ b/sway-core/src/semantic_analysis/namespace/namespace.rs @@ -225,10 +225,6 @@ impl Namespace { errors ); if &method.name == method_name { - // if we find the method that we are looking for, we also need - // to retrieve the impl definitions for the return type so that - // the user can string together method calls - self.insert_trait_implementation_for_type(engines, method.return_type); return ok(decl_id, warnings, errors); } } diff --git a/sway-core/src/semantic_analysis/namespace/trait_map.rs b/sway-core/src/semantic_analysis/namespace/trait_map.rs index b579b991d5a..4f9edd2260e 100644 --- a/sway-core/src/semantic_analysis/namespace/trait_map.rs +++ b/sway-core/src/semantic_analysis/namespace/trait_map.rs @@ -619,19 +619,6 @@ impl TraitMap { trait_map } - /// Filters the contents of `self` to exclude elements that are superset - /// types of the given `type_id`. This function is used when handling trait - /// constraints and is coupled with `filter_by_type` and - /// `filter_by_type_item_import`. - pub(crate) fn filter_against_type(&mut self, engines: Engines<'_>, type_id: TypeId) { - let type_engine = engines.te(); - self.trait_impls.retain(|e| { - !type_engine - .get(type_id) - .is_subset_of(&type_engine.get(e.key.type_id), engines) - }); - } - /// Find the entries in `self` that are equivalent to `type_id`. /// /// Notes: @@ -656,7 +643,7 @@ impl TraitMap { return methods; } for entry in self.trait_impls.iter() { - if are_equal_minus_dynamic_types(type_engine, type_id, entry.key.type_id) { + if are_equal_minus_dynamic_types(engines, type_id, entry.key.type_id) { let mut trait_methods = entry .value .values() @@ -701,7 +688,7 @@ impl TraitMap { is_absolute: e.key.name.is_absolute, }; if &map_trait_name == trait_name - && are_equal_minus_dynamic_types(type_engine, type_id, e.key.type_id) + && are_equal_minus_dynamic_types(engines, type_id, e.key.type_id) { let mut trait_methods = e.value.values().cloned().into_iter().collect::>(); methods.append(&mut trait_methods); @@ -760,12 +747,8 @@ impl TraitMap { }, }, ); - if are_equal_minus_dynamic_types(type_engine, type_id, key.type_id) - && are_equal_minus_dynamic_types( - type_engine, - constraint_type_id, - map_trait_type_id, - ) + if are_equal_minus_dynamic_types(engines, type_id, key.type_id) + && are_equal_minus_dynamic_types(engines, constraint_type_id, map_trait_type_id) { found_traits.insert(constraint_trait_name.suffix.clone()); } @@ -789,10 +772,13 @@ impl TraitMap { } } -fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: TypeId) -> bool { +fn are_equal_minus_dynamic_types(engines: Engines<'_>, left: TypeId, right: TypeId) -> bool { if left.index() == right.index() { return true; } + + let type_engine = engines.te(); + match (type_engine.get(left), type_engine.get(right)) { // these cases are false because, unless left and right have the same // TypeId, they may later resolve to be different types in the type @@ -811,11 +797,16 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: (TypeInfo::UnsignedInteger(l), TypeInfo::UnsignedInteger(r)) => l == r, (TypeInfo::RawUntypedPtr, TypeInfo::RawUntypedPtr) => true, (TypeInfo::RawUntypedSlice, TypeInfo::RawUntypedSlice) => true, - (TypeInfo::UnknownGeneric { .. }, TypeInfo::UnknownGeneric { .. }) => { - // return true if left and right were unified previously - type_engine.get_unified_types(left).contains(&right) - || type_engine.get_unified_types(right).contains(&left) - } + ( + TypeInfo::UnknownGeneric { + name: rn, + trait_constraints: rtc, + }, + TypeInfo::UnknownGeneric { + name: en, + trait_constraints: etc, + }, + ) => rn.as_str() == en.as_str() && rtc.eq(&etc, engines), (TypeInfo::Placeholder(_), TypeInfo::Placeholder(_)) => false, // these cases may contain dynamic types @@ -835,11 +826,7 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: .iter() .zip(r_type_args.unwrap_or_default().iter()) .fold(true, |acc, (left, right)| { - acc && are_equal_minus_dynamic_types( - type_engine, - left.type_id, - right.type_id, - ) + acc && are_equal_minus_dynamic_types(engines, left.type_id, right.type_id) }) } ( @@ -859,22 +846,14 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: true, |acc, (left, right)| { acc && left.name == right.name - && are_equal_minus_dynamic_types( - type_engine, - left.type_id, - right.type_id, - ) + && are_equal_minus_dynamic_types(engines, left.type_id, right.type_id) }, ) && l_type_parameters.iter().zip(r_type_parameters.iter()).fold( true, |acc, (left, right)| { acc && left.name_ident == right.name_ident - && are_equal_minus_dynamic_types( - type_engine, - left.type_id, - right.type_id, - ) + && are_equal_minus_dynamic_types(engines, left.type_id, right.type_id) }, ) } @@ -896,21 +875,13 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: .zip(r_fields.iter()) .fold(true, |acc, (left, right)| { acc && left.name == right.name - && are_equal_minus_dynamic_types( - type_engine, - left.type_id, - right.type_id, - ) + && are_equal_minus_dynamic_types(engines, left.type_id, right.type_id) }) && l_type_parameters.iter().zip(r_type_parameters.iter()).fold( true, |acc, (left, right)| { acc && left.name_ident == right.name_ident - && are_equal_minus_dynamic_types( - type_engine, - left.type_id, - right.type_id, - ) + && are_equal_minus_dynamic_types(engines, left.type_id, right.type_id) }, ) } @@ -919,7 +890,7 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: false } else { l.iter().zip(r.iter()).fold(true, |acc, (left, right)| { - acc && are_equal_minus_dynamic_types(type_engine, left.type_id, right.type_id) + acc && are_equal_minus_dynamic_types(engines, left.type_id, right.type_id) }) } } @@ -937,7 +908,7 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: && Option::zip(l_address, r_address) .map(|(l_address, r_address)| { are_equal_minus_dynamic_types( - type_engine, + engines, l_address.return_type, r_address.return_type, ) @@ -945,8 +916,7 @@ fn are_equal_minus_dynamic_types(type_engine: &TypeEngine, left: TypeId, right: .unwrap_or(true) } (TypeInfo::Array(l0, l1), TypeInfo::Array(r0, r1)) => { - l1.val() == r1.val() - && are_equal_minus_dynamic_types(type_engine, l0.type_id, r0.type_id) + l1.val() == r1.val() && are_equal_minus_dynamic_types(engines, l0.type_id, r0.type_id) } _ => false, } diff --git a/sway-core/src/type_system/engine.rs b/sway-core/src/type_system/engine.rs index 4591f93217d..56c95c4e06f 100644 --- a/sway-core/src/type_system/engine.rs +++ b/sway-core/src/type_system/engine.rs @@ -22,7 +22,6 @@ pub struct TypeEngine { pub(super) slab: ConcurrentSlab, storage_only_types: ConcurrentSlab, id_map: RwLock>, - unify_map: RwLock>>, } fn make_hasher<'a: 'b, 'b, K>( @@ -62,40 +61,6 @@ impl TypeEngine { } } - pub(crate) fn insert_unified_type(&self, received: TypeId, expected: TypeId) { - let mut unify_map = self.unify_map.write().unwrap(); - if let Some(type_ids) = unify_map.get(&received) { - if type_ids.contains(&expected) { - return; - } - let mut type_ids = type_ids.clone(); - type_ids.push(expected); - unify_map.insert(received, type_ids); - return; - } - - unify_map.insert(received, vec![expected]); - } - - pub(crate) fn get_unified_types(&self, type_id: TypeId) -> Vec { - let mut final_unify_ids: Vec = vec![]; - self.get_unified_types_rec(type_id, &mut final_unify_ids); - final_unify_ids - } - - fn get_unified_types_rec(&self, type_id: TypeId, final_unify_ids: &mut Vec) { - let unify_map = self.unify_map.read().unwrap(); - if let Some(unify_ids) = unify_map.get(&type_id) { - for unify_id in unify_ids { - if final_unify_ids.contains(unify_id) { - continue; - } - final_unify_ids.push(*unify_id); - self.get_unified_types_rec(*unify_id, final_unify_ids); - } - } - } - /// Performs a lookup of `id` into the [TypeEngine]. pub fn get(&self, id: TypeId) -> TypeInfo { self.slab.get(id.index()) diff --git a/sway-core/src/type_system/unify.rs b/sway-core/src/type_system/unify.rs index 479cecd1927..df7ae87d034 100644 --- a/sway-core/src/type_system/unify.rs +++ b/sway-core/src/type_system/unify.rs @@ -225,17 +225,11 @@ impl<'a> Unifier<'a> { name: en, trait_constraints: etc, }, - ) if rn.as_str() == en.as_str() && rtc.eq(&etc, self.engines) => { - self.engines.te().insert_unified_type(received, expected); - self.engines.te().insert_unified_type(expected, received); - (vec![], vec![]) - } + ) if rn.as_str() == en.as_str() && rtc.eq(&etc, self.engines) => (vec![], vec![]), (r @ UnknownGeneric { .. }, e) if !self.occurs_check(r.clone(), &e, span) => { - self.engines.te().insert_unified_type(expected, received); self.replace_received_with_expected(received, expected, &r, e, span) } (r, e @ UnknownGeneric { .. }) if !self.occurs_check(e.clone(), &r, span) => { - self.engines.te().insert_unified_type(received, expected); self.replace_expected_with_received(received, expected, r, &e, span) } diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/main.sw b/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/main.sw index d147e525ff6..ad58849cc37 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/main.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/main.sw @@ -42,10 +42,15 @@ fn complex_vec_test() { assert(exp_vec_in_a_vec_in_a_struct_in_a_vec.get(0).unwrap().a.get(0).unwrap().get(2).unwrap() == 2); } +fn simple_option_generics_test() { + assert(get_an_option::().is_none()); +} + fn main() { sell_product(); simple_vec_test(); complex_vec_test(); + simple_option_generics_test(); } fn sell_product() -> MyResult { diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/utils.sw b/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/utils.sw index 4e5eb06dd2e..d91973b7e12 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/utils.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/language/generic_type_inference/src/utils.sw @@ -7,3 +7,7 @@ pub fn vec_from(vals: [u32; 3]) -> Vec { vec.push(vals[2]); vec } + +pub fn get_an_option() -> Option { + Option::None +}