Skip to content

Commit a7db92a

Browse files
authored
handle portkey virtkey in bgw (#138)
* handle portkey virtkey in bgw * bump to 0.18.1 * unused code * move vectorscake statement to db init * add log on test * setup * run on ubuntu 24 * job
1 parent 8c527c2 commit a7db92a

File tree

10 files changed

+70
-51
lines changed

10 files changed

+70
-51
lines changed

.github/workflows/extension_ci.yml

+12-10
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ on:
2525
jobs:
2626
dependencies:
2727
name: Install dependencies
28-
runs-on: ubuntu-22.04
28+
runs-on: ubuntu-24.04
2929
steps:
3030
- uses: actions/checkout@v2
3131

@@ -85,7 +85,7 @@ jobs:
8585
test:
8686
name: Run tests
8787
needs: dependencies
88-
runs-on: ubuntu-22.04
88+
runs-on: ubuntu-24.04
8989
services:
9090
# Label used to access the service container
9191
vector-serve:
@@ -100,10 +100,15 @@ jobs:
100100
toolchain: stable
101101
- uses: Swatinem/rust-cache@v2
102102
with:
103-
prefix-key: "pg-vectorize-extension-test"
104-
workspaces: pg-vectorize
103+
prefix-key: "extension-test"
104+
workspaces: |
105+
vectorize
105106
# Additional directories to cache
106-
cache-directories: /home/runner/.pgrx
107+
cache-directories: |
108+
/home/runner/.pgrx
109+
- name: Install sys dependencies
110+
run: |
111+
sudo apt-get update && sudo apt-get install -y postgresql-server-dev-16 libopenblas-dev libreadline-dev
107112
- uses: ./.github/actions/pgx-init
108113
with:
109114
working-directory: ./extension
@@ -126,10 +131,7 @@ jobs:
126131
${{ runner.os }}-bins-
127132
- name: setup-tests
128133
run: |
129-
make trunk-dependencies
130-
make setup.urls
131-
make setup.shared_preload_libraries
132-
rm -rf ./target/pgrx-test-data-* || true
134+
make setup
133135
- name: unit-test
134136
run: |
135137
make test-unit
@@ -146,7 +148,7 @@ jobs:
146148
publish:
147149
if: github.event_name == 'release'
148150
name: trunk publish
149-
runs-on: ubuntu-22.04
151+
runs-on: ubuntu-24.04
150152
strategy:
151153
matrix:
152154
pg-version: [14, 15, 16]

core/src/types.rs

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ pub struct JobParams {
114114
pub api_key: Option<String>,
115115
#[serde(default = "default_schedule")]
116116
pub schedule: String,
117+
pub args: Option<serde_json::Value>,
117118
}
118119

119120
fn default_schedule() -> String {

core/src/worker/base.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,17 @@ async fn execute_job(
9494
let job_meta: VectorizeMeta = msg.message.job_meta;
9595
let job_params: JobParams = serde_json::from_value(job_meta.params.clone())?;
9696

97+
let virtual_key = if let Some(args) = job_params.args.clone() {
98+
args.get("virtual_key").map(|v| v.to_string())
99+
} else {
100+
None
101+
};
102+
97103
let provider = providers::get_provider(
98104
&job_meta.transformer.source,
99105
job_params.api_key.clone(),
100106
None,
101-
None,
107+
virtual_key,
102108
)?;
103109

104110
let embedding_request =

extension/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "vectorize"
3-
version = "0.18.0"
3+
version = "0.18.1"
44
edition = "2021"
55
publish = false
66

extension/Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ install-pgvector:
7272
install-pgmq:
7373
git clone https://github.com/tembo-io/pgmq.git && \
7474
cd pgmq/pgmq-extension && \
75-
PG_CONFIG=${PGRX_PG_CONFIG} make && \
7675
PG_CONFIG=${PGRX_PG_CONFIG} make clean && \
76+
PG_CONFIG=${PGRX_PG_CONFIG} make && \
7777
PG_CONFIG=${PGRX_PG_CONFIG} make install && \
78-
cd .. && rm -rf pgmq
78+
cd ../.. && rm -rf pgmq
7979

8080
install-vectorscale:
8181
@ARCH=$$(uname -m); \

extension/Trunk.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "The simplest way to orchestrate vector search on Postgres."
66
homepage = "https://github.com/tembo-io/pg_vectorize"
77
documentation = "https://github.com/tembo-io/pg_vectorize"
88
categories = ["orchestration", "machine_learning"]
9-
version = "0.18.0"
9+
version = "0.18.1"
1010
loadable_libraries = [{ library_name = "vectorize", requires_restart = true }]
1111

1212
[build]

extension/sql/vectorize--0.18.0--0.18.1.sql

Whitespace-only changes.

extension/src/search.rs

+33-26
Original file line numberDiff line numberDiff line change
@@ -39,38 +39,19 @@ pub fn init_table(
3939
init::init_pgmq()?;
4040

4141
let guc_configs = get_guc_configs(&transformer.source);
42-
let provider = get_provider(
43-
&transformer.source,
44-
guc_configs.api_key.clone(),
45-
guc_configs.service_url,
46-
None,
47-
)?;
48-
49-
//synchronous
50-
let runtime = tokio::runtime::Builder::new_current_thread()
51-
.enable_io()
52-
.enable_time()
53-
.build()
54-
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
55-
let model_dim =
56-
match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) {
57-
Ok(e) => e,
58-
Err(e) => {
59-
error!("error getting model dim: {}", e);
60-
}
61-
};
62-
63-
// validate API key where necessary
42+
info!("guc_configs: {:?}", guc_configs);
43+
// validate API key where necessary and collect any optional arguments
6444
// certain embedding services require an API key, e.g. openAI
6545
// key can be set in a GUC, so if its required but not provided in args, and not in GUC, error
66-
match transformer.source {
46+
let optional_args = match transformer.source {
6747
ModelSource::OpenAI => {
6848
openai::validate_api_key(
6949
&guc_configs
7050
.api_key
7151
.clone()
7252
.context("OpenAI key is required")?,
7353
)?;
54+
None
7455
}
7556
ModelSource::Tembo => {
7657
error!("Tembo not implemented for search yet");
@@ -85,15 +66,40 @@ pub fn init_table(
8566
let res = check_model_host(&url);
8667
match res {
8768
Ok(_) => {
88-
info!("Model host active!")
69+
info!("Model host active!");
70+
None
8971
}
9072
Err(e) => {
9173
error!("Error with model host: {:?}", e)
9274
}
9375
}
9476
}
95-
_ => (),
96-
}
77+
ModelSource::Portkey => Some(serde_json::json!({
78+
"virtual_key": guc_configs.virtual_key.clone().expect("Portkey virtual key is required")
79+
})),
80+
_ => None,
81+
};
82+
83+
let provider = get_provider(
84+
&transformer.source,
85+
guc_configs.api_key.clone(),
86+
guc_configs.service_url.clone(),
87+
guc_configs.virtual_key.clone(),
88+
)?;
89+
90+
// synchronous
91+
let runtime = tokio::runtime::Builder::new_current_thread()
92+
.enable_io()
93+
.enable_time()
94+
.build()
95+
.unwrap_or_else(|e| error!("failed to initialize tokio runtime: {}", e));
96+
let model_dim =
97+
match runtime.block_on(async { provider.model_dim(&transformer.api_name()).await }) {
98+
Ok(e) => e,
99+
Err(e) => {
100+
error!("error getting model dim: {}", e);
101+
}
102+
};
97103

98104
let valid_params = types::JobParams {
99105
schema: schema.to_string(),
@@ -105,6 +111,7 @@ pub fn init_table(
105111
pkey_type,
106112
api_key: guc_configs.api_key.clone(),
107113
schedule: schedule.to_string(),
114+
args: optional_args,
108115
};
109116
let params =
110117
pgrx::JsonB(serde_json::to_value(valid_params.clone()).expect("error serializing params"));

extension/tests/integration_tests.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -788,10 +788,6 @@ async fn test_diskann_cosine() {
788788
common::init_test_table(&test_table_name, &conn).await;
789789
let job_name = format!("job_diskann_{}", test_num);
790790

791-
let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale;")
792-
.execute(&conn)
793-
.await;
794-
795791
common::init_embedding_svc_url(&conn).await;
796792
// initialize a job
797793
let result = sqlx::query(&format!(
@@ -810,9 +806,15 @@ async fn test_diskann_cosine() {
810806
assert!(result.is_ok());
811807

812808
let search_results: Vec<common::SearchJSON> =
813-
util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None)
809+
match util::common::search_with_retry(&conn, "mobile devices", &job_name, 10, 2, 3, None)
814810
.await
815-
.unwrap();
811+
{
812+
Ok(results) => results,
813+
Err(e) => {
814+
eprintln!("Error: {:?}", e);
815+
panic!("failed to exec search on diskann");
816+
}
817+
};
816818
assert_eq!(search_results.len(), 3);
817819
}
818820

extension/tests/util.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ pub mod common {
5353
.await
5454
.expect("failed to create extension");
5555

56+
// Optional dependencies
57+
let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
58+
.execute(&conn)
59+
.await
60+
.expect("failed to create vectorscale extension");
5661
conn
5762
}
5863

@@ -63,10 +68,6 @@ pub mod common {
6368
28815
6469
} else if cfg!(feature = "pg14") {
6570
28814
66-
} else if cfg!(feature = "pg13") {
67-
28813
68-
} else if cfg!(feature = "pg12") {
69-
28812
7071
} else {
7172
5432
7273
}

0 commit comments

Comments
 (0)