Skip to content

Commit

Permalink
Merge pull request #37 from skunkteam/update-multiple-docs-in-single-txn
Browse files Browse the repository at this point in the history
feature: implement mask projection for all APIs and test multiple updates to a single document in transactions
  • Loading branch information
pavadeli authored Mar 14, 2024
2 parents 5dddf91 + 7ab3a30 commit bf99748
Show file tree
Hide file tree
Showing 13 changed files with 175 additions and 79 deletions.
2 changes: 0 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
- common-document-store-backend
- UNIMPLEMENTED: run_aggregation_query is not supported yet
- wao-claim-appengine-backend
- UNIMPLEMENTED: mask is not supported yet
- UNIMPLEMENTED: target_id should always be 1 is not supported yet
- polis-soc-reg-backend - ???
10 changes: 6 additions & 4 deletions crates/emulator-grpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{mem, sync::Arc};
use firestore_database::{
event::DatabaseEvent,
get_doc_name_from_write,
projection::{Project, Projection},
read_consistency::ReadConsistency,
reference::{DocumentRef, Ref},
FirestoreDatabase, FirestoreProject,
Expand Down Expand Up @@ -59,18 +60,19 @@ impl firestore_server::Firestore for FirestoreEmulator {
mask,
consistency_selector,
} = request.into_inner();
unimplemented_option!(mask);

let name: DocumentRef = name.parse()?;

let projection = mask.map(Projection::try_from).transpose()?;

let doc = self
.project
.database(&name.collection_ref.root_ref)
.await
.get_doc(&name, &consistency_selector.try_into()?)
.await?
.ok_or_else(|| Status::not_found(Code::NotFound.description()))?;
Ok(Response::new(doc))
Ok(Response::new(projection.project(&doc)))
}

/// Server streaming response type for the BatchGetDocuments method.
Expand All @@ -94,13 +96,13 @@ impl firestore_server::Firestore for FirestoreEmulator {
mask,
consistency_selector,
} = request.into_inner();
unimplemented_option!(mask);

let database = self.project.database(&database.parse()?).await;
let documents: Vec<_> = documents
.into_iter()
.map(|name| name.parse::<DocumentRef>())
.try_collect()?;
let projection = mask.map(Projection::try_from).transpose()?;

