Skip to content

Commit

Permalink
Add missing trait quick fix.
Browse files Browse the repository at this point in the history
commit-id:41236fd0
  • Loading branch information
gilbens-starkware committed Jul 31, 2024
1 parent 5ea64d6 commit 8112d4d
Show file tree
Hide file tree
Showing 17 changed files with 822 additions and 133 deletions.
6 changes: 6 additions & 0 deletions crates/cairo-lang-defs/src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ impl ModuleId {
}
}
}
pub fn name(&self, db: &dyn DefsGroup) -> SmolStr {
match self {
ModuleId::CrateRoot(id) => id.lookup_intern(db).name(),
ModuleId::Submodule(id) => id.name(db),
}
}
pub fn owning_crate(&self, db: &dyn DefsGroup) -> CrateId {
match self {
ModuleId::CrateRoot(crate_id) => *crate_id,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::collections::HashMap;

use cairo_lang_defs::ids::{LookupItemId, ModuleId, NamedLanguageElementId};
use cairo_lang_filesystem::ids::FileId;
use cairo_lang_filesystem::span::TextOffset;
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::expr::inference::InferenceId;
use cairo_lang_semantic::items::function_with_body::SemanticExprLookup;
use cairo_lang_semantic::lookup_item::{HasResolverData, LookupItemEx};
use cairo_lang_semantic::resolve::Resolver;
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{ast, SyntaxNode, TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::Upcast;
use tower_lsp::lsp_types::{CodeAction, CodeActionKind, Range, TextEdit, Url, WorkspaceEdit};
use tracing::debug;

use crate::ide::utils::find_methods_for_type;
use crate::lang::db::{AnalysisDatabase, LsSemanticGroup};
use crate::lang::lsp::{LsProtoGroup, ToLsp};

/// Create a Quick Fix code action to add a missing trait given a `CannotCallMethod` diagnostic.
#[tracing::instrument(level = "trace", skip_all)]
pub fn add_missing_trait(db: &AnalysisDatabase, node: &SyntaxNode, uri: Url) -> Vec<CodeAction> {
let file_id = db.file_for_url(&uri).unwrap();
let lookup_items = db.collect_lookup_items_stack(node).unwrap();
let unknown_method_name = node.get_text(db.upcast());
missing_traits_actions(db, file_id, lookup_items, node, &unknown_method_name, uri)
.unwrap_or_default()
}

/// Returns a list of code actions to add missing traits to the current module, or `None` if the
/// type is missing.
fn missing_traits_actions(
db: &AnalysisDatabase,
file_id: FileId,
lookup_items: Vec<LookupItemId>,
node: &SyntaxNode,
unknown_method_name: &str,
uri: Url,
) -> Option<Vec<CodeAction>> {
let syntax_db = db.upcast();
// Get a resolver in the current context.
let lookup_item_id = lookup_items.into_iter().next()?;
let function_with_body = lookup_item_id.function_with_body()?;
let resolver_data = lookup_item_id.resolver_data(db).ok()?;
let resolver = Resolver::with_data(
db,
resolver_data.as_ref().clone_with_inference_id(db, InferenceId::NoContext),
);
let mut expr_node = node.clone();
while expr_node.kind(db.upcast()) != SyntaxKind::ExprBinary {
expr_node = expr_node.parent()?;
}
let expr_node = ast::ExprBinary::from_syntax_node(db.upcast(), expr_node).lhs(db.upcast());
let stable_ptr = expr_node.stable_ptr().untyped();
// Get its semantic model.
let expr_id = db.lookup_expr_by_ptr(function_with_body, expr_node.stable_ptr()).ok()?;
let semantic_expr = db.expr_semantic(function_with_body, expr_id);
// Get the type.
let ty = semantic_expr.ty();
if ty.is_missing(db) {
debug!("type is missing");
return None;
}

let module_start_offset =
if let Some(ModuleId::Submodule(submodule_id)) = db.find_module_containing_node(node) {
let module_def_ast = submodule_id.stable_ptr(db.upcast()).lookup(syntax_db);
if let ast::MaybeModuleBody::Some(body) = module_def_ast.body(syntax_db) {
body.items(syntax_db).as_syntax_node().span_start_without_trivia(syntax_db)
} else {
TextOffset::default()
}
} else {
TextOffset::default()
};
let module_start_position =
module_start_offset.position_in_file(db.upcast(), file_id).unwrap().to_lsp();
let relevant_methods = find_methods_for_type(db, resolver, ty, stable_ptr);
let current_module = db.find_module_containing_node(node)?;
let module_visible_traits = db.visible_traits_from_module(current_module);
let mut code_actions = vec![];
for method in relevant_methods {
let method_name = method.name(db.upcast());
if method_name == unknown_method_name {
if let Some(trait_path) = module_visible_traits.get(&method.trait_id(db.upcast())) {
code_actions.push(CodeAction {
title: format!("Import {}", trait_path),
kind: Some(CodeActionKind::QUICKFIX),
edit: Some(WorkspaceEdit {
changes: Some(HashMap::from_iter([(
uri.clone(),
vec![TextEdit {
range: Range::new(module_start_position, module_start_position),
new_text: format!("use {};\n", trait_path),
}],
)])),
document_changes: None,
change_annotations: None,
}),
diagnostics: None,
..Default::default()
});
}
}
}
Some(code_actions)
}
2 changes: 2 additions & 0 deletions crates/cairo-lang-language-server/src/ide/code_actions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tracing::debug;
use crate::lang::db::{AnalysisDatabase, LsSyntaxGroup};
use crate::lang::lsp::{LsProtoGroup, ToCairo};

mod add_missing_trait;
mod rename_unused_variable;

/// Compute commands for a given text document and range. These commands are typically code fixes to
Expand Down Expand Up @@ -70,6 +71,7 @@ fn get_code_actions_for_diagnostic(
params.text_document.uri.clone(),
)]
}
"E0002" => add_missing_trait::add_missing_trait(db, node, params.text_document.uri.clone()),
code => {
debug!("no code actions for diagnostic code: {code}");
vec![]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ use cairo_lang_filesystem::ids::FileId;
use cairo_lang_filesystem::span::TextOffset;
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::diagnostic::{NotFoundItemType, SemanticDiagnostics};
use cairo_lang_semantic::expr::inference::infers::InferenceEmbeddings;
use cairo_lang_semantic::expr::inference::solver::SolutionSet;
use cairo_lang_semantic::expr::inference::InferenceId;
use cairo_lang_semantic::items::function_with_body::SemanticExprLookup;
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_semantic::items::us::SemanticUseEx;
use cairo_lang_semantic::lookup_item::{HasResolverData, LookupItemEx};
use cairo_lang_semantic::lsp_helpers::TypeFilter;
use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
use cairo_lang_semantic::types::peel_snapshots;
use cairo_lang_semantic::{ConcreteTypeId, Pattern, TypeLongId};
Expand All @@ -25,6 +21,7 @@ use cairo_lang_utils::{LookupIntern, Upcast};
use tower_lsp::lsp_types::{CompletionItem, CompletionItemKind, Position, Range, TextEdit};
use tracing::debug;

use crate::ide::utils::{find_methods_for_type, module_has_trait};
use crate::lang::db::{AnalysisDatabase, LsSemanticGroup};
use crate::lang::lsp::ToLsp;

Expand Down Expand Up @@ -259,7 +256,7 @@ pub fn dot_completions(

/// Returns a completion item for a method.
#[tracing::instrument(level = "trace", skip_all)]
fn completion_for_method(
pub fn completion_for_method(
db: &AnalysisDatabase,
module_id: ModuleId,
trait_function: TraitFunctionId,
Expand Down Expand Up @@ -292,72 +289,3 @@ fn completion_for_method(
};
Some(completion)
}

/// Checks if a module has a trait in scope.
#[tracing::instrument(level = "trace", skip_all)]
fn module_has_trait(
db: &AnalysisDatabase,
module_id: ModuleId,
trait_id: cairo_lang_defs::ids::TraitId,
) -> Option<bool> {
if db.module_traits_ids(module_id).ok()?.contains(&trait_id) {
return Some(true);
}
for use_id in db.module_uses_ids(module_id).ok()?.iter().copied() {
if db.use_resolved_item(use_id) == Ok(ResolvedGenericItem::Trait(trait_id)) {
return Some(true);
}
}
Some(false)
}

/// Finds all methods that can be called on a type.
#[tracing::instrument(level = "trace", skip_all)]
fn find_methods_for_type(
db: &AnalysisDatabase,
mut resolver: Resolver<'_>,
ty: cairo_lang_semantic::TypeId,
stable_ptr: cairo_lang_syntax::node::ids::SyntaxStablePtrId,
) -> Vec<TraitFunctionId> {
let type_filter = match ty.head(db) {
Some(head) => TypeFilter::TypeHead(head),
None => TypeFilter::NoFilter,
};

let mut relevant_methods = Vec::new();
// Find methods on type.
// TODO(spapini): Look only in current crate dependencies.
for crate_id in db.crates() {
let methods = db.methods_in_crate(crate_id, type_filter.clone());
for trait_function in methods.iter().copied() {
let clone_data =
&mut resolver.inference().clone_with_inference_id(db, InferenceId::NoContext);
let mut inference = clone_data.inference(db);
let lookup_context = resolver.impl_lookup_context();
// Check if trait function signature's first param can fit our expr type.
let Some((concrete_trait_id, _)) = inference.infer_concrete_trait_by_self(
trait_function,
ty,
&lookup_context,
Some(stable_ptr),
|_| {},
) else {
debug!("can't fit");
continue;
};

// Find impls for it.

// ignore the result as nothing can be done with the error, if any.
inference.solve().ok();
if !matches!(
inference.trait_solution_set(concrete_trait_id, lookup_context),
Ok(SolutionSet::Unique(_) | SolutionSet::Ambiguous(_))
) {
continue;
}
relevant_methods.push(trait_function);
}
}
relevant_methods
}
1 change: 1 addition & 0 deletions crates/cairo-lang-language-server/src/ide/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod formatter;
pub mod hover;
pub mod navigation;
pub mod semantic_highlighting;
pub mod utils;
83 changes: 83 additions & 0 deletions crates/cairo-lang-language-server/src/ide/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use cairo_lang_defs::db::DefsGroup;
use cairo_lang_defs::ids::{ModuleId, TraitFunctionId};
use cairo_lang_filesystem::db::FilesGroup;
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::expr::inference::infers::InferenceEmbeddings;
use cairo_lang_semantic::expr::inference::solver::SolutionSet;
use cairo_lang_semantic::expr::inference::InferenceId;
use cairo_lang_semantic::items::us::SemanticUseEx;
use cairo_lang_semantic::lsp_helpers::TypeFilter;
use cairo_lang_semantic::resolve::{ResolvedGenericItem, Resolver};
use tracing::debug;

use crate::lang::db::AnalysisDatabase;

/// Finds all methods that can be called on a type.
#[tracing::instrument(level = "trace", skip_all)]
pub fn find_methods_for_type(
db: &AnalysisDatabase,
mut resolver: Resolver<'_>,
ty: cairo_lang_semantic::TypeId,
stable_ptr: cairo_lang_syntax::node::ids::SyntaxStablePtrId,
) -> Vec<TraitFunctionId> {
let type_filter = match ty.head(db) {
Some(head) => TypeFilter::TypeHead(head),
None => TypeFilter::NoFilter,
};

let mut relevant_methods = Vec::new();
// Find methods on type.
// TODO(spapini): Look only in current crate dependencies.
for crate_id in db.crates() {
let methods = db.methods_in_crate(crate_id, type_filter.clone());
for trait_function in methods.iter().copied() {
let clone_data =
&mut resolver.inference().clone_with_inference_id(db, InferenceId::NoContext);
let mut inference = clone_data.inference(db);
let lookup_context = resolver.impl_lookup_context();
// Check if trait function signature's first param can fit our expr type.
let Some((concrete_trait_id, _)) = inference.infer_concrete_trait_by_self(
trait_function,
ty,
&lookup_context,
Some(stable_ptr),
|_| {},
) else {
debug!("can't fit");
continue;
};

// Find impls for it.

// ignore the result as nothing can be done with the error, if any.
inference.solve().ok();
if !matches!(
inference.trait_solution_set(concrete_trait_id, lookup_context),
Ok(SolutionSet::Unique(_) | SolutionSet::Ambiguous(_))
) {
continue;
}
relevant_methods.push(trait_function);
}
}
relevant_methods
}

/// Checks if a module has a trait in scope.
#[tracing::instrument(level = "trace", skip_all)]
pub fn module_has_trait(
db: &AnalysisDatabase,
module_id: ModuleId,
trait_id: cairo_lang_defs::ids::TraitId,
) -> Option<bool> {
if db.module_traits_ids(module_id).ok()?.contains(&trait_id) {
return Some(true);
}
// TODO(Gil): Check if the trait is visible, and return the path of the visible use item.
for use_id in db.module_uses_ids(module_id).ok()?.iter().copied() {
if db.use_resolved_item(use_id) == Ok(ResolvedGenericItem::Trait(trait_id)) {
return Some(true);
}
}
Some(false)
}
Loading

0 comments on commit 8112d4d

Please sign in to comment.