Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions crates/pgt_completions/src/providers/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pgt_treesitter::TreesitterContext;
use crate::{
CompletionItemKind, CompletionText,
builder::{CompletionBuilder, PossibleCompletionItem},
providers::helper::get_range_to_replace,
providers::helper::{get_range_to_replace, node_text_surrounded_by_quotes, only_leading_quote},
relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore},
};

Expand Down Expand Up @@ -37,7 +37,7 @@ pub fn complete_functions<'a>(
fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionText {
let mut text = with_schema_or_alias(ctx, func.name.as_str(), Some(func.schema.as_str()));

let range = get_range_to_replace(ctx);
let mut range = get_range_to_replace(ctx);

if ctx.is_invocation {
CompletionText {
Expand All @@ -46,6 +46,11 @@ fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionTe
is_snippet: false,
}
} else {
if node_text_surrounded_by_quotes(ctx) && !only_leading_quote(ctx) {
text.push('"');
range = range.checked_expand_end(1.into()).unwrap_or(range);
}

text.push('(');

let num_args = func.args.args.len();
Expand All @@ -68,6 +73,7 @@ fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionTe

#[cfg(test)]
mod tests {
use pgt_text_size::TextRange;
use sqlx::{Executor, PgPool};

use crate::{
Expand Down Expand Up @@ -294,4 +300,68 @@ mod tests {
)
.await;
}

#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
async fn autocompletes_after_schema_in_quotes(pool: PgPool) {
let setup = r#"
create schema auth;

create or replace function auth.my_cool_foo()
returns trigger
language plpgsql
security invoker
as $$
begin
raise exception 'dont matter';
end;
$$;
"#;

pool.execute(setup).await.unwrap();

assert_complete_results(
format!(
r#"select "auth".{}"#,
QueryWithCursorPosition::cursor_marker()
)
.as_str(),
vec![CompletionAssertion::CompletionTextAndRange(
"my_cool_foo()".into(),
TextRange::new(14.into(), 14.into()),
)],
None,
&pool,
)
.await;

assert_complete_results(
format!(
r#"select "auth"."{}"#,
QueryWithCursorPosition::cursor_marker()
)
.as_str(),
vec![CompletionAssertion::CompletionTextAndRange(
r#"my_cool_foo"()"#.into(),
TextRange::new(15.into(), 15.into()),
)],
None,
&pool,
)
.await;

assert_complete_results(
format!(
r#"select "auth"."{}""#,
QueryWithCursorPosition::cursor_marker()
)
.as_str(),
vec![CompletionAssertion::CompletionTextAndRange(
r#"my_cool_foo"()"#.into(),
TextRange::new(15.into(), 16.into()),
)],
None,
&pool,
)
.await;
}
}
15 changes: 9 additions & 6 deletions crates/pgt_completions/src/providers/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ pub(crate) fn get_range_to_replace(ctx: &TreesitterContext) -> TextRange {
}
}

pub(crate) fn only_leading_quote(ctx: &TreesitterContext) -> bool {
let node_under_cursor_txt = ctx.get_node_under_cursor_content().unwrap_or("".into());
let node_under_cursor_txt = node_under_cursor_txt.as_str();
is_sanitized_token_with_quote(node_under_cursor_txt)
}

pub(crate) fn with_schema_or_alias(
ctx: &TreesitterContext,
item_name: &str,
Expand All @@ -42,21 +48,18 @@ pub(crate) fn with_schema_or_alias(
let is_already_prefixed_with_schema_name = ctx.schema_or_alias_name.is_some();

let with_quotes = node_text_surrounded_by_quotes(ctx);

let node_under_cursor_txt = ctx.get_node_under_cursor_content().unwrap_or("".into());
let node_under_cursor_txt = node_under_cursor_txt.as_str();
let is_quote_sanitized = is_sanitized_token_with_quote(node_under_cursor_txt);
let single_leading_quote = only_leading_quote(ctx);

if schema_or_alias_name.is_none_or(|s| s == "public") || is_already_prefixed_with_schema_name {
if is_quote_sanitized {
if single_leading_quote {
format!(r#"{}""#, item_name)
} else {
item_name.to_string()
}
} else {
let schema_or_als = schema_or_alias_name.unwrap();

if is_quote_sanitized {
if single_leading_quote {
format!(r#"{}"."{}""#, schema_or_als.replace('"', ""), item_name)
} else if with_quotes {
format!(r#"{}"."{}"#, schema_or_als.replace('"', ""), item_name)
Expand Down
87 changes: 87 additions & 0 deletions crates/pgt_completions/src/providers/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ fn get_completion_text(ctx: &TreesitterContext, table: &Table) -> CompletionText
#[cfg(test)]
mod tests {

use pgt_text_size::TextRange;
use sqlx::{Executor, PgPool};

use crate::{
Expand Down Expand Up @@ -569,4 +570,90 @@ mod tests {
)
.await;
}

#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
async fn after_quoted_schemas(pool: PgPool) {
let setup = r#"
create schema auth;

create table auth.users (
uid serial primary key,
name text not null,
email text unique not null
);

create table auth.posts (
pid serial primary key,
user_id int not null references auth.users(uid),
title text not null,
content text,
created_at timestamp default now()
);
"#;

pool.execute(setup).await.unwrap();

assert_complete_results(
format!(
r#"select * from "auth".{}"#,
QueryWithCursorPosition::cursor_marker()
)
.as_str(),
vec![
CompletionAssertion::CompletionTextAndRange(
"posts".into(),
TextRange::new(21.into(), 21.into()),
),
CompletionAssertion::CompletionTextAndRange(
"users".into(),
TextRange::new(21.into(), 21.into()),
),
],
None,
&pool,
)
.await;

assert_complete_results(
format!(
r#"select * from "auth"."{}""#,
QueryWithCursorPosition::cursor_marker()
)
.as_str(),
vec![
CompletionAssertion::CompletionTextAndRange(
"posts".into(),
TextRange::new(22.into(), 22.into()),
),
CompletionAssertion::CompletionTextAndRange(
"users".into(),
TextRange::new(22.into(), 22.into()),
),
],
None,
&pool,
)
.await;

assert_complete_results(
format!(
r#"select * from "auth"."{}"#,
QueryWithCursorPosition::cursor_marker()
)
.as_str(),
vec![
CompletionAssertion::CompletionTextAndRange(
r#"posts""#.into(),
TextRange::new(22.into(), 22.into()),
),
CompletionAssertion::CompletionTextAndRange(
r#"users""#.into(),
TextRange::new(22.into(), 22.into()),
),
],
None,
&pool,
)
.await;
}
}
8 changes: 4 additions & 4 deletions crates/pgt_completions/src/relevance/filtering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ impl CompletionFilter<'_> {
return Some(());
}

let schema_or_alias = ctx.schema_or_alias_name.as_ref().unwrap();
let schema_or_alias = ctx.schema_or_alias_name.as_ref().unwrap().replace('"', "");

let matches = match self.data {
CompletionRelevanceData::Table(table) => &table.schema == schema_or_alias,
CompletionRelevanceData::Function(f) => &f.schema == schema_or_alias,
CompletionRelevanceData::Table(table) => table.schema == schema_or_alias,
CompletionRelevanceData::Function(f) => f.schema == schema_or_alias,
CompletionRelevanceData::Column(col) => ctx
.get_mentioned_table_for_alias(schema_or_alias)
.get_mentioned_table_for_alias(&schema_or_alias)
.is_some_and(|t| t == &col.table_name),

// we should never allow schema suggestions if there already was one.
Expand Down
3 changes: 2 additions & 1 deletion crates/pgt_completions/src/relevance/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ impl CompletionScore<'_> {
}

fn check_matches_schema(&mut self, ctx: &TreesitterContext) {
// TODO
let schema_name = match ctx.schema_or_alias_name.as_ref() {
None => return,
Some(n) => n,
Some(n) => n.replace('"', ""),
};

let data_schema = match self.get_schema_name() {
Expand Down
56 changes: 46 additions & 10 deletions crates/pgt_completions/src/sanitization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ where
// we want to push spaces until we arrive at the cursor position.
// we'll then add the SANITIZED_TOKEN
if idx == cursor_pos {
sql.push_str(SANITIZED_TOKEN);
if opened_quote && has_uneven_quotes {
sql.push_str(SANITIZED_TOKEN_WITH_QUOTE);
} else {
sql.push_str(SANITIZED_TOKEN);
}
} else {
sql.push(' ');
}
Expand Down Expand Up @@ -342,18 +346,50 @@ mod tests {

#[test]
fn should_sanitize_with_opened_quotes() {
// select "email", "| from "auth"."users";
let input = r#"select "email", " from "auth"."users";"#;
let position = TextSize::new(17);
{
// select "email", "| from "auth"."users";
let input = r#"select "email", " from "auth"."users";"#;
let position = TextSize::new(17);

let params = get_test_params(input, position);
let params = get_test_params(input, position);

let sanitized = SanitizedCompletionParams::from(params);
let sanitized = SanitizedCompletionParams::from(params);

assert_eq!(
sanitized.text,
r#"select "email", "REPLACED_TOKEN_WITH_QUOTE" from "auth"."users";"#
);
assert_eq!(
sanitized.text,
r#"select "email", "REPLACED_TOKEN_WITH_QUOTE" from "auth"."users";"#
);
}

{
// select * from "auth"."|; <-- with semi
let input = r#"select * from "auth".";"#;
let position = TextSize::new(22);

let params = get_test_params(input, position);

let sanitized = SanitizedCompletionParams::from(params);

assert_eq!(
sanitized.text,
r#"select * from "auth"."REPLACED_TOKEN_WITH_QUOTE";"#
);
}

{
// select * from "auth"."| <-- without semi
let input = r#"select * from "auth".""#;
let position = TextSize::new(22);

let params = get_test_params(input, position);

let sanitized = SanitizedCompletionParams::from(params);

assert_eq!(
sanitized.text,
r#"select * from "auth"."REPLACED_TOKEN_WITH_QUOTE""#
);
}
}

#[test]
Expand Down
9 changes: 7 additions & 2 deletions crates/pgt_hover/src/hovered_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ impl HoveredNode {
pub(crate) fn get(ctx: &pgt_treesitter::context::TreesitterContext) -> Option<Self> {
let node_content = ctx.get_node_under_cursor_content()?;

let under_node = ctx.node_under_cursor.as_ref()?;
let under_cursor = ctx.node_under_cursor.as_ref()?;

match under_node.kind() {
match under_cursor.kind() {
"identifier" if ctx.matches_ancestor_history(&["relation", "object_reference"]) => {
if let Some(schema) = ctx.schema_or_alias_name.as_ref() {
Some(HoveredNode::Table(NodeIdentification::SchemaAndName((
Expand Down Expand Up @@ -64,6 +64,11 @@ impl HoveredNode {
Some(HoveredNode::Role(NodeIdentification::Name(node_content)))
}

// quoted columns
"literal" if ctx.matches_ancestor_history(&["select_expression", "term"]) => {
Some(HoveredNode::Column(NodeIdentification::Name(node_content)))
}

_ => None,
}
}
Expand Down
Loading
Loading