let (
// Only used for new transactions.
Expand All @@ -125,7 +127,7 @@ impl firestore_server::Firestore for FirestoreEmulator {
Ok(doc) => Ok(BatchGetDocumentsResponse {
result: Some(match doc {
None => Missing(name.to_string()),
Some(doc) => Found(doc),
Some(doc) => Found(projection.project(&doc)),
}),
read_time: Some(Timestamp::now()),
transaction: mem::take(&mut new_transaction),
Expand Down
2 changes: 1 addition & 1 deletion crates/emulator-ui/src/routes/emulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn get_by_ref(
.into_response()),
Ref::Document(r) => Ok(Json(json!({
"type": "document",
"document": database.get_doc(&r, &ReadConsistency::Default).await?,
"document": database.get_doc(&r, &ReadConsistency::Default).await?.map(|d| d.to_document()),
"collections": database.get_collection_ids(&Ref::Document(r)).await?,
}))
.into_response()),
Expand Down
27 changes: 17 additions & 10 deletions crates/firestore-database/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{
collections::{hash_map::Entry, HashMap, HashSet},
ops::Deref,
sync::{Arc, Weak},
};

Expand All @@ -25,7 +24,10 @@ use tracing::{info, instrument, Span};

use self::{
collection::Collection,
document::{DocumentContents, DocumentMeta, DocumentVersion, OwnedDocumentContentsWriteGuard},
document::{
DocumentContents, DocumentMeta, DocumentVersion, OwnedDocumentContentsWriteGuard,
StoredDocumentVersion,
},
event::DatabaseEvent,
field_path::FieldPath,
query::Query,
Expand All @@ -42,6 +44,7 @@ mod collection;
pub(crate) mod document;
pub mod event;
mod field_path;
pub mod projection;
pub(crate) mod query;
pub mod read_consistency;
pub mod reference;
Expand Down Expand Up @@ -78,20 +81,20 @@ impl FirestoreDatabase {
&self,
name: &DocumentRef,
consistency: &ReadConsistency,
) -> Result<Option<Document>> {
) -> Result<Option<Arc<StoredDocumentVersion>>> {
info!(%name);
let version = if let Some(txn) = self.get_txn_for_consistency(consistency).await? {
txn.read_doc(name)
.await?
.version_for_consistency(consistency)?
.map(|version| version.to_document())
.map(Arc::clone)
} else {
self.get_doc_meta(name)
.await?
.read()
.await?
.version_for_consistency(consistency)?
.map(|version| version.to_document())
.map(Arc::clone)
};
Span::current().record("found", version.is_some());
Ok(version)
Expand Down Expand Up @@ -230,10 +233,10 @@ impl FirestoreDatabase {
info!(?query);
let result = query.once(self).await?;
info!(result_count = result.len());
result
Ok(result
.into_iter()
.map(|version| query.project(&version))
.try_collect()
.collect())
}

#[instrument(skip_all, err)]
Expand Down Expand Up @@ -394,7 +397,7 @@ impl FirestoreDatabase {
operation,
} = write;

let operation = operation.ok_or(GenericDatabaseError::not_implemented(
let operation = operation.ok_or(GenericDatabaseError::invalid_argument(
"missing operation in write",
))?;
let condition = current_document
Expand All @@ -408,6 +411,8 @@ impl FirestoreDatabase {
use write::Operation::*;
let document_version = match operation {
Update(doc) => {
info!(name = %contents.name, "Update");

let mut fields = if let Some(mask) = update_mask {
apply_updates(contents, mask, &doc.fields)?
} else {
Expand All @@ -428,6 +433,8 @@ impl FirestoreDatabase {
contents.add_version(fields, commit_time.clone()).await
}
Delete(_) => {
info!(name = %contents.name, "Delete");

unimplemented_option!(update_mask);
unimplemented_collection!(update_transforms);
contents.delete(commit_time.clone()).await
Expand Down Expand Up @@ -495,7 +502,7 @@ fn apply_updates(
.map(|v| v.fields.clone())
.unwrap_or_default();
for field_path in mask.field_paths {
let field_path: FieldPath = field_path.deref().try_into()?;
let field_path: FieldPath = field_path.parse()?;
match field_path.get_value(updated_values) {
Some(new_value) => field_path.set_value(&mut fields, new_value.clone()),
None => {
Expand All @@ -512,7 +519,7 @@ fn apply_transform(
transform: TransformType,
commit_time: &Timestamp,
) -> Result<Value> {
let field_path: FieldPath = path.deref().try_into()?;
let field_path: FieldPath = path.parse()?;
let result = match transform {
TransformType::SetToServerValue(code) => {
match ServerValue::try_from(code).map_err(|_| {
Expand Down
13 changes: 12 additions & 1 deletion crates/firestore-database/src/database/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio::{
},
time::{error::Elapsed, timeout},
};
use tracing::{instrument, trace, Level};
use tracing::{info, instrument, trace, Level};

use super::{read_consistency::ReadConsistency, reference::DocumentRef};
use crate::{error::Result, GenericDatabaseError};
Expand Down Expand Up @@ -182,6 +182,16 @@ impl DocumentContents {
update_time,
fields,
}));
if let Some(last) = self.versions.last_mut() {
if last.update_time() == version.update_time() {
last.clone_from(&version);
return version;
}
assert!(
last.update_time() < version.update_time(),
"update or commit time earlier than last version"
);
}
self.versions.push(version.clone());
version
}
Expand Down Expand Up @@ -209,6 +219,7 @@ pub type OwnedDocumentContentsWriteGuard = OwnedRwLockWriteGuard<DocumentContent
impl OwnedDocumentContentsReadGuard {
#[instrument(skip_all, err)]
pub async fn upgrade(self) -> Result<OwnedDocumentContentsWriteGuard> {
info!(name = %self.meta.name);
let check_time = self.guard.last_updated();
let OwnedDocumentContentsReadGuard {
meta,
Expand Down
18 changes: 9 additions & 9 deletions crates/firestore-database/src/database/field_path.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Cow, collections::HashMap, convert::Infallible, mem::take, ops::Deref};
use std::{borrow::Cow, collections::HashMap, convert::Infallible, mem::take, str::FromStr};

use googleapis::google::firestore::v1::*;

Expand Down Expand Up @@ -35,17 +35,17 @@ impl TryFrom<&structured_query::FieldReference> for FieldReference {
type Error = GenericDatabaseError;

fn try_from(value: &structured_query::FieldReference) -> Result<Self, Self::Error> {
value.field_path.deref().try_into()
value.field_path.parse()
}
}

impl TryFrom<&str> for FieldReference {
type Error = GenericDatabaseError;
impl FromStr for FieldReference {
type Err = GenericDatabaseError;

fn try_from(path: &str) -> Result<Self, Self::Error> {
fn from_str(path: &str) -> Result<Self, Self::Err> {
match path {
DOC_NAME => Ok(Self::DocumentName),
path => Ok(Self::FieldPath((path.try_into())?)),
path => Ok(Self::FieldPath((path.parse())?)),
}
}
}
Expand Down Expand Up @@ -124,10 +124,10 @@ impl FieldPath {
}
}

impl TryFrom<&str> for FieldPath {
type Error = GenericDatabaseError;
impl FromStr for FieldPath {
type Err = GenericDatabaseError;

fn try_from(path: &str) -> Result<Self, Self::Error> {
fn from_str(path: &str) -> Result<Self, Self::Err> {
if path.is_empty() {
return Err(GenericDatabaseError::invalid_argument(
"invalid empty field path",
Expand Down
63 changes: 63 additions & 0 deletions crates/firestore-database/src/database/projection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use googleapis::google::firestore::v1::{structured_query, Document, DocumentMask};
use itertools::Itertools;

use super::field_path::FieldReference;
use crate::{document::StoredDocumentVersion, GenericDatabaseError};

#[derive(Debug)]
pub struct Projection {
fields: Vec<FieldReference>,
}

pub trait Project {
fn project(&self, version: &StoredDocumentVersion) -> Document;
}

impl Project for Projection {
fn project(&self, version: &StoredDocumentVersion) -> Document {
let mut doc = Document {
fields: Default::default(),
create_time: Some(version.create_time.clone()),
update_time: Some(version.update_time.clone()),
name: version.name.to_string(),
};
for field in &self.fields {
match field {
FieldReference::DocumentName => continue,
FieldReference::FieldPath(path) => {
if let Some(val) = path.get_value(&version.fields) {
path.set_value(&mut doc.fields, val.clone());
}
}
}
}
doc
}
}

impl Project for Option<Projection> {
fn project(&self, version: &StoredDocumentVersion) -> Document {
match self {
Some(projection) => projection.project(version),
None => version.to_document(),
}
}
}

impl TryFrom<structured_query::Projection> for Projection {
type Error = GenericDatabaseError;

fn try_from(value: structured_query::Projection) -> Result<Self, Self::Error> {
let fields = value.fields.iter().map(TryInto::try_into).try_collect()?;
Ok(Self { fields })
}
}

impl TryFrom<DocumentMask> for Projection {
type Error = GenericDatabaseError;

fn try_from(value: DocumentMask) -> Result<Self, Self::Error> {
let fields = value.field_paths.iter().map(|s| s.parse()).try_collect()?;
Ok(Self { fields })
}
}
Loading

0 comments on commit bf99748

Please sign in to comment.