diff --git a/Cargo.toml b/Cargo.toml index 37404d4..08d8c31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ parking_lot = "0.12.4" [dev-dependencies] tonic-build = { version = "0.12.3", features = ["prost"] } +testcontainers = { version = "0.26", features = ["http_wait"] } [features] default = ["download_snapshots", "serde", "generate-snippets"] diff --git a/src/lib.rs b/src/lib.rs index 4d4d15e..fd6cd85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -242,6 +242,7 @@ mod tests { } #[tokio::test] + #[ignore = "requires running Qdrant server, use tests/integration.rs instead"] async fn test_qdrant_queries() -> anyhow::Result<()> { let client = Qdrant::from_url("http://localhost:6334") .timeout(10u64) // larger timeout to account for the slow snapshot creation diff --git a/src/qdrant_client/collection.rs b/src/qdrant_client/collection.rs index 1bb1527..68255d8 100644 --- a/src/qdrant_client/collection.rs +++ b/src/qdrant_client/collection.rs @@ -443,6 +443,7 @@ mod tests { }; #[tokio::test] + #[ignore = "requires running Qdrant server, use tests/integration.rs instead"] async fn create_collection_and_do_the_search() -> QdrantResult<()> { let client = Qdrant::from_url("http://localhost:6334").build()?; diff --git a/src/qdrant_client/query.rs b/src/qdrant_client/query.rs index 900c9c9..5602163 100644 --- a/src/qdrant_client/query.rs +++ b/src/qdrant_client/query.rs @@ -145,6 +145,7 @@ mod tests { use crate::Payload; #[tokio::test] + #[ignore = "requires running Qdrant server, use tests/integration.rs instead"] async fn test_query() { let client = Qdrant::from_url("http://localhost:6334").build().unwrap(); let collection_name = "test_collection_query"; diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..669c07f --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,228 @@ +//! Integration tests using testcontainers +//! +//! These tests spin up a Qdrant container and run the full test suite against it. +//! The container is shared across all tests for efficiency. + +mod test_utils; + +use std::collections::HashMap; + +#[cfg(feature = "download_snapshots")] +use qdrant_client::qdrant::SnapshotDownloadBuilder; +use qdrant_client::qdrant::{ + Condition, CreateCollectionBuilder, DeletePayloadPointsBuilder, DeletePointsBuilder, Distance, + Filter, GetPointsBuilder, PointStruct, QueryPointsBuilder, SearchPointsBuilder, + SetPayloadPointsBuilder, UpsertPointsBuilder, Value, VectorParamsBuilder, +}; +use qdrant_client::{Payload, Qdrant}; +use test_utils::get_or_create_container; + +#[tokio::test] +async fn test_qdrant_queries() -> anyhow::Result<()> { + let container = get_or_create_container().await; + + let client = Qdrant::from_url(&container.grpc_url) + .timeout(10u64) // larger timeout to account for the slow snapshot creation + .build()?; + + let health = client.health_check().await?; + println!("{health:?}"); + + let collections_list = client.list_collections().await?; + println!("{collections_list:?}"); + + let collection_name = "test_qdrant_queries"; + client.delete_collection(collection_name).await?; + + client + .create_collection( + CreateCollectionBuilder::new(collection_name) + .vectors_config(VectorParamsBuilder::new(10, Distance::Cosine)), + ) + .await?; + + let exists = client.collection_exists(collection_name).await?; + assert!(exists); + + let collection_info = client.collection_info(collection_name).await?; + println!("{collection_info:#?}"); + + let mut sub_payload = Payload::new(); + sub_payload.insert("foo", "Not bar"); + + let payload: Payload = vec![ + ("foo", "Bar".into()), + ("bar", 12.into()), + ("sub_payload", sub_payload.into()), + ] + .into_iter() + .collect::>() + .into(); + + let points = vec![PointStruct::new(0, vec![12.; 10], payload)]; + client + .upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true)) + .await?; + + let mut search_points = SearchPointsBuilder::new(collection_name, vec![11.; 10], 10).build(); + + // Keyword filter result + search_points.filter = Some(Filter::all([Condition::matches("foo", "Bar".to_string())])); + let search_result = client.search_points(search_points.clone()).await?; + eprintln!("search_result = {search_result:#?}"); + assert!(!search_result.result.is_empty()); + + // Existing implementations full text search filter result (`Condition::matches`) + search_points.filter = Some(Filter::all([Condition::matches( + "sub_payload.foo", + "Not ".to_string(), + )])); + let search_result = client.search_points(search_points.clone()).await?; + eprintln!("search_result = {search_result:#?}"); + assert!(!search_result.result.is_empty()); + + // Full text search filter result (`Condition::matches_text`) + search_points.filter = Some(Filter::all([Condition::matches_text( + "sub_payload.foo", + "Not", + )])); + let search_result = client.search_points(search_points).await?; + eprintln!("search_result = {search_result:#?}"); + assert!(!search_result.result.is_empty()); + + // Override payload of the existing point + let new_payload: Payload = vec![("foo", "BAZ".into())] + .into_iter() + .collect::>() + .into(); + + let payload_result = client + .set_payload( + SetPayloadPointsBuilder::new(collection_name, new_payload).points_selector([0]), + ) + .await?; + eprintln!("payload_result = {payload_result:#?}"); + + // Delete some payload fields + client + .delete_payload( + DeletePayloadPointsBuilder::new(collection_name, ["sub_payload".into()]) + .points_selector([0]), + ) + .await?; + + let get_points_result = client + .get_points( + GetPointsBuilder::new(collection_name, [0.into()]) + .with_vectors(true) + .with_payload(true), + ) + .await?; + eprintln!("get_points_result = {get_points_result:#?}"); + assert_eq!(get_points_result.result.len(), 1); + let point = get_points_result.result[0].clone(); + assert!(point.payload.contains_key("foo")); + assert!(!point.payload.contains_key("sub_payload")); + + let delete_points_result = client + .delete_points( + DeletePointsBuilder::new(collection_name) + .points([0]) + .wait(true), + ) + .await?; + eprintln!("delete_points_result = {delete_points_result:#?}"); + + // slow operation + let snapshot_result = client.create_snapshot(collection_name).await?; + eprintln!("snapshot_result = {snapshot_result:#?}"); + + #[cfg(feature = "download_snapshots")] + client + .download_snapshot(SnapshotDownloadBuilder::new("test.tar", collection_name)) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_create_collection_and_do_the_search() -> anyhow::Result<()> { + let container = get_or_create_container().await; + + let client = Qdrant::from_url(&container.grpc_url).build()?; + + let health = client.health_check().await?; + println!("{health:?}"); + + let collection_name = "test_create_collection_and_do_the_search"; + + client.delete_collection(collection_name).await?; + + client + .create_collection( + CreateCollectionBuilder::new(collection_name) + .vectors_config(VectorParamsBuilder::new(10, Distance::Cosine)), + ) + .await?; + + let points = vec![PointStruct::new( + 0, + vec![12.; 10], + Payload::try_from(serde_json::json!({ + "field": "value" + })) + .unwrap(), + )]; + + client + .upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true)) + .await?; + + let search_points = SearchPointsBuilder::new(collection_name, vec![11.; 10], 10).build(); + + let search_result = client.search_points(search_points).await?; + eprintln!("search_result = {search_result:#?}"); + assert!(!search_result.result.is_empty()); + + Ok(()) +} + +#[tokio::test] +async fn test_query() { + let container = get_or_create_container().await; + + let client = Qdrant::from_url(&container.grpc_url).build().unwrap(); + + let collection_name = "test_query"; + + client.delete_collection(collection_name).await.ok(); + + client + .create_collection( + CreateCollectionBuilder::new(collection_name) + .vectors_config(VectorParamsBuilder::new(10, Distance::Cosine)), + ) + .await + .unwrap(); + + let points = vec![PointStruct::new( + 0, + vec![12.; 10], + Payload::try_from(serde_json::json!({ + "field": "value" + })) + .unwrap(), + )]; + + client + .upsert_points(UpsertPointsBuilder::new(collection_name, points).wait(true)) + .await + .unwrap(); + + let query_result = client + .query(QueryPointsBuilder::new(collection_name).query(vec![11.; 10])) + .await + .unwrap(); + eprintln!("query_result = {query_result:#?}"); + assert!(!query_result.result.is_empty()); +} diff --git a/tests/test_utils.rs b/tests/test_utils.rs new file mode 100644 index 0000000..7d1c045 --- /dev/null +++ b/tests/test_utils.rs @@ -0,0 +1,141 @@ +//! Test utilities for integration testing with testcontainers +//! +//! Provides container setup for Qdrant vector database. +//! Uses tmpfs mounts for fast in-memory testing. + +use std::env; +use std::sync::OnceLock; + +use testcontainers::core::wait::HttpWaitStrategy; +use testcontainers::core::{IntoContainerPort, Mount, WaitFor}; +use testcontainers::runners::AsyncRunner; +use testcontainers::{ContainerAsync, GenericImage, ImageExt, TestcontainersError}; + +// Environment variable keys +pub const QDRANT_VERSION_ENV: &str = "QDRANT_VERSION"; + +// Default version - matches the version used in integration-tests.sh +const DEFAULT_QDRANT_VERSION: &str = "v1.16.0"; + +// Qdrant ports +const QDRANT_GRPC_PORT: u16 = 6334; +const QDRANT_HTTP_PORT: u16 = 6333; + +/// Global container instance for test reuse +pub static CONTAINER: OnceLock = OnceLock::new(); + +/// Container for Qdrant +#[allow(dead_code)] +pub struct QdrantContainer { + container: ContainerAsync, + pub grpc_port: u16, + pub http_port: u16, + pub grpc_url: String, + pub http_url: String, +} + +impl QdrantContainer { + /// Create a new Qdrant container + /// + /// # Arguments + /// + /// * `use_tmpfs` - Enable tmpfs mount for storage directory (recommended for tests) + /// + /// # Errors + /// + /// Returns error if container fails to start + pub async fn try_new(use_tmpfs: bool) -> Result { + let version = + env::var(QDRANT_VERSION_ENV).unwrap_or_else(|_| DEFAULT_QDRANT_VERSION.to_string()); + + let grpc_port = QDRANT_GRPC_PORT.tcp(); + let http_port = QDRANT_HTTP_PORT.tcp(); + + let http_strat = HttpWaitStrategy::new("/healthz") + .with_port(testcontainers::core::ports::ContainerPort::Tcp( + QDRANT_HTTP_PORT, + )) + .with_response_matcher(|response| response.status() == 200); + + // Create base image + let image = GenericImage::new("qdrant/qdrant", &version) + .with_exposed_port(grpc_port) + .with_exposed_port(http_port) + .with_wait_for(WaitFor::http(http_strat)); + + // Start container with tmpfs if requested + let container: ContainerAsync = if use_tmpfs { + image + .with_mount(Mount::tmpfs_mount("/qdrant/storage").with_size("5g")) + .start() + .await? + } else { + image.start().await? + }; + + // Get mapped ports + let grpc_port = container.get_host_port_ipv4(QDRANT_GRPC_PORT).await?; + let http_port = container.get_host_port_ipv4(QDRANT_HTTP_PORT).await?; + + let grpc_url = format!("http://localhost:{grpc_port}"); + let http_url = format!("http://localhost:{http_port}"); + + Ok(QdrantContainer { + container, + grpc_port, + http_port, + grpc_url, + http_url, + }) + } +} + +/// Get or create a shared Qdrant container for tests +/// +/// This function ensures only one container is created and reused across all tests. +/// Uses tmpfs for fast in-memory testing. +/// +/// # Panics +/// +/// Panics if container fails to start +pub async fn get_or_create_container() -> &'static QdrantContainer { + if let Some(c) = CONTAINER.get() { + return c; + } + + let container = QdrantContainer::try_new(true) + .await + .expect("Failed to start Qdrant container"); + + CONTAINER.get_or_init(|| container) +} + +/// Create a new standalone Qdrant container +/// +/// Unlike `get_or_create_container`, this creates a fresh container each time. +/// Useful when tests need isolation. +/// +/// # Errors +/// +/// Returns error if container fails to start +#[allow(dead_code)] +pub async fn create_container() -> Result { + QdrantContainer::try_new(true).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_constants() { + assert_eq!(DEFAULT_QDRANT_VERSION, "v1.16.0"); + assert_eq!(QDRANT_GRPC_PORT, 6334); + assert_eq!(QDRANT_HTTP_PORT, 6333); + } + + #[test] + fn test_env_var_constants() { + assert_eq!(QDRANT_VERSION_ENV, "QDRANT_VERSION"); + } +}