From 0659017d007ade545d6bb1158e714ec8c94ca79b Mon Sep 17 00:00:00 2001 From: Paco van der Linden Date: Fri, 8 Mar 2024 22:58:21 +0100 Subject: [PATCH] feat: add support for parsing Firestore database references from string This change adds a new enum `Ref` which represents a reference to either the database root, a single collection, or a single document in a Firestore database. It includes methods to extract the different types of references, check if one reference is a parent of another, and implements the `Display` trait for formatting. Additionally, it introduces error handling for invalid references and conversion to `tonic::Status`. The implementation also includes parsing logic from string to create a `Ref` instance. --- Cargo.lock | 65 +++ crates/firestore-database/Cargo.toml | 4 + crates/firestore-database/src/database.rs | 55 +- .../src/database/collection.rs | 16 +- .../src/database/document.rs | 40 +- .../firestore-database/src/database/event.rs | 5 +- .../src/database/listener.rs | 25 +- .../firestore-database/src/database/query.rs | 46 +- .../src/database/reference.rs | 486 ++++++++++++++++++ .../src/database/transaction.rs | 13 +- crates/firestore-database/src/lib.rs | 2 +- crates/googleapis/Cargo.toml | 1 + crates/googleapis/src/timestamp_ext.rs | 39 +- src/emulator.rs | 151 +++--- 14 files changed, 748 insertions(+), 200 deletions(-) create mode 100644 crates/firestore-database/src/database/reference.rs diff --git a/Cargo.lock b/Cargo.lock index 887b202..de16f4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -432,7 +432,9 @@ dependencies = [ "futures", "googleapis", "itertools 0.12.1", + "rstest", "string_cache", + "thiserror", "tokio", "tokio-stream", "tonic 0.11.0", @@ -565,6 +567,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.30" @@ -600,12 +608,19 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "googleapis" version = "0.0.0" dependencies = [ "itertools 0.12.1", "prost", + "rstest", "thiserror", "time", "tonic 0.11.0", @@ -1260,12 +1275,56 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "relative-path" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc" + +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.31" @@ -1297,6 +1356,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "serde" version = "1.0.197" diff --git a/crates/firestore-database/Cargo.toml b/crates/firestore-database/Cargo.toml index 6bdcaf7..b106ae3 100644 --- a/crates/firestore-database/Cargo.toml +++ b/crates/firestore-database/Cargo.toml @@ -7,7 +7,11 @@ futures = { workspace = true } googleapis = { workspace = true } itertools = { workspace = true } string_cache = { workspace = true } +thiserror = { workspace = true } tokio = { workspace = true } tokio-stream = { workspace = true } tonic = { workspace = true } tracing = { workspace = true } + +[dev-dependencies] +rstest = "0.18.2" diff --git a/crates/firestore-database/src/database.rs b/crates/firestore-database/src/database.rs index 14626dd..9cadea8 100644 --- a/crates/firestore-database/src/database.rs +++ b/crates/firestore-database/src/database.rs @@ -28,6 +28,7 @@ use self::{ field_path::FieldPath, listener::Listener, query::Query, + reference::{CollectionRef, DocumentRef, Ref, RootRef}, transaction::{RunningTransactions, Transaction, TransactionId}, }; use crate::{ @@ -40,19 +41,22 @@ pub mod event; mod field_path; mod listener; mod query; +pub mod reference; mod transaction; const MAX_EVENT_BACKLOG: usize = 1024; pub struct Database { + pub name: RootRef, collections: RwLock>>, transactions: RunningTransactions, events: broadcast::Sender>, } impl Database { - pub fn new() -> Arc { + pub fn new(name: RootRef) -> Arc { Arc::new_cyclic(|database| Database { + name, collections: Default::default(), transactions: RunningTransactions::new(Weak::clone(database)), events: broadcast::channel(MAX_EVENT_BACKLOG).0, @@ -64,10 +68,10 @@ impl Database { #[instrument(skip_all, err, fields(in_txn = consistency.is_transaction(), found))] pub async fn get_doc( &self, - name: &DefaultAtom, + name: &DocumentRef, consistency: &ReadConsistency, ) -> Result> { - info!(name = name.deref()); + info!(%name); let version = if let Some(txn) = self.get_txn_for_consistency(consistency).await? { txn.read_doc(name) .await? @@ -85,28 +89,32 @@ impl Database { Ok(version) } - pub async fn get_collection(&self, collection_name: &DefaultAtom) -> Arc { + pub async fn get_collection(&self, collection_name: &CollectionRef) -> Arc { + debug_assert_eq!(self.name, collection_name.root_ref); Arc::clone( &*self .collections - .get_or_insert(collection_name, || { - Arc::new(Collection::new(collection_name.into())) + .get_or_insert(&collection_name.collection_id, || { + Arc::new(Collection::new(collection_name.clone())) }) .await, ) } #[instrument(skip_all, err)] - pub async fn get_doc_meta(&self, name: &DefaultAtom) -> Result> { - let collection = collection_name(name)?; - let meta = self.get_collection(&collection).await.get_doc(name).await; + pub async fn get_doc_meta(&self, name: &DocumentRef) -> Result> { + let meta = self + .get_collection(&name.collection_ref) + .await + .get_doc(name) + .await; Ok(meta) } #[instrument(skip_all, err)] pub async fn get_doc_meta_mut_no_txn( &self, - name: &DefaultAtom, + name: &DocumentRef, ) -> Result { self.get_doc_meta(name) .await? @@ -134,7 +142,7 @@ impl Database { } #[instrument(skip_all)] - pub async fn get_collection_ids(&self, parent: &DefaultAtom) -> Result> { + pub async fn get_collection_ids(&self, parent_doc: &DocumentRef) -> Result> { // Get all collections asap in order to keep the read lock time minimal. let all_collections = self .collections @@ -146,11 +154,7 @@ impl Database { // Cannot use `filter_map` because of the `await`. let mut result = vec![]; for col in all_collections { - let Some(path) = col - .name - .strip_prefix(parent.deref()) - .and_then(|p| p.strip_prefix('/')) - else { + let Some(path) = col.name.strip_document_prefix(parent_doc) else { continue; }; if col.has_doc().await? { @@ -163,13 +167,14 @@ impl Database { #[instrument(skip_all, err)] pub async fn run_query( &self, - parent: String, + parent: Ref, query: StructuredQuery, consistency: ReadConsistency, ) -> Result> { let mut query = Query::from_structured(parent, query, consistency)?; info!(?query); - query.once(self).await + let result = query.once(self).await?; + Ok(result.into_iter().map(|t| t.1).collect()) } #[instrument(skip_all, err)] @@ -183,7 +188,7 @@ impl Database { let mut write_results = vec![]; let mut updates = HashMap::new(); - let mut write_guard_cache = HashMap::::new(); + let mut write_guard_cache = HashMap::::new(); // This must be done in two phases. First acquire the lock on all docs, only then start to // update them. for write in &writes { @@ -391,15 +396,7 @@ fn apply_transform( Ok(result) } -fn collection_name(name: &DefaultAtom) -> Result { - Ok(name - .rsplit_once('/') - .ok_or_else(|| Status::invalid_argument("invalid document path, missing collection-name"))? - .0 - .into()) -} - -pub fn get_doc_name_from_write(write: &Write) -> Result { +pub fn get_doc_name_from_write(write: &Write) -> Result { let operation = write .operation .as_ref() @@ -410,7 +407,7 @@ pub fn get_doc_name_from_write(write: &Write) -> Result { Delete(name) => name, Transform(trans) => &trans.document, }; - Ok(DefaultAtom::from(name)) + Ok(name.parse()?) } #[derive(Clone, Debug)] diff --git a/crates/firestore-database/src/database/collection.rs b/crates/firestore-database/src/database/collection.rs index d0d25b8..1bcba36 100644 --- a/crates/firestore-database/src/database/collection.rs +++ b/crates/firestore-database/src/database/collection.rs @@ -5,28 +5,32 @@ use tokio::sync::RwLock; use tonic::Result; use tracing::instrument; -use super::document::DocumentMeta; +use super::{ + document::DocumentMeta, + reference::{CollectionRef, DocumentRef}, +}; use crate::utils::RwLockHashMapExt; pub struct Collection { - pub name: DefaultAtom, + pub name: CollectionRef, documents: RwLock>>, } impl Collection { #[instrument(skip_all)] - pub fn new(name: DefaultAtom) -> Self { + pub fn new(name: CollectionRef) -> Self { Self { name, documents: Default::default(), } } - pub async fn get_doc(self: &Arc, name: &DefaultAtom) -> Arc { + pub async fn get_doc(self: &Arc, name: &DocumentRef) -> Arc { + debug_assert_eq!(self.name, name.collection_ref); Arc::clone( self.documents - .get_or_insert(name, || { - Arc::new(DocumentMeta::new(name.clone(), self.name.clone())) + .get_or_insert(&name.document_id, || { + Arc::new(DocumentMeta::new(name.clone())) }) .await .deref(), diff --git a/crates/firestore-database/src/database/document.rs b/crates/firestore-database/src/database/document.rs index f62ac1d..9c5ad08 100644 --- a/crates/firestore-database/src/database/document.rs +++ b/crates/firestore-database/src/database/document.rs @@ -11,7 +11,6 @@ use googleapis::google::{ firestore::v1::{precondition, Document, Value}, protobuf::Timestamp, }; -use string_cache::DefaultAtom; use tokio::{ sync::{ oneshot, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, OwnedSemaphorePermit, RwLock, @@ -22,17 +21,14 @@ use tokio::{ use tonic::{Code, Result, Status}; use tracing::{instrument, trace, Level}; -use super::ReadConsistency; +use super::{reference::DocumentRef, ReadConsistency}; const WAIT_LOCK_TIMEOUT: Duration = Duration::from_secs(30); pub struct DocumentMeta { /// The resource name of the document, for example /// `projects/{project_id}/databases/{database_id}/documents/{document_path}`. - pub name: DefaultAtom, - /// The collection name of the document, i.e. the full name of the document minus the last - /// component. - pub collection_name: DefaultAtom, + pub name: DocumentRef, contents: Arc>, write_permit_shop: Arc, } @@ -46,14 +42,10 @@ impl Debug for DocumentMeta { } impl DocumentMeta { - pub fn new(name: DefaultAtom, collection_name: DefaultAtom) -> Self { + pub fn new(name: DocumentRef) -> Self { Self { - contents: Arc::new(RwLock::new(DocumentContents::new( - name.clone(), - collection_name.clone(), - ))), + contents: Arc::new(RwLock::new(DocumentContents::new(name.clone()))), name, - collection_name, write_permit_shop: Arc::new(Semaphore::new(1)), } } @@ -91,18 +83,14 @@ impl DocumentMeta { pub struct DocumentContents { /// The resource name of the document, for example /// `projects/{project_id}/databases/{database_id}/documents/{document_path}`. - pub name: DefaultAtom, - /// The collection name of the document, i.e. the full name of the document minus the last - /// component. - pub collection_name: DefaultAtom, + pub name: DocumentRef, versions: Vec, } impl DocumentContents { - pub fn new(name: DefaultAtom, collection_name: DefaultAtom) -> Self { + pub fn new(name: DocumentRef) -> Self { Self { name, - collection_name, versions: Default::default(), } } @@ -172,7 +160,7 @@ impl DocumentContents { } #[instrument(skip_all, fields( - doc_name = self.name.deref(), + doc_name = %self.name, time = display(&update_time), ), level = Level::DEBUG)] pub async fn add_version( @@ -184,7 +172,6 @@ impl DocumentContents { let create_time = self.create_time().unwrap_or_else(|| update_time.clone()); let version = DocumentVersion::Stored(Arc::new(StoredDocumentVersion { name: self.name.clone(), - collection_name: self.collection_name.clone(), create_time, update_time, fields, @@ -196,7 +183,6 @@ impl DocumentContents { pub async fn delete(&mut self, delete_time: Timestamp) -> DocumentVersion { let version = DocumentVersion::Deleted(Arc::new(DeletedDocumentVersion { name: self.name.clone(), - collection_name: self.collection_name.clone(), delete_time, })); self.versions.push(version.clone()); @@ -250,7 +236,7 @@ pub enum DocumentVersion { } impl DocumentVersion { - pub fn name(&self) -> &DefaultAtom { + pub fn name(&self) -> &DocumentRef { match self { DocumentVersion::Deleted(ver) => &ver.name, DocumentVersion::Stored(ver) => &ver.name, @@ -287,10 +273,7 @@ impl DocumentVersion { pub struct StoredDocumentVersion { /// The resource name of the document, for example /// `projects/{project_id}/databases/{database_id}/documents/{document_path}`. - pub name: DefaultAtom, - /// The collection name of the document, i.e. the full name of the document minus the last - /// component. - pub collection_name: DefaultAtom, + pub name: DocumentRef, /// The time at which the document was created. /// /// This value increases monotonically when a document is deleted then @@ -345,10 +328,7 @@ impl StoredDocumentVersion { pub struct DeletedDocumentVersion { /// The resource name of the document, for example /// `projects/{project_id}/databases/{database_id}/documents/{document_path}`. - pub name: DefaultAtom, - /// The collection name of the document, i.e. the full name of the document minus the last - /// component. - pub collection_name: DefaultAtom, + pub name: DocumentRef, /// The time at which the document was deleted. pub delete_time: Timestamp, } diff --git a/crates/firestore-database/src/database/event.rs b/crates/firestore-database/src/database/event.rs index b8ed09a..7cc33e8 100644 --- a/crates/firestore-database/src/database/event.rs +++ b/crates/firestore-database/src/database/event.rs @@ -1,11 +1,10 @@ use std::collections::HashMap; use googleapis::google::protobuf::Timestamp; -use string_cache::DefaultAtom; -use super::document::DocumentVersion; +use super::{document::DocumentVersion, reference::DocumentRef}; pub struct DatabaseEvent { pub update_time: Timestamp, - pub updates: HashMap, + pub updates: HashMap, } diff --git a/crates/firestore-database/src/database/listener.rs b/crates/firestore-database/src/database/listener.rs index b7b34be..d76ba77 100644 --- a/crates/firestore-database/src/database/listener.rs +++ b/crates/firestore-database/src/database/listener.rs @@ -17,18 +17,19 @@ use googleapis::google::{ protobuf::Timestamp, }; use itertools::Itertools; -use string_cache::DefaultAtom; use tokio::sync::{broadcast::error::RecvError, mpsc}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{Result, Status}; use tracing::{debug, error, instrument}; use super::{ - document::DocumentVersion, event::DatabaseEvent, query::Query, target_change, Database, + document::DocumentVersion, event::DatabaseEvent, query::Query, reference::DocumentRef, + target_change, Database, }; use crate::{ - database::ReadConsistency, required_option, unimplemented, unimplemented_bool, - unimplemented_collection, unimplemented_option, + database::{reference::Ref, ReadConsistency}, + required_option, unimplemented, unimplemented_bool, unimplemented_collection, + unimplemented_option, }; const TARGET_ID: i32 = 1; @@ -148,6 +149,7 @@ impl Listener { target::TargetType::Query(target::QueryTarget { parent, query_type }) => { required_option!(query_type); let query_target::QueryType::StructuredQuery(query) = query_type; + let parent: Ref = parent.parse()?; let query = Query::from_structured(parent, query, ReadConsistency::Default)?; self.set_query(query).await?; @@ -156,7 +158,7 @@ impl Listener { let Ok(document) = documents.into_iter().exactly_one() else { unimplemented!("multiple documents inside a single listen stream") }; - self.set_document(document.into(), resume_type).await?; + self.set_document(document.parse()?, resume_type).await?; } }; } @@ -193,10 +195,10 @@ impl Listener { self.send_complete(update_time).await } - #[instrument(skip_all, fields(document = &*name), err)] + #[instrument(skip_all, fields(document = %name), err)] async fn set_document( &mut self, - name: DefaultAtom, + name: DocumentRef, resume_type: Option, ) -> Result<()> { // We rely on the fact that this function will complete before any other events are @@ -227,7 +229,7 @@ impl Listener { .await?; let read_time = Timestamp::now(); - debug!(name = &*name); + debug!(name = %name); // Now determine the latest version we can find... let doc = database.get_doc(&name, &ReadConsistency::Default).await?; @@ -365,7 +367,7 @@ impl ListenerTarget { } struct DocumentTarget { - name: DefaultAtom, + name: DocumentRef, last_read_time: Timestamp, } @@ -405,7 +407,7 @@ impl DocumentTarget { struct QueryTarget { query: Query, reset_on_update: bool, - doctargets_by_name: HashMap, + doctargets_by_name: HashMap, } impl QueryTarget { fn new(query: Query) -> Self { @@ -461,8 +463,7 @@ impl QueryTarget { })]; self.doctargets_by_name.clear(); - for doc in self.query.once(database).await? { - let name = DefaultAtom::from(&*doc.name); + for (name, doc) in self.query.once(database).await? { self.doctargets_by_name.insert( name.clone(), DocumentTarget { diff --git a/crates/firestore-database/src/database/query.rs b/crates/firestore-database/src/database/query.rs index 307799e..cfd473c 100644 --- a/crates/firestore-database/src/database/query.rs +++ b/crates/firestore-database/src/database/query.rs @@ -7,8 +7,11 @@ use tonic::{Result, Status}; use self::filter::Filter; use super::{ - collection::Collection, document::StoredDocumentVersion, field_path::FieldReference, Database, - ReadConsistency, + collection::Collection, + document::StoredDocumentVersion, + field_path::FieldReference, + reference::{CollectionRef, DocumentRef, Ref}, + Database, ReadConsistency, }; mod filter; @@ -16,7 +19,7 @@ mod filter; /// A Firestore query. #[derive(Debug)] pub struct Query { - parent: String, + parent: Ref, /// Optional sub-set of the fields to return. /// @@ -114,7 +117,7 @@ pub struct Query { impl Query { pub fn from_structured( - parent: String, + parent: Ref, query: StructuredQuery, consistency: ReadConsistency, ) -> Result { @@ -178,7 +181,7 @@ impl Query { self.order_by.iter().any(|o| !o.field.is_document_name()) } - pub async fn once(&mut self, db: &Database) -> Result> { + pub async fn once(&mut self, db: &Database) -> Result> { // First collect all Arcs in a Vec to release the collection lock asap. let collections = self.applicable_collections(db).await; @@ -230,7 +233,7 @@ impl Query { buffer .into_iter() .skip(self.offset) - .map(|version| self.project(&version)) + .map(|version| Ok((version.name.clone(), self.project(&version)?))) .try_collect() } @@ -244,29 +247,26 @@ impl Query { .collect_vec() } - fn includes_collection(&mut self, path: &DefaultAtom) -> bool { - if let Some(&r) = self.collection_cache.get(path) { + fn includes_collection(&mut self, collection: &CollectionRef) -> bool { + if let Some(&r) = self.collection_cache.get(&collection.collection_id) { return r; } - let included = match path - .strip_prefix(&self.parent) - .and_then(|path| path.strip_prefix('/')) - { - Some(path) => self.from.iter().any(|selector| { - if selector.all_descendants { - path.starts_with(&selector.collection_id) - } else { - path == selector.collection_id - } - }), - None => false, - }; - self.collection_cache.insert(path.clone(), included); + let included = collection.strip_prefix(&self.parent).is_some_and(|path| { + self.from.iter().any(|selector| { + path == selector.collection_id + || selector.all_descendants + && path + .strip_prefix(&selector.collection_id) + .is_some_and(|rest| rest.starts_with('/')) + }) + }); + self.collection_cache + .insert(collection.collection_id.clone(), included); included } pub fn includes_document(&mut self, doc: &StoredDocumentVersion) -> Result { - if !self.includes_collection(&doc.collection_name) { + if !self.includes_collection(&doc.name.collection_ref) { return Ok(false); } if let Some(filter) = &self.filter { diff --git a/crates/firestore-database/src/database/reference.rs b/crates/firestore-database/src/database/reference.rs new file mode 100644 index 0000000..998a555 --- /dev/null +++ b/crates/firestore-database/src/database/reference.rs @@ -0,0 +1,486 @@ +use std::{fmt::Display, str::FromStr}; + +use string_cache::DefaultAtom; +use thiserror::Error; + +/// A reference to either the database root (including project name and database name), a single +/// collection or a single document. +/// +/// # Examples +/// +/// ``` +/// # use firestore_database::reference::*; +/// # fn main() -> Result<(), InvalidReference> { +/// assert_eq!( +/// "projects/my-project/databases/my-database".parse::()?, +/// Ref::Root(RootRef::new("my-project", "my-database")), +/// ); +/// assert_eq!( +/// "projects/my-project/databases/my-database/documents/my-collection".parse::()?, +/// Ref::Collection(CollectionRef::new( +/// RootRef::new("my-project", "my-database"), +/// "my-collection" +/// )) +/// ); +/// assert_eq!( +/// "projects/my-project/databases/my-database/documents/my-collection/my-document" +/// .parse::()?, +/// Ref::Document(DocumentRef::new( +/// CollectionRef::new(RootRef::new("my-project", "my-database"), "my-collection"), +/// "my-document" +/// )) +/// ); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Ref { + Root(RootRef), + Collection(CollectionRef), + Document(DocumentRef), +} + +impl Ref { + pub fn as_root(&self) -> Option<&RootRef> { + if let Self::Root(v) = self { + Some(v) + } else { + None + } + } + + pub fn as_collection(&self) -> Option<&CollectionRef> { + if let Self::Collection(v) = self { + Some(v) + } else { + None + } + } + + pub fn as_document(&self) -> Option<&DocumentRef> { + if let Self::Document(v) = self { + Some(v) + } else { + None + } + } + + pub fn root(&self) -> &RootRef { + match self { + Ref::Root(root) => root, + Ref::Collection(col) => &col.root_ref, + Ref::Document(doc) => &doc.collection_ref.root_ref, + } + } + + /// Returns whether `self` is a (grand)parent of `other`. + /// + /// # Examples + /// + /// ``` + /// # use firestore_database::reference::*; + /// # fn main() -> Result<(), InvalidReference> { + /// let root: Ref = "projects/p/databases/d/documents".parse()?; + /// let collection: Ref = "projects/p/databases/d/documents/collection".parse()?; + /// let document: Ref = "projects/p/databases/d/documents/collection/document".parse()?; + /// assert!(root.is_parent_of(&collection)); + /// assert!(root.is_parent_of(&document)); + /// assert!(collection.is_parent_of(&document)); + /// assert!(!root.is_parent_of(&root)); + /// assert!(!document.is_parent_of(&root)); + /// assert!(!collection.is_parent_of(&collection)); + /// + /// let doc_in_other_col: Ref = "projects/p/databases/d/documents/OTHER/document".parse()?; + /// assert!(!collection.is_parent_of(&doc_in_other_col)); + /// # Ok(()) + /// # } + /// ``` + pub fn is_parent_of(&self, other: &Ref) -> bool { + debug_assert_eq!(self.root(), other.root()); + match (self, other) { + (_, Ref::Root(_)) => false, + (Ref::Root(_), _) => true, + (Ref::Collection(parent), Ref::Collection(child)) => { + child.strip_collection_prefix(parent).is_some() + } + (Ref::Collection(parent), Ref::Document(child)) => { + &child.collection_ref == parent + || child + .collection_ref + .strip_collection_prefix(parent) + .is_some() + } + (Ref::Document(parent), Ref::Collection(child)) => { + child.strip_document_prefix(parent).is_some() + } + (Ref::Document(parent), Ref::Document(child)) => { + child.collection_ref.strip_document_prefix(parent).is_some() + } + } + } +} + +impl Display for Ref { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Ref::Root(r) => r.fmt(f), + Ref::Collection(r) => r.fmt(f), + Ref::Document(r) => r.fmt(f), + } + } +} + +#[derive(Debug, Error)] +#[error("Invalid {0} reference: {1}")] +pub struct InvalidReference(&'static str, String); + +impl From for tonic::Status { + fn from(value: InvalidReference) -> Self { + tonic::Status::invalid_argument(value.to_string()) + } +} + +impl FromStr for Ref { + type Err = InvalidReference; + + fn from_str(s: &str) -> Result { + fn parse_ref(s: &str) -> Option { + let s = s.strip_prefix("projects/")?; + let (project_id, s) = s.split_once('/')?; + let s = s.strip_prefix("databases/")?; + let Some((database_id, s)) = s.split_once('/') else { + return Some(Ref::Root(RootRef::new(project_id, s))); + }; + let s = s.strip_prefix("documents")?; + let root_ref = RootRef::new(project_id, database_id); + if s.is_empty() { + return Some(Ref::Root(root_ref)); + } + let s = s.strip_prefix('/')?; + let slashes = s.chars().filter(|ch| *ch == '/').count(); + let rf = if slashes % 2 == 0 { + Ref::Collection(CollectionRef::new(root_ref, s)) + } else { + let (collection_id, document_id) = s.rsplit_once('/')?; + Ref::Document(DocumentRef::new( + CollectionRef::new(root_ref, collection_id), + document_id, + )) + }; + + // TODO: add checks: + // - Maximum depth of subcollections 100 + // - Maximum size for a document name 6 KiB + // - Constraints on collection IDs + // - Must be valid UTF-8 characters + // - Must be no longer than 1,500 bytes + // - Cannot contain a forward slash (/) + // - Cannot solely consist of a single period (.) or double periods (..) + // - Cannot match the regular expression __.*__ + // - Constraints on document IDs + // - Must be valid UTF-8 characters + // - Must be no longer than 1,500 bytes + // - Cannot contain a forward slash (/) + // - Cannot solely consist of a single period (.) or double periods (..) + // - Cannot match the regular expression __.*__ + // - (If you import Datastore entities into a Firestore database, numeric entity IDs + // are exposed as __id[0-9]+__) + + Some(rf) + } + + parse_ref(s).ok_or_else(|| InvalidReference("database/collection/document", s.to_string())) + } +} + +impl From for Ref { + fn from(v: RootRef) -> Self { + Self::Root(v) + } +} + +impl From for Ref { + fn from(v: DocumentRef) -> Self { + Self::Document(v) + } +} + +impl From for Ref { + fn from(v: CollectionRef) -> Self { + Self::Collection(v) + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct RootRef { + pub project_id: DefaultAtom, + pub database_id: DefaultAtom, +} + +impl FromStr for RootRef { + type Err = InvalidReference; + + fn from_str(s: &str) -> Result { + let rf: Ref = s.parse()?; + rf.as_root() + .cloned() + .ok_or_else(|| InvalidReference("database", s.to_string())) + } +} + +impl Display for RootRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "projects/{}/databases/{}/documents", + self.project_id, self.database_id, + ) + } +} + +impl RootRef { + pub fn new(project_id: impl Into, database_id: impl Into) -> Self { + Self { + project_id: project_id.into(), + database_id: database_id.into(), + } + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CollectionRef { + pub root_ref: RootRef, + pub collection_id: DefaultAtom, +} + +impl FromStr for CollectionRef { + type Err = InvalidReference; + + fn from_str(s: &str) -> Result { + let rf: Ref = s.parse()?; + rf.as_collection() + .cloned() + .ok_or_else(|| InvalidReference("collection", s.to_string())) + } +} + +impl Display for CollectionRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/{}", self.root_ref, self.collection_id,) + } +} + +impl CollectionRef { + pub fn new(root_ref: RootRef, collection_id: impl Into) -> Self { + Self { + root_ref, + collection_id: collection_id.into(), + } + } + + /// Returns the remaining canonical reference string when the given reference `r` has been + /// removed as prefix. Can also be used to efficiently determine whether the given `r` is a + /// parent of this collection. + /// + /// # Examples + /// + /// ``` + /// # use firestore_database::reference::*; + /// # fn main() -> Result<(), InvalidReference> { + /// let parent: Ref = "projects/p/databases/d/documents/parent".parse()?; + /// let child: CollectionRef = "projects/p/databases/d/documents/parent/doc/child".parse()?; + /// assert_eq!(child.strip_prefix(&parent), Some("doc/child")); + /// # Ok(()) + /// # } + /// ``` + pub fn strip_prefix(&self, r: &Ref) -> Option<&str> { + match r { + Ref::Root(root) => Some(self.strip_root_prefix(root)), + Ref::Collection(col) => self.strip_collection_prefix(col), + Ref::Document(doc) => self.strip_document_prefix(doc), + } + } + + /// Returns the remaining canonical reference string when the given `root` has been removed as + /// prefix. Note that this becomes a no-op in release mode, because we assume that references + /// from different databases are never compared to each other. + /// + /// # Examples + /// + /// ``` + /// # use firestore_database::reference::*; + /// # fn main() -> Result<(), InvalidReference> { + /// let parent: RootRef = "projects/p/databases/d/documents".parse()?; + /// let child: CollectionRef = "projects/p/databases/d/documents/parent/doc/child".parse()?; + /// assert_eq!(child.strip_root_prefix(&parent), "parent/doc/child"); + /// # Ok(()) + /// # } + /// ``` + pub fn strip_root_prefix(&self, root: &RootRef) -> &str { + debug_assert_eq!(&self.root_ref, root); + &self.collection_id + } + + /// Returns the remaining canonical reference string when the given `col` has been removed as + /// prefix. Can also be used to efficiently determine whether the given `col` is a parent of + /// this collection. + /// + /// # Examples + /// + /// ``` + /// # use firestore_database::reference::*; + /// # fn main() -> Result<(), InvalidReference> { + /// let parent: CollectionRef = "projects/p/databases/d/documents/parent".parse()?; + /// let child: CollectionRef = "projects/p/databases/d/documents/parent/doc/child".parse()?; + /// assert_eq!(child.strip_collection_prefix(&parent), Some("doc/child")); + /// # Ok(()) + /// # } + /// ``` + pub fn strip_collection_prefix(&self, col: &CollectionRef) -> Option<&str> { + let rest = self + .strip_root_prefix(&col.root_ref) + .strip_prefix(&*col.collection_id)? + .strip_prefix('/')?; + Some(rest) + } + + /// Returns the remaining canonical reference string when the given `doc` has been removed as + /// prefix. Can also be used to efficiently determine whether the given `doc` is a parent of + /// this collection. + /// + /// # Examples + /// + /// ``` + /// # use firestore_database::reference::*; + /// # fn main() -> Result<(), InvalidReference> { + /// let parent: DocumentRef = "projects/p/databases/d/documents/parent/doc".parse()?; + /// let child: CollectionRef = "projects/p/databases/d/documents/parent/doc/child".parse()?; + /// assert_eq!(child.strip_document_prefix(&parent), Some("child")); + /// # Ok(()) + /// # } + /// ``` + pub fn strip_document_prefix(&self, doc: &DocumentRef) -> Option<&str> { + let rest = self + .strip_collection_prefix(&doc.collection_ref)? + .strip_prefix(&*doc.document_id)? + .strip_prefix('/')?; + Some(rest) + } +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct DocumentRef { + pub collection_ref: CollectionRef, + pub document_id: DefaultAtom, +} + +impl FromStr for DocumentRef { + type Err = InvalidReference; + + fn from_str(s: &str) -> Result { + let rf: Ref = s.parse()?; + rf.as_document() + .cloned() + .ok_or_else(|| InvalidReference("document", s.to_string())) + } +} +impl Display for DocumentRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/{}", self.collection_ref, self.document_id,) + } +} + +impl DocumentRef { + pub fn new(collection_ref: CollectionRef, document_id: impl Into) -> Self { + Self { + collection_ref, + document_id: document_id.into(), + } + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use super::*; + + #[rstest] + #[case( + // With `/documents` suffix + "projects/my-demo-project/databases/database/documents", + "my-demo-project", + "database" + )] + #[case( + // Without `/documents` suffix + "projects/demo-project/databases/(default)", + "demo-project", + "(default)" + )] + fn parse_root_ref(#[case] input: &str, #[case] project_id: &str, #[case] database_id: &str) { + assert_eq!( + input.parse::().unwrap(), + Ref::Root(RootRef::new(project_id, database_id)) + ) + } + + #[rstest] + #[case( + "projects/proj/databases/database/documents/collection", + "proj", + "database", + "collection" + )] + #[case( + "projects/demo-project/databases/(default)/documents/root/doc/sub", + "demo-project", + "(default)", + "root/doc/sub" + )] + fn parse_collection_refs( + #[case] input: &str, + #[case] project_id: &str, + #[case] database_id: &str, + #[case] collection_id: &str, + ) { + assert_eq!( + input.parse::().unwrap(), + Ref::Collection(CollectionRef::new( + RootRef::new(project_id, database_id), + collection_id + )) + ); + } + + #[rstest] + #[case( + "projects/proj/databases/database/documents/collection/document", + "proj", + "database", + "collection", + "document" + )] + #[case( + "projects/demo-project/databases/(default)/documents/root/doc/sub/doc", + "demo-project", + "(default)", + "root/doc/sub", + "doc" + )] + fn parse_document_refs( + #[case] input: &str, + #[case] project_id: &str, + #[case] database_id: &str, + #[case] collection_id: &str, + #[case] doc_id: &str, + ) { + assert_eq!( + input.parse::().unwrap(), + Ref::Document(DocumentRef::new( + CollectionRef::new(RootRef::new(project_id, database_id), collection_id), + doc_id + )) + ); + } +} diff --git a/crates/firestore-database/src/database/transaction.rs b/crates/firestore-database/src/database/transaction.rs index 48ba673..d04a7c1 100644 --- a/crates/firestore-database/src/database/transaction.rs +++ b/crates/firestore-database/src/database/transaction.rs @@ -13,6 +13,7 @@ use tracing::instrument; use super::{ document::{OwnedDocumentContentsReadGuard, OwnedDocumentContentsWriteGuard}, + reference::DocumentRef, Database, }; @@ -97,14 +98,14 @@ impl Transaction { #[instrument(skip_all)] pub async fn read_doc( &self, - name: &DefaultAtom, + name: &DocumentRef, ) -> Result> { let mut guards = self.guards.lock().await; - if let Some(guard) = guards.get(name) { + if let Some(guard) = guards.get(&name.document_id) { return Ok(Arc::clone(guard)); } let guard = self.new_read_guard(name).await?.into(); - guards.insert(name.clone(), Arc::clone(&guard)); + guards.insert(name.document_id.clone(), Arc::clone(&guard)); Ok(guard) } @@ -114,10 +115,10 @@ impl Transaction { pub async fn take_write_guard( &self, - name: &DefaultAtom, + name: &DocumentRef, ) -> Result { let mut guards = self.guards.lock().await; - let read_guard = match guards.remove(name) { + let read_guard = match guards.remove(&name.document_id) { Some(guard) => Arc::into_inner(guard) .ok_or_else(|| Status::aborted("concurrent reads during txn commit in same txn"))?, None => self.new_read_guard(name).await?, @@ -125,7 +126,7 @@ impl Transaction { read_guard.upgrade().await } - async fn new_read_guard(&self, name: &DefaultAtom) -> Result { + async fn new_read_guard(&self, name: &DocumentRef) -> Result { self.database .upgrade() .ok_or_else(|| Status::aborted("database was dropped"))? diff --git a/crates/firestore-database/src/lib.rs b/crates/firestore-database/src/lib.rs index ba1ed30..6492b35 100644 --- a/crates/firestore-database/src/lib.rs +++ b/crates/firestore-database/src/lib.rs @@ -1,4 +1,4 @@ mod database; pub use database::*; #[macro_use] -mod utils; +pub mod utils; diff --git a/crates/googleapis/Cargo.toml b/crates/googleapis/Cargo.toml index 2a88bb0..c86a3fa 100644 --- a/crates/googleapis/Cargo.toml +++ b/crates/googleapis/Cargo.toml @@ -16,3 +16,4 @@ tonic-build = "0.11" [dev-dependencies] itertools = "0.12.1" +rstest = "0.18.2" diff --git a/crates/googleapis/src/timestamp_ext.rs b/crates/googleapis/src/timestamp_ext.rs index a944131..a0d2f6e 100644 --- a/crates/googleapis/src/timestamp_ext.rs +++ b/crates/googleapis/src/timestamp_ext.rs @@ -106,33 +106,46 @@ mod tests { use std::array; use itertools::Itertools; + use rstest::rstest; use time::macros::datetime; - use tonic::Status; use super::*; const TOKEN_LEN: usize = 16; + #[rstest] + #[case(vec![])] + // Using 127 to get a high number because the token is signed (0 would mean 1970-01-01, + // 255 results in a -1_i128 which would mean end of 1769-12-31) + #[case(vec![127; TOKEN_LEN])] + fn invalid_tokens(#[case] token: Vec) { + assert!(matches!( + Timestamp::from_token(token), + Err(InvalidTokenError) + )); + } + + #[rstest] + #[case(Timestamp { seconds: i64::MAX, nanos: 0 })] + #[case(Timestamp { seconds: i64::MIN, nanos: 0 })] + fn out_of_range_timestamps(#[case] timestamp: Timestamp) { + println!("{:?}", OffsetDateTime::try_from(×tamp)); + assert!(matches!( + OffsetDateTime::try_from(×tamp), + Err(TimestampOutOfRangeError(_)) + )); + } + #[test] fn tonic_status_compat() { - let invalid_tokens = [ - Timestamp::from_token(vec![]).unwrap_err(), // Invalid length - // Using 127 to get a high number because the token is signed (0 would mean 1970-01-01, - // 255 results in a -1_i128 which would mean end of 1769-12-31) - Timestamp::from_token(vec![127; TOKEN_LEN]).unwrap_err(), // Out of range contents - ]; - for invalid_token in invalid_tokens { - Status::invalid_argument(invalid_token.clone()); - assert_eq!(invalid_token.to_string(), "Invalid token"); - } + assert_eq!(String::from(InvalidTokenError), "Invalid token"); let out_of_range = OffsetDateTime::try_from(&Timestamp { seconds: i64::MAX, nanos: 0, }) .unwrap_err(); - Status::invalid_argument(out_of_range.clone()); - assert_eq!(out_of_range.to_string(), "Timestamp out of range"); + assert_eq!(String::from(out_of_range), "Timestamp out of range"); } #[test] diff --git a/src/emulator.rs b/src/emulator.rs index 5a9bd1f..a618d6b 100644 --- a/src/emulator.rs +++ b/src/emulator.rs @@ -1,7 +1,11 @@ -use std::sync::Arc; +use std::{collections::HashMap, ops::Deref, sync::Arc}; use firestore_database::{ - event::DatabaseEvent, get_doc_name_from_write, Database, ReadConsistency, + event::DatabaseEvent, + get_doc_name_from_write, + reference::{DocumentRef, Ref, RootRef}, + utils::RwLockHashMapExt, + Database, ReadConsistency, }; use futures::{future::try_join_all, stream::BoxStream, StreamExt}; use googleapis::google::{ @@ -14,7 +18,7 @@ use googleapis::google::{ }; use itertools::Itertools; use string_cache::DefaultAtom; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, RwLock}; use tokio_stream::wrappers::ReceiverStream; use tonic::{async_trait, Code, Request, Response, Result, Status}; use tracing::{info, info_span, instrument, Instrument}; @@ -22,7 +26,7 @@ use tracing::{info, info_span, instrument, Instrument}; use crate::{unimplemented, unimplemented_bool, unimplemented_collection, unimplemented_option}; pub struct FirestoreEmulator { - pub database: Arc, + databases: RwLock>>, } impl std::fmt::Debug for FirestoreEmulator { @@ -34,44 +38,22 @@ impl std::fmt::Debug for FirestoreEmulator { impl FirestoreEmulator { pub fn new() -> Self { Self { - database: Database::new(), + databases: Default::default(), } } - async fn eval_command(&self, writes: &[Write]) -> Result> { - let [ - Write { - operation: Some(write::Operation::Update(update)), - .. - }, - ] = writes - else { - return Ok(None); - }; - let Some(command_name) = command_name(&update.name) else { - return Ok(None); - }; - info!(command_name, "received command"); - match command_name { - "CLEAR_EMULATOR" => { - self.database.clear().await; - Ok(Some(WriteResult { - update_time: Some(Timestamp::now()), - transform_results: vec![], - })) - } - _ => Err(Status::invalid_argument(format!( - "Unknown COMMAND: {command_name}" - ))), - } - } -} + // pub async fn clear(&mut self) { + // self.databases.write().await.clear(); + // } -fn command_name(doc_name: &str) -> Option<&str> { - let path = doc_name.strip_prefix("projects/")?; - let (_project_id, path) = path.split_once("/databases/")?; - let (_database_name, path) = path.split_once("/documents/")?; - path.strip_prefix("__COMMANDS__/") + pub async fn database(&self, name: &RootRef) -> Arc { + Arc::clone( + self.databases + .get_or_insert(&name.database_id, || Database::new(name.clone())) + .await + .deref(), + ) + } } #[async_trait] @@ -89,9 +71,12 @@ impl firestore_server::Firestore for FirestoreEmulator { } = request.into_inner(); unimplemented_option!(mask); + let name: DocumentRef = name.parse()?; + let doc = self - .database - .get_doc(&name.into(), &consistency_selector.try_into()?) + .database(&name.collection_ref.root_ref) + .await + .get_doc(&name, &consistency_selector.try_into()?) .await? .ok_or_else(|| Status::not_found(Code::NotFound.description()))?; Ok(Response::new(doc)) @@ -113,20 +98,26 @@ impl firestore_server::Firestore for FirestoreEmulator { request: Request, ) -> Result> { let BatchGetDocumentsRequest { - database: _, + database, documents, mask, consistency_selector, } = request.into_inner(); unimplemented_option!(mask); + let database = self.database(&database.parse()?).await; + let documents: Vec<_> = documents + .into_iter() + .map(|name| name.parse::()) + .try_collect()?; + // Only used for new transactions. let (mut new_transaction, read_consistency) = match consistency_selector { Some(batch_get_documents_request::ConsistencySelector::NewTransaction( transaction_options, )) => { unimplemented_option!(transaction_options.mode); - let id = self.database.new_txn().await?; + let id = database.new_txn().await?; info!("started new transaction"); (Some(id.into()), ReadConsistency::Transaction(id)) } @@ -135,19 +126,15 @@ impl firestore_server::Firestore for FirestoreEmulator { info!(?read_consistency); let (tx, rx) = mpsc::channel(16); - let database = Arc::clone(&self.database); tokio::spawn( async move { for name in documents { use batch_get_documents_response::Result::*; - let msg = match database - .get_doc(&DefaultAtom::from(&*name), &read_consistency) - .await - { + let msg = match database.get_doc(&name, &read_consistency).await { Ok(doc) => Ok(BatchGetDocumentsResponse { result: Some(match doc { - None => Missing(name), - Some(doc) => Found(Document::clone(&doc)), + None => Missing(name.to_string()), + Some(doc) => Found(doc), }), read_time: Some(Timestamp::now()), transaction: new_transaction.take().unwrap_or_default(), @@ -172,24 +159,19 @@ impl firestore_server::Firestore for FirestoreEmulator { ), err)] async fn commit(&self, request: Request) -> Result> { let CommitRequest { - database: _, + database, writes, transaction, } = request.into_inner(); - if let Some(write_result) = self.eval_command(&writes).await? { - return Ok(Response::new(CommitResponse { - write_results: vec![write_result], - commit_time: Some(Timestamp::now()), - })); - } + let database = self.database(&database.parse()?).await; let (commit_time, write_results) = if transaction.is_empty() { - perform_writes(&self.database, writes).await? + perform_writes(database.as_ref(), writes).await? } else { let txn_id = transaction.try_into()?; info!(?txn_id); - self.database.commit(writes, &txn_id).await? + database.commit(writes, &txn_id).await? }; Ok(Response::new(CommitResponse { @@ -232,8 +214,10 @@ impl firestore_server::Firestore for FirestoreEmulator { unimplemented!("page_size"); } - let documents = self - .database + let parent: Ref = parent.parse()?; + let database = self.database(parent.root()).await; + + let documents = database .run_query( parent, StructuredQuery { @@ -289,13 +273,14 @@ impl firestore_server::Firestore for FirestoreEmulator { &self, request: Request, ) -> Result> { - let BeginTransactionRequest { - database: _, - options, - } = request.into_inner(); + let BeginTransactionRequest { database, options } = request.into_inner(); + + let database = self.database(&database.parse()?).await; + use transaction_options::Mode; + let txn_id = match options { - None => self.database.new_txn().await?, + None => database.new_txn().await?, Some(TransactionOptions { mode: None | Some(Mode::ReadOnly(_)), }) => { @@ -305,7 +290,7 @@ impl firestore_server::Firestore for FirestoreEmulator { mode: Some(Mode::ReadWrite(ReadWrite { retry_transaction })), }) => { let id = retry_transaction.try_into()?; - self.database.new_txn_with_id(id).await?; + database.new_txn_with_id(id).await?; id } }; @@ -320,10 +305,11 @@ impl firestore_server::Firestore for FirestoreEmulator { #[instrument(skip_all, err)] async fn rollback(&self, request: Request) -> Result> { let RollbackRequest { - database: _, + database, transaction, } = request.into_inner(); - self.database.rollback(&transaction.try_into()?).await?; + let database = self.database(&database.parse()?).await; + database.rollback(&transaction.try_into()?).await?; Ok(Response::new(Empty {})) } @@ -349,8 +335,11 @@ impl firestore_server::Firestore for FirestoreEmulator { unimplemented!("query without query") }; + let parent: Ref = parent.parse()?; + let docs = self - .database + .database(parent.root()) + .await .run_query(parent, query, consistency_selector.try_into()?) .await?; @@ -426,7 +415,12 @@ impl firestore_server::Firestore for FirestoreEmulator { &self, request: Request>, ) -> Result> { - Ok(Response::new(self.database.listen(request.into_inner()))) + // TODO: refactor to be able to use database properly + Ok(Response::new( + self.database(&"projects/whatever/databases/(default)".parse()?) + .await + .listen(request.into_inner()), + )) } /// Lists all the collection IDs underneath a document. @@ -441,9 +435,11 @@ impl firestore_server::Firestore for FirestoreEmulator { page_token, consistency_selector, } = request.into_inner(); + let parent: DocumentRef = parent.parse()?; let collection_ids = self - .database - .get_collection_ids(&parent.into()) + .database(&parent.collection_ref.root_ref) + .await + .get_collection_ids(&parent) .await? .into_iter() .map(|name| name.to_string()) @@ -474,7 +470,7 @@ impl firestore_server::Firestore for FirestoreEmulator { request: Request, ) -> Result> { let BatchWriteRequest { - database: _, + database, writes, labels, } = request.into_inner(); @@ -482,12 +478,13 @@ impl firestore_server::Firestore for FirestoreEmulator { let time: Timestamp = Timestamp::now(); + let database = self.database(&database.parse()?).await; + let (status, write_results, updates): (Vec<_>, Vec<_>, Vec<_>) = try_join_all(writes.into_iter().map(|write| async { let name = get_doc_name_from_write(&write)?; - let mut guard = self.database.get_doc_meta_mut_no_txn(&name).await?; - let result = self - .database + let mut guard = database.get_doc_meta_mut_no_txn(&name).await?; + let result = database .perform_write(write, &mut guard, time.clone()) .await; use googleapis::google::rpc; @@ -508,7 +505,7 @@ impl firestore_server::Firestore for FirestoreEmulator { .into_iter() .multiunzip(); - self.database.send_event(DatabaseEvent { + database.send_event(DatabaseEvent { update_time: time.clone(), updates: updates.into_iter().flatten().collect(), });