Skip to content

Commit

Permalink
feat: add a text_to_sql function
Browse files Browse the repository at this point in the history
  • Loading branch information
jgpruitt committed Dec 17, 2024
1 parent 295c0f0 commit 17e3d28
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 2 deletions.
217 changes: 215 additions & 2 deletions projects/extension/sql/idempotent/905-text-to-sql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ set search_path to pg_catalog, pg_temp
-- _text_to_sql_prompt
create or replace function ai._text_to_sql_prompt
( prompt pg_catalog.text
, "limit" pg_catalog.int8 default 5
, objtypes pg_catalog.text[] default null
, max_dist pg_catalog.float8 default null
, catalog_name pg_catalog.text default 'default'
) returns pg_catalog.text
as $func$
Expand Down Expand Up @@ -130,11 +133,14 @@ begin
from ai._find_relevant_obj
( _catalog_id
, _prompt_emb
, "limit"=>"limit"
, objtypes=>objtypes
, max_dist=>max_dist
) r
;

-- distinct tables
select pg_catalog.array_agg(objid) into _distinct_tables
select pg_catalog.array_agg(distinct objid) into _distinct_tables
from pg_catalog.jsonb_to_recordset(_relevant_obj) x
( objtype pg_catalog.text
, objid pg_catalog.oid
Expand Down Expand Up @@ -185,7 +191,7 @@ begin
;

-- distinct views
select pg_catalog.array_agg(objid) into _distinct_views
select pg_catalog.array_agg(distinct objid) into _distinct_views
from pg_catalog.jsonb_to_recordset(_relevant_obj) x
( objtype pg_catalog.text
, objid pg_catalog.oid
Expand Down Expand Up @@ -269,6 +275,8 @@ begin
from ai._find_relevant_sql
( _catalog_id
, _prompt_emb
, "limit"=>"limit"
, max_dist=>max_dist
) r
;

Expand All @@ -294,4 +302,209 @@ $func$ language plpgsql stable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- text_to_sql_openai
create or replace function ai.text_to_sql_openai
( model pg_catalog.text
, api_key pg_catalog.text default null
, api_key_name pg_catalog.text default null
, base_url pg_catalog.text default null
, frequency_penalty pg_catalog.float8 default null
, logit_bias pg_catalog.jsonb default null
, logprobs pg_catalog.bool default null
, top_logprobs pg_catalog.int4 default null
, max_tokens pg_catalog.int4 default null
, n pg_catalog.int4 default null
, presence_penalty pg_catalog.float8 default null
, seed pg_catalog.int4 default null
, stop pg_catalog.text default null
, temperature pg_catalog.float8 default null
, top_p pg_catalog.float8 default null
, openai_user pg_catalog.text default null
) returns pg_catalog.jsonb
as $func$
select json_object
( 'provider': 'openai'
, 'model': model
, 'api_key': api_key
, 'api_key_name': api_key_name
, 'base_url': base_url
, 'frequency_penalty': frequency_penalty
, 'logit_bias': logit_bias
, 'logprobs': logprobs
, 'top_logprobs': top_logprobs
, 'max_tokens': max_tokens
, 'n': n
, 'presence_penalty': presence_penalty
, 'seed': seed
, 'stop': stop
, 'temperature': temperature
, 'top_p': top_p
, 'openai_user': openai_user
absent on null
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- text_to_sql_ollama
create or replace function ai.text_to_sql_ollama
( model pg_catalog.text
, host pg_catalog.text default null
, keep_alive pg_catalog.text default null
, chat_options pg_catalog.jsonb default null
) returns pg_catalog.jsonb
as $func$
select json_object
( 'provider': 'ollama'
, 'model': model
, 'host': host
, 'keep_alive': keep_alive
, 'chat_options': chat_options
absent on null
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- text_to_sql_anthropic
create or replace function ai.text_to_sql_anthropic
( model text
, max_tokens int default 1024
, api_key text default null
, api_key_name text default null
, base_url text default null
, timeout float8 default null
, max_retries int default null
, user_id text default null
, stop_sequences text[] default null
, temperature float8 default null
, top_k int default null
, top_p float8 default null
) returns pg_catalog.jsonb
as $func$
select json_object
( 'provider': 'anthropic'
, 'model': model
, 'max_tokens': max_tokens
, 'api_key': api_key
, 'api_key_name': api_key_name
, 'base_url': base_url
, 'timeout': timeout
, 'max_retries': max_retries
, 'user_id': user_id
, 'stop_sequences': stop_sequences
, 'temperature': temperature
, 'top_k': top_k
, 'top_p': top_p
absent on null
)
$func$ language sql immutable security invoker
set search_path to pg_catalog, pg_temp
;

-------------------------------------------------------------------------------
-- text_to_sql
create or replace function ai.text_to_sql
( prompt pg_catalog.text
, config pg_catalog.jsonb
, "limit" pg_catalog.int8 default 5
, objtypes pg_catalog.text[] default null
, max_dist pg_catalog.float8 default null
, catalog_name pg_catalog.text default 'default'
) returns pg_catalog.text
as $func$
declare
_system_prompt pg_catalog.text;
_user_prompt pg_catalog.text;
_response pg_catalog.jsonb;
_sql pg_catalog.text;
begin
_system_prompt = trim
($txt$
You are an expert database developer and DBA specializing in PostgreSQL.
You will be provided with context about a database model and a question to be answered.
You respond with nothing but a SQL statement that addresses the question posed.
The SQL statement must be valid syntax for PostgreSQL.
SQL features and functions that are built-in to PostgreSQL may be used.
$txt$);

_user_prompt = ai._text_to_sql_prompt
( prompt
, "limit"=>"limit"
, objtypes=>objtypes
, max_dist=>max_dist
, catalog_name=>catalog_name
);
raise log 'prompt: %', _user_prompt;

case config operator(pg_catalog.->>) 'provider'
when 'openai' then
_response = ai.openai_chat_complete
( config operator(pg_catalog.->>) 'model'
, pg_catalog.jsonb_build_array
( jsonb_build_object('role', 'system', 'content', _system_prompt)
, jsonb_build_object('role', 'user', 'content', _user_prompt)
)
, api_key=>config operator(pg_catalog.->>) 'api_key'
, api_key_name=>config operator(pg_catalog.->>) 'api_key_name'
, base_url=>config operator(pg_catalog.->>) 'base_url'
, frequency_penalty=>(config operator(pg_catalog.->>) 'frequency_penalty')::pg_catalog.float8
, logit_bias=>(config operator(pg_catalog.->>) 'logit_bias')::pg_catalog.jsonb
, logprobs=>(config operator(pg_catalog.->>) 'logprobs')::pg_catalog.bool
, top_logprobs=>(config operator(pg_catalog.->>) 'top_logprobs')::pg_catalog.int4
, max_tokens=>(config operator(pg_catalog.->>) 'max_tokens')::pg_catalog.int4
, n=>(config operator(pg_catalog.->>) 'n')::pg_catalog.int4
, presence_penalty=>(config operator(pg_catalog.->>) 'presence_penalty')::pg_catalog.float8
, seed=>(config operator(pg_catalog.->>) 'seed')::pg_catalog.int4
, stop=>(config operator(pg_catalog.->>) 'stop')
, temperature=>(config operator(pg_catalog.->>) 'temperature')::pg_catalog.float8
, top_p=>(config operator(pg_catalog.->>) 'top_p')::pg_catalog.float8
, openai_user=>(config operator(pg_catalog.->>) 'openai_user')
);
raise log 'response: %', _response;
_sql = pg_catalog.jsonb_extract_path_text(_response, 'choices', '0', 'message', 'content');
when 'ollama' then
_response = ai.ollama_chat_complete
( config operator(pg_catalog.->>) 'model'
, pg_catalog.jsonb_build_array
( jsonb_build_object('role', 'system', 'content', _system_prompt)
, jsonb_build_object('role', 'user', 'content', _user_prompt)
)
, host=>(config operator(pg_catalog.->>) 'host')
, keep_alive=>(config operator(pg_catalog.->>) 'keep_alive')
, chat_options=>(config operator(pg_catalog.->) 'chat_options')
);
raise log 'response: %', _response;
_sql = pg_catalog.jsonb_extract_path_text(_response, 'choices', '0', 'message', 'content');
when 'anthropic' then
_response = ai.anthropic_generate
( config operator(pg_catalog.->>) 'model'
, pg_catalog.jsonb_build_array
( jsonb_build_object('role', 'user', 'content', _user_prompt)
)
, system_prompt=>_system_prompt
, max_tokens=>(config operator(pg_catalog.->>) 'max_tokens')::pg_catalog.int4
, api_key=>(config operator(pg_catalog.->>) 'api_key')
, api_key_name=>(config operator(pg_catalog.->>) 'api_key_name')
, base_url=>(config operator(pg_catalog.->>) 'base_url')
, timeout=>(config operator(pg_catalog.->>) 'timeout')::pg_catalog.float8
, max_retries=>(config operator(pg_catalog.->>) 'max_retries')::pg_catalog.int4
, user_id=>(config operator(pg_catalog.->>) 'user_id')
, temperature=>(config operator(pg_catalog.->>) 'temperature')::pg_catalog.float8
, top_k=>(config operator(pg_catalog.->>) 'top_k')::pg_catalog.int4
, top_p=>(config operator(pg_catalog.->>) 'top_p')::pg_catalog.float8
);
raise log 'response: %', _response;
_sql = pg_catalog.jsonb_extract_path_text(_response, 'content', '0', 'text');
else
raise exception 'unsupported provider';
end case;
return _sql;
end
$func$ language plpgsql stable security invoker
set search_path to pg_catalog, pg_temp
;

45 changes: 45 additions & 0 deletions projects/extension/tests/text_to_sql/prompt.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
Consider the following context when responding.
Any relevant table, view, and functions descriptions and DDL definitions will appear in <table></table>, <view></view>, and <function></function> tags respectively.
Any relevant example SQL statements will appear in <example-sql></example-sql> tags.
<table>
/*
# public.bob
this is a comment about the bob table
## id
this is a comment about the id column
## foo
this is a comment about the foo column
*/
CREATE TABLE public.bob
( id integer NOT NULL
, foo text NOT NULL
, bar timestamp with time zone NOT NULL now()
, PRIMARY KEY (id)
);
CREATE UNIQUE INDEX bob_pkey ON public.bob USING btree (id)
</table>
<view>
/*
# public.bobby

## id
this is a comment about the id column
## foo
this is a comment about the foo column
*/
CREATE VIEW public.bobby AS
SELECT id,
foo,
bar
FROM public.bob;

</view>
<example-sql>
/*
a bogus query against the bobby view using the life function
*/
select * from bobby where id = life(id)
</example-sql>
Respond to the following question with a SQL statement only. Only use syntax and functions that work with PostgreSQL.
Q: Construct a query that gives me the distinct foo where the corresponding ids are evenly divisible life.
A:
26 changes: 26 additions & 0 deletions projects/extension/tests/text_to_sql/test_text_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,32 @@ def test_text_to_sql() -> None:
== "a bogus query against the bobby view using the life function"
)

cur.execute(
"""select ai._text_to_sql_prompt('Construct a query that gives me the distinct foo where the corresponding ids are evenly divisible life.')"""
)
actual = cur.fetchone()[0]
# host_dir().joinpath("prompt.expected").write_text(actual)
expected = file_contents("prompt.expected")
assert actual == expected

anthropic_api_key = os.environ["ANTHROPIC_API_KEY"]
assert anthropic_api_key is not None
cur.execute(
"select set_config('ai.anthropic_api_key', %s, false) is not null",
(anthropic_api_key,),
)
cur.execute(
"""
select ai.text_to_sql
( 'Construct a query that gives me the distinct foo where the corresponding ids are evenly divisible life.'
, ai.text_to_sql_anthropic('claude-3-5-sonnet-20240620')
)
"""
)
actual = cur.fetchone()[0]
assert actual is not None
cur.execute(f"explain {actual}") # make sure it's valid sql

snapshot_catalog("text_to_sql_2")
actual = file_contents("snapshot-catalog.actual")
expected = file_contents("snapshot-catalog.expected")
Expand Down

0 comments on commit 17e3d28

Please sign in to comment.