diff --git a/crates/pgt_completions/src/builder.rs b/crates/pgt_completions/src/builder.rs index bf8eb66a6..ed884ee95 100644 --- a/crates/pgt_completions/src/builder.rs +++ b/crates/pgt_completions/src/builder.rs @@ -6,6 +6,7 @@ use crate::{ use pgt_treesitter::TreesitterContext; +#[derive(Debug)] pub(crate) struct PossibleCompletionItem<'a> { pub label: String, pub description: String, diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index ba3b24813..1f404627c 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -1,13 +1,14 @@ -use pgt_schema_cache::SchemaCache; -use pgt_treesitter::{TreesitterContext, WrappingClause}; +use pgt_schema_cache::{Column, SchemaCache}; +use pgt_treesitter::TreesitterContext; use crate::{ - CompletionItemKind, + CompletionItemKind, CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, + providers::helper::get_range_to_replace, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; -use super::helper::{find_matching_alias_for_table, get_completion_text_with_schema_or_alias}; +use super::helper::with_schema_or_alias; pub fn complete_columns<'a>( ctx: &TreesitterContext<'a>, @@ -19,37 +20,39 @@ pub fn complete_columns<'a>( for col in available_columns { let relevance = CompletionRelevanceData::Column(col); - let mut item = PossibleCompletionItem { + let item = PossibleCompletionItem { label: col.name.clone(), score: CompletionScore::from(relevance.clone()), filter: CompletionFilter::from(relevance), description: format!("{}.{}", col.schema_name, col.table_name), kind: CompletionItemKind::Column, - completion_text: None, + completion_text: Some(get_completion_text(ctx, col)), detail: col.type_name.as_ref().map(|t| t.to_string()), }; - // autocomplete with the alias in a join clause if we find one - if matches!( - ctx.wrapping_clause_type, - Some(WrappingClause::Join { .. }) - | Some(WrappingClause::Where) - | Some(WrappingClause::Select) - ) { - item.completion_text = find_matching_alias_for_table(ctx, col.table_name.as_str()) - .and_then(|alias| { - get_completion_text_with_schema_or_alias(ctx, col.name.as_str(), alias.as_str()) - }); - } - builder.add_item(item); } } +fn get_completion_text(ctx: &TreesitterContext, col: &Column) -> CompletionText { + let alias = ctx.get_used_alias_for_table(col.table_name.as_str()); + + let with_schema_or_alias = with_schema_or_alias(ctx, col.name.as_str(), alias.as_deref()); + + let range = get_range_to_replace(ctx); + + CompletionText { + is_snippet: false, + range, + text: with_schema_or_alias, + } +} + #[cfg(test)] mod tests { use std::vec; + use pgt_text_size::TextRange; use sqlx::{Executor, PgPool}; use crate::{ @@ -932,4 +935,343 @@ mod tests { .await; } } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_quoted_columns(pool: PgPool) { + let setup = r#" + create schema if not exists private; + + create table private.users ( + id serial primary key, + email text unique not null, + name text not null, + "quoted_column" text + ); + "#; + + pool.execute(setup).await.unwrap(); + + // test completion inside quoted column name + assert_complete_results( + format!( + r#"select "em{}" from "private"."users""#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), + vec![CompletionAssertion::LabelAndDesc( + "email".to_string(), + "private.users".to_string(), + )], + None, + &pool, + ) + .await; + + // test completion for already quoted column + assert_complete_results( + format!( + r#"select "quoted_col{}" from "private"."users""#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), + vec![CompletionAssertion::LabelAndDesc( + "quoted_column".to_string(), + "private.users".to_string(), + )], + None, + &pool, + ) + .await; + + // test completion with empty quotes + assert_complete_results( + format!( + r#"select "{}" from "private"."users""#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), + vec![ + CompletionAssertion::Label("email".to_string()), + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("name".to_string()), + CompletionAssertion::Label("quoted_column".to_string()), + ], + None, + &pool, + ) + .await; + + // test completion with partially opened quote + assert_complete_results( + format!( + r#"select "{} from "private"."users""#, + QueryWithCursorPosition::cursor_marker() + ) + .as_str(), + vec![ + CompletionAssertion::Label("email".to_string()), + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("name".to_string()), + CompletionAssertion::Label("quoted_column".to_string()), + ], + None, + &pool, + ) + .await; + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn completes_quoted_columns_with_aliases(pool: PgPool) { + let setup = r#" + create schema if not exists private; + + create table private.users ( + id serial primary key, + email text unique not null, + name text not null, + "quoted_column" text + ); + + create table public.names ( + uid serial references private.users(id), + name text + ); + "#; + + pool.execute(setup).await.unwrap(); + + { + // should suggest pr"."email and insert into existing quotes + let query = format!( + r#"select "e{}" from private.users "pr""#, + QueryWithCursorPosition::cursor_marker() + ); + + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + r#"pr"."email"#.into(), + // replaces the full `"e"` + TextRange::new(8.into(), 9.into()), + )], + None, + &pool, + ) + .await; + } + + { + // should suggest pr"."email and insert into existing quotes + let query = format!( + r#"select "{}" from private.users "pr""#, + QueryWithCursorPosition::cursor_marker() + ); + + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + r#"pr"."email"#.into(), + TextRange::new(8.into(), 8.into()), + )], + None, + &pool, + ) + .await; + } + + { + // should suggest email and insert into quotes + let query = format!( + r#"select pr."{}" from private.users "pr""#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + r#"email"#.into(), + TextRange::new(11.into(), 11.into()), + )], + None, + &pool, + ) + .await; + } + + { + // should suggest email + let query = format!( + r#"select "pr".{} from private.users "pr""#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + "email".into(), + TextRange::new(12.into(), 12.into()), + )], + None, + &pool, + ) + .await; + } + + { + // should suggest `email` + let query = format!( + r#"select pr.{} from private.users "pr""#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + "email".into(), + TextRange::new(10.into(), 10.into()), + )], + None, + &pool, + ) + .await; + } + + { + let query = format!( + r#"select {} from private.users "pr" join public.names n on pr.id = n.uid;"#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![ + CompletionAssertion::CompletionTextAndRange( + "n.name".into(), + TextRange::new(7.into(), 7.into()), + ), + CompletionAssertion::CompletionTextAndRange( + "n.uid".into(), + TextRange::new(7.into(), 7.into()), + ), + CompletionAssertion::CompletionTextAndRange( + r#""pr".email"#.into(), + TextRange::new(7.into(), 7.into()), + ), + CompletionAssertion::CompletionTextAndRange( + r#""pr".id"#.into(), + TextRange::new(7.into(), 7.into()), + ), + ], + None, + &pool, + ) + .await; + } + + { + // should suggest "pr"."email" + let query = format!( + r#"select "{}" from private.users "pr" join public.names "n" on pr.id = n.uid;"#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![ + CompletionAssertion::CompletionTextAndRange( + r#"n"."name"#.into(), + TextRange::new(8.into(), 8.into()), + ), + CompletionAssertion::CompletionTextAndRange( + r#"n"."uid"#.into(), + TextRange::new(8.into(), 8.into()), + ), + CompletionAssertion::CompletionTextAndRange( + r#"pr"."email"#.into(), + TextRange::new(8.into(), 8.into()), + ), + CompletionAssertion::CompletionTextAndRange( + r#"pr"."id"#.into(), + TextRange::new(8.into(), 8.into()), + ), + ], + None, + &pool, + ) + .await; + } + + { + // should suggest pr"."email" + let query = format!( + r#"select "{} from private.users "pr";"#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![ + CompletionAssertion::CompletionTextAndRange( + r#"pr"."email""#.into(), + TextRange::new(8.into(), 8.into()), + ), + CompletionAssertion::CompletionTextAndRange( + r#"pr"."id""#.into(), + TextRange::new(8.into(), 8.into()), + ), + ], + None, + &pool, + ) + .await; + } + + { + // should suggest email" + let query = format!( + r#"select pr."{} from private.users "pr";"#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + r#"email""#.into(), + TextRange::new(11.into(), 11.into()), + )], + None, + &pool, + ) + .await; + } + + { + // should suggest email" + let query = format!( + r#"select "pr"."{} from private.users "pr";"#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + r#"email""#.into(), + TextRange::new(13.into(), 13.into()), + )], + None, + &pool, + ) + .await; + } + + { + // should suggest "n".name + let query = format!( + r#"select {} from names "n";"#, + QueryWithCursorPosition::cursor_marker() + ); + assert_complete_results( + query.as_str(), + vec![CompletionAssertion::CompletionTextAndRange( + r#""n".name"#.into(), + TextRange::new(7.into(), 7.into()), + )], + None, + &pool, + ) + .await; + } + } } diff --git a/crates/pgt_completions/src/providers/functions.rs b/crates/pgt_completions/src/providers/functions.rs index b2ac2fae8..f4e86509b 100644 --- a/crates/pgt_completions/src/providers/functions.rs +++ b/crates/pgt_completions/src/providers/functions.rs @@ -8,7 +8,7 @@ use crate::{ relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; -use super::helper::get_completion_text_with_schema_or_alias; +use super::helper::with_schema_or_alias; pub fn complete_functions<'a>( ctx: &'a TreesitterContext, @@ -35,10 +35,9 @@ pub fn complete_functions<'a>( } fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionText { + let mut text = with_schema_or_alias(ctx, func.name.as_str(), Some(func.schema.as_str())); + let range = get_range_to_replace(ctx); - let mut text = get_completion_text_with_schema_or_alias(ctx, &func.name, &func.schema) - .map(|ct| ct.text) - .unwrap_or(func.name.to_string()); if ctx.is_invocation { CompletionText { diff --git a/crates/pgt_completions/src/providers/helper.rs b/crates/pgt_completions/src/providers/helper.rs index cd1046f12..b6547d701 100644 --- a/crates/pgt_completions/src/providers/helper.rs +++ b/crates/pgt_completions/src/providers/helper.rs @@ -1,28 +1,32 @@ use pgt_text_size::{TextRange, TextSize}; use pgt_treesitter::TreesitterContext; -use crate::{CompletionText, remove_sanitized_token}; +use crate::{is_sanitized_token_with_quote, remove_sanitized_token}; -pub(crate) fn find_matching_alias_for_table( - ctx: &TreesitterContext, - table_name: &str, -) -> Option { - for (alias, table) in ctx.mentioned_table_aliases.iter() { - if table == table_name { - return Some(alias.to_string()); - } - } - None +pub(crate) fn node_text_surrounded_by_quotes(ctx: &TreesitterContext) -> bool { + ctx.get_node_under_cursor_content() + .is_some_and(|c| c.starts_with('"') && c.ends_with('"') && c.len() > 1) } pub(crate) fn get_range_to_replace(ctx: &TreesitterContext) -> TextRange { match ctx.node_under_cursor.as_ref() { Some(node) => { let content = ctx.get_node_under_cursor_content().unwrap_or("".into()); - let length = remove_sanitized_token(content.as_str()).len(); + let content = content.as_str(); + + let sanitized = remove_sanitized_token(content); + let length = sanitized.len(); + + let mut start = node.start_byte(); + let mut end = start + length; - let start = node.start_byte(); - let end = start + length; + if sanitized.starts_with('"') && sanitized.ends_with('"') { + start += 1; + + if sanitized.len() > 1 { + end -= 1; + } + } TextRange::new(start.try_into().unwrap(), end.try_into().unwrap()) } @@ -30,22 +34,34 @@ pub(crate) fn get_range_to_replace(ctx: &TreesitterContext) -> TextRange { } } -pub(crate) fn get_completion_text_with_schema_or_alias( +pub(crate) fn with_schema_or_alias( ctx: &TreesitterContext, item_name: &str, - schema_or_alias_name: &str, -) -> Option { + schema_or_alias_name: Option<&str>, +) -> String { let is_already_prefixed_with_schema_name = ctx.schema_or_alias_name.is_some(); - if schema_or_alias_name == "public" || is_already_prefixed_with_schema_name { - None + let with_quotes = node_text_surrounded_by_quotes(ctx); + + let node_under_cursor_txt = ctx.get_node_under_cursor_content().unwrap_or("".into()); + let node_under_cursor_txt = node_under_cursor_txt.as_str(); + let is_quote_sanitized = is_sanitized_token_with_quote(node_under_cursor_txt); + + if schema_or_alias_name.is_none_or(|s| s == "public") || is_already_prefixed_with_schema_name { + if is_quote_sanitized { + format!(r#"{}""#, item_name) + } else { + item_name.to_string() + } } else { - let range = get_range_to_replace(ctx); + let schema_or_als = schema_or_alias_name.unwrap(); - Some(CompletionText { - text: format!("{}.{}", schema_or_alias_name, item_name), - range, - is_snippet: false, - }) + if is_quote_sanitized { + format!(r#"{}"."{}""#, schema_or_als.replace('"', ""), item_name) + } else if with_quotes { + format!(r#"{}"."{}"#, schema_or_als.replace('"', ""), item_name) + } else { + format!("{}.{}", schema_or_als, item_name) + } } } diff --git a/crates/pgt_completions/src/providers/policies.rs b/crates/pgt_completions/src/providers/policies.rs index a5ffdb43e..b903155e6 100644 --- a/crates/pgt_completions/src/providers/policies.rs +++ b/crates/pgt_completions/src/providers/policies.rs @@ -1,10 +1,10 @@ use pgt_schema_cache::SchemaCache; -use pgt_text_size::{TextRange, TextSize}; use pgt_treesitter::TreesitterContext; use crate::{ CompletionItemKind, CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, + providers::helper::node_text_surrounded_by_quotes, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; @@ -17,44 +17,30 @@ pub fn complete_policies<'a>( ) { let available_policies = &schema_cache.policies; - let surrounded_by_quotes = ctx - .get_node_under_cursor_content() - .is_some_and(|c| c.starts_with('"') && c.ends_with('"') && c != "\"\""); - for pol in available_policies { - let completion_text = if surrounded_by_quotes { + let text = if node_text_surrounded_by_quotes(ctx) { // If we're within quotes, we want to change the content // *within* the quotes. - // If we attempt to replace outside the quotes, the VSCode - // client won't show the suggestions. - let range = get_range_to_replace(ctx); - Some(CompletionText { - text: pol.name.clone(), - is_snippet: false, - range: TextRange::new( - range.start() + TextSize::new(1), - range.end() - TextSize::new(1), - ), - }) + pol.name.to_string() } else { - // If we aren't within quotes, we want to complete the - // full policy including quotation marks. - Some(CompletionText { - is_snippet: false, - text: format!("\"{}\"", pol.name), - range: get_range_to_replace(ctx), - }) + format!("\"{}\"", pol.name) }; let relevance = CompletionRelevanceData::Policy(pol); + let range = get_range_to_replace(ctx); + let item = PossibleCompletionItem { label: pol.name.chars().take(35).collect::(), score: CompletionScore::from(relevance.clone()), filter: CompletionFilter::from(relevance), description: pol.table_name.to_string(), kind: CompletionItemKind::Policy, - completion_text, + completion_text: Some(CompletionText { + text, + range, + is_snippet: false, + }), detail: None, }; diff --git a/crates/pgt_completions/src/providers/tables.rs b/crates/pgt_completions/src/providers/tables.rs index f78b697c9..20100e01f 100644 --- a/crates/pgt_completions/src/providers/tables.rs +++ b/crates/pgt_completions/src/providers/tables.rs @@ -1,13 +1,15 @@ -use pgt_schema_cache::SchemaCache; +use pgt_schema_cache::{SchemaCache, Table}; use pgt_treesitter::TreesitterContext; use crate::{ + CompletionText, builder::{CompletionBuilder, PossibleCompletionItem}, item::CompletionItemKind, + providers::helper::get_range_to_replace, relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, }; -use super::helper::get_completion_text_with_schema_or_alias; +use super::helper::with_schema_or_alias; pub fn complete_tables<'a>( ctx: &'a TreesitterContext, @@ -34,17 +36,25 @@ pub fn complete_tables<'a>( description: table.schema.to_string(), kind: CompletionItemKind::Table, detail, - completion_text: get_completion_text_with_schema_or_alias( - ctx, - &table.name, - &table.schema, - ), + completion_text: Some(get_completion_text(ctx, table)), }; builder.add_item(item); } } +fn get_completion_text(ctx: &TreesitterContext, table: &Table) -> CompletionText { + let text = with_schema_or_alias(ctx, table.name.as_str(), Some(table.schema.as_str())); + + let range = get_range_to_replace(ctx); + + CompletionText { + text, + range, + is_snippet: false, + } +} + #[cfg(test)] mod tests { diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index 0514a485e..b9a896c49 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -39,12 +39,35 @@ impl CompletionFilter<'_> { if current_node_kind.starts_with("keyword_") || current_node_kind == "=" || current_node_kind == "," - || current_node_kind == "literal" || current_node_kind == "ERROR" { return None; } + // "literal" nodes can be identfiers wrapped in quotes: + // `select "email" from auth.users;` + // Here, "email" is a literal node. + if current_node_kind == "literal" { + match self.data { + CompletionRelevanceData::Column(_) => match ctx.wrapping_clause_type.as_ref() { + Some(WrappingClause::Select) + | Some(WrappingClause::Where) + | Some(WrappingClause::Join { .. }) + | Some(WrappingClause::Update) + | Some(WrappingClause::Delete) + | Some(WrappingClause::Insert) + | Some(WrappingClause::DropColumn) + | Some(WrappingClause::AlterColumn) + | Some(WrappingClause::RenameColumn) + | Some(WrappingClause::PolicyCheck) => { + // the literal is probably a column + } + _ => return None, + }, + _ => return None, + } + } + // No autocompletions if there are two identifiers without a separator. if ctx.node_under_cursor.as_ref().is_some_and(|n| match n { NodeUnderCursor::TsNode(node) => node.prev_sibling().is_some_and(|p| { @@ -232,8 +255,7 @@ impl CompletionFilter<'_> { CompletionRelevanceData::Table(table) => &table.schema == schema_or_alias, CompletionRelevanceData::Function(f) => &f.schema == schema_or_alias, CompletionRelevanceData::Column(col) => ctx - .mentioned_table_aliases - .get(schema_or_alias) + .get_mentioned_table_for_alias(schema_or_alias) .is_some_and(|t| t == &col.table_name), // we should never allow schema suggestions if there already was one. diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 4bbf325f4..ba45e2d0d 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -77,7 +77,7 @@ impl CompletionScore<'_> { Some(ct) => ct, }; - let has_mentioned_tables = !ctx.mentioned_relations.is_empty(); + let has_mentioned_tables = ctx.has_any_mentioned_relations(); let has_mentioned_schema = ctx.schema_or_alias_name.is_some(); self.score += match self.data { @@ -248,14 +248,12 @@ impl CompletionScore<'_> { }; if ctx - .mentioned_relations - .get(&Some(schema.to_string())) + .get_mentioned_relations(&Some(schema.to_string())) .is_some_and(|tables| tables.contains(table_name)) { self.score += 45; } else if ctx - .mentioned_relations - .get(&None) + .get_mentioned_relations(&None) .is_some_and(|tables| tables.contains(table_name)) { self.score += 30; @@ -334,13 +332,12 @@ impl CompletionScore<'_> { * */ if ctx - .mentioned_columns - .get(&ctx.wrapping_clause_type) + .get_mentioned_columns(&ctx.wrapping_clause_type) .is_some_and(|set| { set.iter().any(|mentioned| match mentioned.alias.as_ref() { Some(als) => { - let aliased_table = ctx.mentioned_table_aliases.get(als.as_str()); - column.name == mentioned.column + let aliased_table = ctx.get_mentioned_table_for_alias(als.as_str()); + column.name == mentioned.column.replace('"', "") && aliased_table.is_none_or(|t| t == &column.table_name) } None => mentioned.column == column.name, diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index 155256c8a..5272c75e2 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -5,6 +5,7 @@ use pgt_text_size::TextSize; use crate::CompletionParams; static SANITIZED_TOKEN: &str = "REPLACED_TOKEN"; +static SANITIZED_TOKEN_WITH_QUOTE: &str = r#"REPLACED_TOKEN_WITH_QUOTE""#; #[derive(Debug)] pub(crate) struct SanitizedCompletionParams<'a> { @@ -20,33 +21,23 @@ pub fn benchmark_sanitization(params: CompletionParams) -> String { } pub(crate) fn remove_sanitized_token(it: &str) -> String { - it.replace(SANITIZED_TOKEN, "") + it.replace(SANITIZED_TOKEN_WITH_QUOTE, "") + .replace(SANITIZED_TOKEN, "") } -pub(crate) fn is_sanitized_token(txt: &str) -> bool { - txt == SANITIZED_TOKEN +pub(crate) fn is_sanitized_token(node_under_cursor_txt: &str) -> bool { + node_under_cursor_txt == SANITIZED_TOKEN || is_sanitized_token_with_quote(node_under_cursor_txt) } -#[derive(PartialEq, Eq, Debug)] -pub(crate) enum NodeText { - Replaced, - Original(String), -} - -impl From<&str> for NodeText { - fn from(value: &str) -> Self { - if value == SANITIZED_TOKEN { - NodeText::Replaced - } else { - NodeText::Original(value.into()) - } +pub(crate) fn is_sanitized_token_with_quote(node_under_cursor_txt: &str) -> bool { + if node_under_cursor_txt.len() <= 1 { + return false; } -} -impl From for NodeText { - fn from(value: String) -> Self { - NodeText::from(value.as_str()) - } + // Node under cursor text will be "REPLACED_TOKEN_WITH_QUOTE". + // The SANITIZED_TOKEN_WITH_QUOTE does not have the leading ". + // We need to omit it from the txt. + &node_under_cursor_txt[1..] == SANITIZED_TOKEN_WITH_QUOTE } impl<'larger, 'smaller> From> for SanitizedCompletionParams<'smaller> @@ -60,6 +51,7 @@ where || cursor_before_semicolon(params.tree, params.position) || cursor_on_a_dot(¶ms.text, params.position) || cursor_between_parentheses(¶ms.text, params.position) + || cursor_after_opened_quote(¶ms.text, params.position) { SanitizedCompletionParams::with_adjusted_sql(params) } else { @@ -80,12 +72,25 @@ where let max = max(cursor_pos + 1, params.text.len()); + let has_uneven_quotes = params.text.chars().filter(|c| *c == '"').count() % 2 != 0; + let mut opened_quote = false; + for idx in 0..max { match sql_iter.next() { Some(c) => { + if c == '"' { + opened_quote = !opened_quote; + } + if idx == cursor_pos { - sql.push_str(SANITIZED_TOKEN); + if opened_quote && has_uneven_quotes { + sql.push_str(SANITIZED_TOKEN_WITH_QUOTE); + opened_quote = false; + } else { + sql.push_str(SANITIZED_TOKEN); + } } + sql.push(c); } None => { @@ -268,6 +273,27 @@ fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool { head_of_list || end_of_list || between_list_items || after_and_keyword || after_eq_sign } +fn cursor_after_opened_quote(sql: &str, position: TextSize) -> bool { + let position: usize = position.into(); + let mut opened_quote = false; + let mut preceding_quote = false; + + for (idx, c) in sql.char_indices() { + if idx == position && opened_quote && preceding_quote { + return true; + } + + if c == '"' { + preceding_quote = true; + opened_quote = !opened_quote; + } else { + preceding_quote = false; + } + } + + opened_quote && preceding_quote +} + #[cfg(test)] mod tests { use pgt_schema_cache::SchemaCache; @@ -276,29 +302,36 @@ mod tests { use crate::{ CompletionParams, SanitizedCompletionParams, sanitization::{ - cursor_before_semicolon, cursor_between_parentheses, cursor_inbetween_nodes, - cursor_on_a_dot, cursor_prepared_to_write_token_after_last_node, + cursor_after_opened_quote, cursor_before_semicolon, cursor_between_parentheses, + cursor_inbetween_nodes, cursor_on_a_dot, + cursor_prepared_to_write_token_after_last_node, }, }; - #[test] - fn should_lowercase_everything_except_replaced_token() { - let input = "SELECT FROM users WHERE ts = NOW();"; - - let position = TextSize::new(7); - let cache = SchemaCache::default(); - + fn get_test_params(input: &str, position: TextSize) -> CompletionParams { let mut ts = tree_sitter::Parser::new(); ts.set_language(tree_sitter_sql::language()).unwrap(); - let tree = ts.parse(input, None).unwrap(); - let params = CompletionParams { + let tree = Box::new(ts.parse(input, None).unwrap()); + let cache = Box::new(SchemaCache::default()); + + let leaked_tree = Box::leak(tree); + let leaked_cache = Box::leak(cache); + + CompletionParams { position, - schema: &cache, + schema: leaked_cache, text: input.into(), - tree: &tree, - }; + tree: leaked_tree, + } + } + + #[test] + fn should_lowercase_everything_except_replaced_token() { + let input = "SELECT FROM users WHERE ts = NOW();"; + let position = TextSize::new(7); + let params = get_test_params(input, position); let sanitized = SanitizedCompletionParams::from(params); assert_eq!( @@ -307,6 +340,56 @@ mod tests { ); } + #[test] + fn should_sanitize_with_opened_quotes() { + // select "email", "| from "auth"."users"; + let input = r#"select "email", " from "auth"."users";"#; + let position = TextSize::new(17); + + let params = get_test_params(input, position); + + let sanitized = SanitizedCompletionParams::from(params); + + assert_eq!( + sanitized.text, + r#"select "email", "REPLACED_TOKEN_WITH_QUOTE" from "auth"."users";"# + ); + } + + #[test] + fn should_not_complete_quote_if_we_are_inside_pair() { + // select "email", "| " from "auth"."users"; + // we have opened a quote, but it is already closed a couple of characters later + let input = r#"select "email", " " from "auth"."users";"#; + let position = TextSize::new(17); + + let params = get_test_params(input, position); + + let sanitized = SanitizedCompletionParams::from(params); + + assert_eq!( + sanitized.text, + r#"select "email", "REPLACED_TOKEN " from "auth"."users";"# + ); + } + + #[test] + fn should_not_use_quote_token_if_we_are_not_within_opened_quote() { + // select "users".| from "users" join "public"." + // we have an opened quote at the end, but the cursor is not within an opened quote + let input = r#"select "users". from "users" join "public"." "#; + let position = TextSize::new(15); + + let params = get_test_params(input, position); + + let sanitized = SanitizedCompletionParams::from(params); + + assert_eq!( + sanitized.text, + r#"select "users".REPLACED_TOKEN from "users" join "public"." "# + ); + } + #[test] fn test_cursor_inbetween_nodes() { // note: two spaces between select and from. @@ -467,4 +550,44 @@ mod tests { // does not break if sql is really short assert!(!cursor_between_parentheses("(a)", TextSize::new(2))); } + + #[test] + fn after_single_quote() { + // select "| <-- right after single quote + assert!(cursor_after_opened_quote(r#"select ""#, TextSize::new(8))); + // select "| from something; <-- right after opening quote + assert!(cursor_after_opened_quote( + r#"select " from something;"#, + TextSize::new(8) + )); + + // select "user_id", "| <-- right after opening quote + assert!(cursor_after_opened_quote( + r#"select "user_id", ""#, + TextSize::new(19) + )); + // select "user_id, "| from something; <-- right after opening quote + assert!(cursor_after_opened_quote( + r#"select "user_id", " from something;"#, + TextSize::new(19) + )); + + // select "user_id"| from something; <-- after closing quote + assert!(!cursor_after_opened_quote( + r#"select "user_id" from something;"#, + TextSize::new(16) + )); + + // select ""| from something; <-- after closing quote + assert!(!cursor_after_opened_quote( + r#"select "" from something;"#, + TextSize::new(9) + )); + + // select "user_id, " |from something; <-- one off after opening quote + assert!(!cursor_after_opened_quote( + r#"select "user_id", " from something;"#, + TextSize::new(20) + )); + } } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index e6c347614..e5fc58014 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -1,5 +1,6 @@ use pgt_schema_cache::SchemaCache; use pgt_test_utils::QueryWithCursorPosition; +use pgt_text_size::TextRange; use sqlx::{Executor, PgPool}; use crate::{CompletionItem, CompletionItemKind, CompletionParams, complete}; @@ -77,6 +78,7 @@ pub(crate) enum CompletionAssertion { LabelAndDesc(String, String), LabelNotExists(String), KindNotExists(CompletionItemKind), + CompletionTextAndRange(String, TextRange), } impl CompletionAssertion { @@ -127,6 +129,29 @@ impl CompletionAssertion { desc, &item.description ); } + CompletionAssertion::CompletionTextAndRange(txt, text_range) => { + assert_eq!( + item.completion_text.as_ref().map(|t| t.text.as_str()), + Some(txt.as_str()), + "Expected completion text to be {}, but got {}", + txt, + item.completion_text + .as_ref() + .map(|t| t.text.clone()) + .unwrap_or("None".to_string()) + ); + + assert_eq!( + item.completion_text.as_ref().map(|t| &t.range), + Some(text_range), + "Expected range to be {:?}, but got {:?}", + text_range, + item.completion_text + .as_ref() + .map(|t| format!("{:?}", &t.range)) + .unwrap_or("None".to_string()) + ); + } } } } @@ -146,6 +171,7 @@ pub(crate) async fn assert_complete_results( CompletionAssertion::LabelNotExists(_) | CompletionAssertion::KindNotExists(_) => true, CompletionAssertion::Label(_) | CompletionAssertion::LabelAndKind(_, _) + | CompletionAssertion::CompletionTextAndRange(_, _) | CompletionAssertion::LabelAndDesc(_, _) => false, }); diff --git a/crates/pgt_hover/src/hoverables/column.rs b/crates/pgt_hover/src/hoverables/column.rs index e4e15ebce..913aae86e 100644 --- a/crates/pgt_hover/src/hoverables/column.rs +++ b/crates/pgt_hover/src/hoverables/column.rs @@ -63,25 +63,25 @@ impl ContextualPriority for Column { // high score if we match the specific alias or table being referenced in the cursor context if let Some(table_or_alias) = ctx.schema_or_alias_name.as_ref() { - if table_or_alias == self.table_name.as_str() { + if table_or_alias.replace('"', "") == self.table_name.as_str() { score += 250.0; - } else if let Some(table_name) = ctx.mentioned_table_aliases.get(table_or_alias) { + } else if let Some(table_name) = ctx.get_mentioned_table_for_alias(table_or_alias) { if table_name == self.table_name.as_str() { score += 250.0; } } } - // medium score if the current column maps to any of the query's mentioned - // "(schema.)table" combinations - for (schema_opt, tables) in &ctx.mentioned_relations { - if tables.contains(&self.table_name) { - if schema_opt.as_deref() == Some(&self.schema_name) { - score += 150.0; - } else { - score += 100.0; - } - } + if ctx + .get_mentioned_relations(&Some(self.schema_name.clone())) + .is_some_and(|t| t.contains(&self.table_name)) + { + score += 150.0; + } else if ctx + .get_mentioned_relations(&None) + .is_some_and(|t| t.contains(&self.table_name)) + { + score += 100.0; } if self.schema_name == "public" && score == 0.0 { diff --git a/crates/pgt_hover/src/hoverables/table.rs b/crates/pgt_hover/src/hoverables/table.rs index 6bd012061..4c457c28d 100644 --- a/crates/pgt_hover/src/hoverables/table.rs +++ b/crates/pgt_hover/src/hoverables/table.rs @@ -57,20 +57,19 @@ impl ContextualPriority for Table { fn relevance_score(&self, ctx: &TreesitterContext) -> f32 { let mut score = 0.0; - for (schema_opt, tables) in &ctx.mentioned_relations { - if tables.contains(&self.name) { - if schema_opt.as_deref() == Some(&self.schema) { - score += 200.0; - } else { - score += 150.0; - } - } - } - if ctx - .mentioned_relations - .keys() - .any(|schema| schema.as_deref() == Some(&self.schema)) + .get_mentioned_relations(&Some(self.schema.clone())) + .is_some_and(|t| t.contains(&self.name)) + { + score += 200.0; + } else if ctx + .get_mentioned_relations(&None) + .is_some_and(|t| t.contains(&self.name)) + { + score += 150.0; + } else if ctx + .get_mentioned_relations(&Some(self.schema.clone())) + .is_some() { score += 50.0; } diff --git a/crates/pgt_hover/src/lib.rs b/crates/pgt_hover/src/lib.rs index 53324bc8a..f3b8a2640 100644 --- a/crates/pgt_hover/src/lib.rs +++ b/crates/pgt_hover/src/lib.rs @@ -58,8 +58,7 @@ pub fn on_hover(params: OnHoverParams) -> Vec { hovered_node::NodeIdentification::SchemaAndName((table_or_alias, column_name)) => { // resolve alias to actual table name if needed let actual_table = ctx - .mentioned_table_aliases - .get(table_or_alias.as_str()) + .get_mentioned_table_for_alias(table_or_alias.as_str()) .map(|s| s.as_str()) .unwrap_or(table_or_alias.as_str()); diff --git a/crates/pgt_lsp/src/capabilities.rs b/crates/pgt_lsp/src/capabilities.rs index 8c8ff6d92..3bbb062de 100644 --- a/crates/pgt_lsp/src/capabilities.rs +++ b/crates/pgt_lsp/src/capabilities.rs @@ -37,7 +37,12 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa // The request is used to get more information about a simple CompletionItem. resolve_provider: None, - trigger_characters: Some(vec![".".to_owned(), " ".to_owned(), "(".to_owned()]), + trigger_characters: Some(vec![ + ".".to_owned(), + " ".to_owned(), + "(".to_owned(), + "\"".to_owned(), + ]), // No character will lead to automatically inserting the selected completion-item all_commit_characters: None, diff --git a/crates/pgt_treesitter/src/context/mod.rs b/crates/pgt_treesitter/src/context/mod.rs index 383e4c993..d481af8cb 100644 --- a/crates/pgt_treesitter/src/context/mod.rs +++ b/crates/pgt_treesitter/src/context/mod.rs @@ -183,9 +183,9 @@ pub struct TreesitterContext<'a> { pub is_invocation: bool, pub wrapping_statement_range: Option, - pub mentioned_relations: HashMap, HashSet>, - pub mentioned_table_aliases: HashMap, - pub mentioned_columns: HashMap>, HashSet>, + mentioned_relations: HashMap, HashSet>, + mentioned_table_aliases: HashMap, + mentioned_columns: HashMap>, HashSet>, } impl<'a> TreesitterContext<'a> { @@ -800,6 +800,56 @@ impl<'a> TreesitterContext<'a> { NodeUnderCursor::CustomNode { .. } => false, }) } + + pub fn get_mentioned_relations(&self, key: &Option) -> Option<&HashSet> { + if let Some(key) = key.as_ref() { + let sanitized_key = key.replace('"', ""); + + self.mentioned_relations + .get(&Some(sanitized_key.clone())) + .or(self + .mentioned_relations + .get(&Some(format!(r#""{}""#, sanitized_key)))) + } else { + self.mentioned_relations.get(&None) + } + } + + pub fn get_mentioned_table_for_alias(&self, key: &str) -> Option<&String> { + let sanitized_key = key.replace('"', ""); + + self.mentioned_table_aliases.get(&sanitized_key).or(self + .mentioned_table_aliases + .get(&format!(r#""{}""#, sanitized_key))) + } + + pub fn get_used_alias_for_table(&self, table_name: &str) -> Option { + for (alias, table) in self.mentioned_table_aliases.iter() { + if table == table_name { + return Some(alias.to_string()); + } + } + None + } + + pub fn get_mentioned_columns( + &self, + clause: &Option>, + ) -> Option<&HashSet> { + self.mentioned_columns.get(clause) + } + + pub fn has_any_mentioned_relations(&self) -> bool { + !self.mentioned_relations.is_empty() + } + + pub fn has_mentioned_table_aliases(&self) -> bool { + !self.mentioned_table_aliases.is_empty() + } + + pub fn has_mentioned_columns(&self) -> bool { + !self.mentioned_columns.is_empty() + } } #[cfg(test)] diff --git a/crates/pgt_treesitter/src/queries/relations.rs b/crates/pgt_treesitter/src/queries/relations.rs index cb6a6bea9..664260fb9 100644 --- a/crates/pgt_treesitter/src/queries/relations.rs +++ b/crates/pgt_treesitter/src/queries/relations.rs @@ -44,13 +44,13 @@ pub struct RelationMatch<'a> { impl RelationMatch<'_> { pub fn get_schema(&self, sql: &str) -> Option { - let str = self - .schema - .as_ref()? - .utf8_text(sql.as_bytes()) - .expect("Failed to get schema from RelationMatch"); - - Some(str.to_string()) + Some( + self.schema + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get schema from RelationMatch") + .to_string(), + ) } pub fn get_table(&self, sql: &str) -> String { @@ -162,6 +162,29 @@ mod tests { assert_eq!(results[0].get_table(sql), "users"); } + #[test] + fn finds_table_with_schema_quotes() { + let sql = r#"select * from "public"."users";"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some(r#""public""#.to_string())); + assert_eq!(results[0].get_table(sql), r#""users""#); + } + #[test] fn finds_insert_into_with_schema_and_table() { let sql = r#"insert into auth.accounts (id, email) values (1, 'a@b.com');"#; diff --git a/crates/pgt_treesitter/src/queries/select_columns.rs b/crates/pgt_treesitter/src/queries/select_columns.rs index f232abc38..de5016d52 100644 --- a/crates/pgt_treesitter/src/queries/select_columns.rs +++ b/crates/pgt_treesitter/src/queries/select_columns.rs @@ -28,13 +28,13 @@ pub struct SelectColumnMatch<'a> { impl SelectColumnMatch<'_> { pub fn get_alias(&self, sql: &str) -> Option { - let str = self - .alias - .as_ref()? - .utf8_text(sql.as_bytes()) - .expect("Failed to get alias from ColumnMatch"); - - Some(str.to_string()) + Some( + self.alias + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from ColumnMatch") + .to_string(), + ) } pub fn get_column(&self, sql: &str) -> String { diff --git a/crates/pgt_treesitter/src/queries/where_columns.rs b/crates/pgt_treesitter/src/queries/where_columns.rs index b683300b6..03ce90ec3 100644 --- a/crates/pgt_treesitter/src/queries/where_columns.rs +++ b/crates/pgt_treesitter/src/queries/where_columns.rs @@ -29,13 +29,13 @@ pub struct WhereColumnMatch<'a> { impl WhereColumnMatch<'_> { pub fn get_alias(&self, sql: &str) -> Option { - let str = self - .alias - .as_ref()? - .utf8_text(sql.as_bytes()) - .expect("Failed to get alias from ColumnMatch"); - - Some(str.to_string()) + Some( + self.alias + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from ColumnMatch") + .to_string(), + ) } pub fn get_column(&self, sql: &str) -> String {