Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use 'update_transforms' #219

Merged
merged 7 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 26 additions & 44 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,9 @@ def get_update_pb(

return update_pb

def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
def get_field_transform_pbs(
self, document_path
) -> List[types.write.DocumentTransform.FieldTransform]:
def make_array_value(values):
value_list = [encode_value(element) for element in values]
return document.ArrayValue(values=value_list)
Expand Down Expand Up @@ -559,9 +561,10 @@ def make_array_value(values):
for path, value in self.minimums.items()
]
)
field_transforms = [
transform for path, transform in sorted(path_field_transforms)
]
return [transform for path, transform in sorted(path_field_transforms)]

def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
field_transforms = self.get_field_transform_pbs(document_path)
transform_pb = write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=field_transforms
Expand Down Expand Up @@ -592,19 +595,13 @@ def pbs_for_create(document_path, document_data) -> List[types.write.Write]:
if extractor.deleted_fields:
raise ValueError("Cannot apply DELETE_FIELD in a create request.")

write_pbs = []

# Conformance tests require skipping the 'update_pb' if the document
# contains only transforms.
if extractor.empty_document or extractor.set_fields:
write_pbs.append(extractor.get_update_pb(document_path, exists=False))
create_pb = extractor.get_update_pb(document_path, exists=False)

if extractor.has_transforms:
exists = None if write_pbs else False
transform_pb = extractor.get_transform_pb(document_path, exists)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
create_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [create_pb]


def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]:
Expand All @@ -627,15 +624,13 @@ def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write
"specifying 'merge=True' or 'merge=[field_paths]'."
)

# Conformance tests require send the 'update_pb' even if the document
# contains only transforms.
write_pbs = [extractor.get_update_pb(document_path)]
set_pb = extractor.get_update_pb(document_path)

if extractor.has_transforms:
transform_pb = extractor.get_transform_pb(document_path)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
set_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [set_pb]


class DocumentExtractorForMerge(DocumentExtractor):
Expand Down Expand Up @@ -799,19 +794,14 @@ def pbs_for_set_with_merge(
extractor.apply_merge(merge)

merge_empty = not document_data
allow_empty_mask = merge_empty or extractor.transform_paths

write_pbs = []

if extractor.has_updates or merge_empty:
write_pbs.append(
extractor.get_update_pb(document_path, allow_empty_mask=merge_empty)
)

set_pb = extractor.get_update_pb(document_path, allow_empty_mask=allow_empty_mask)
if extractor.transform_paths:
transform_pb = extractor.get_transform_pb(document_path)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
set_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [set_pb]


class DocumentExtractorForUpdate(DocumentExtractor):
Expand Down Expand Up @@ -876,22 +866,14 @@ def pbs_for_update(document_path, field_updates, option) -> List[types.write.Wri
if option is None: # Default is to use ``exists=True``.
option = ExistsOption(exists=True)

write_pbs = []

if extractor.field_paths or extractor.deleted_fields:
update_pb = extractor.get_update_pb(document_path)
option.modify_write(update_pb)
write_pbs.append(update_pb)
update_pb = extractor.get_update_pb(document_path)
option.modify_write(update_pb)

if extractor.has_transforms:
transform_pb = extractor.get_transform_pb(document_path)
if not write_pbs:
# NOTE: set the write option on the ``transform_pb`` only if there
# is no ``update_pb``
option.modify_write(transform_pb)
write_pbs.append(transform_pb)

return write_pbs
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
update_pb.update_transforms.extend(field_transform_pbs)

return [update_pb]


def pb_for_delete(document_path, option) -> types.write.Write:
Expand Down
158 changes: 86 additions & 72 deletions tests/unit/v1/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,38 @@ def test_get_update_pb_wo_exists_precondition(self):
self.assertEqual(update_pb.update.fields, encode_dict(document_data))
self.assertFalse(update_pb._pb.HasField("current_document"))

def test_get_field_transform_pbs_miss(self):
document_data = {"a": 1}
inst = self._make_one(document_data)
document_path = (
"projects/project-id/databases/(default)/" "documents/document-id"
)

field_transform_pbs = inst.get_field_transform_pbs(document_path)

self.assertEqual(field_transform_pbs, [])

def test_get_field_transform_pbs_w_server_timestamp(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM

document_data = {"a": SERVER_TIMESTAMP}
inst = self._make_one(document_data)
document_path = (
"projects/project-id/databases/(default)/" "documents/document-id"
)

field_transform_pbs = inst.get_field_transform_pbs(document_path)

self.assertEqual(len(field_transform_pbs), 1)
field_transform_pb = field_transform_pbs[0]
self.assertIsInstance(
field_transform_pb, write.DocumentTransform.FieldTransform
)
self.assertEqual(field_transform_pb.field_path, "a")
self.assertEqual(field_transform_pb.set_to_server_value, REQUEST_TIME_ENUM)

def test_get_transform_pb_w_server_timestamp_w_exists_precondition(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
Expand Down Expand Up @@ -1526,23 +1558,16 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
from google.cloud.firestore_v1.types import write
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

def _helper(self, do_transform=False, empty_val=False):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
Expand All @@ -1569,9 +1594,7 @@ def _helper(self, do_transform=False, empty_val=False):
expected_pbs = [update_pb]

if do_transform:
expected_pbs.append(
self._make_write_w_transform(document_path, fields=["butter"])
)
self._add_field_transforms(update_pb, fields=["butter"])

self.assertEqual(write_pbs, expected_pbs)

Expand Down Expand Up @@ -1603,23 +1626,16 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
from google.cloud.firestore_v1.types import write
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

def test_w_empty_document(self):
document_path = _make_ref_string(u"little", u"town", u"of", u"ham")
Expand All @@ -1640,8 +1656,8 @@ def test_w_only_server_timestamp(self):
write_pbs = self._call_fut(document_path, document_data)

update_pb = self._make_write_w_document(document_path)
transform_pb = self._make_write_w_transform(document_path, ["butter"])
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def _helper(self, do_transform=False, empty_val=False):
Expand Down Expand Up @@ -1669,9 +1685,7 @@ def _helper(self, do_transform=False, empty_val=False):
expected_pbs = [update_pb]

if do_transform:
expected_pbs.append(
self._make_write_w_transform(document_path, fields=["butter"])
)
self._add_field_transforms(update_pb, fields=["butter"])

self.assertEqual(write_pbs, expected_pbs)

Expand Down Expand Up @@ -1904,23 +1918,16 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
from google.cloud.firestore_v1.types import write
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

@staticmethod
def _update_document_mask(update_pb, field_paths):
Expand Down Expand Up @@ -1954,6 +1961,20 @@ def test_with_merge_field_wo_transform(self):
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_true_w_only_transform(self):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP

document_path = _make_ref_string(u"little", u"town", u"of", u"ham")
document_data = {"butter": SERVER_TIMESTAMP}

write_pbs = self._call_fut(document_path, document_data, merge=True)

update_pb = self._make_write_w_document(document_path)
self._update_document_mask(update_pb, field_paths=())
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_true_w_transform(self):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP

Expand All @@ -1966,8 +1987,8 @@ def test_with_merge_true_w_transform(self):

update_pb = self._make_write_w_document(document_path, **update_data)
self._update_document_mask(update_pb, field_paths=sorted(update_data))
transform_pb = self._make_write_w_transform(document_path, fields=["butter"])
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_field_w_transform(self):
Expand All @@ -1986,8 +2007,8 @@ def test_with_merge_field_w_transform(self):
document_path, cheese=document_data["cheese"]
)
self._update_document_mask(update_pb, ["cheese"])
transform_pb = self._make_write_w_transform(document_path, fields=["butter"])
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_field_w_transform_masking_simple(self):
Expand All @@ -2001,10 +2022,9 @@ def test_with_merge_field_w_transform_masking_simple(self):
write_pbs = self._call_fut(document_path, document_data, merge=["butter.pecan"])

update_pb = self._make_write_w_document(document_path)
transform_pb = self._make_write_w_transform(
document_path, fields=["butter.pecan"]
)
expected_pbs = [update_pb, transform_pb]
self._update_document_mask(update_pb, field_paths=())
self._add_field_transforms(update_pb, fields=["butter.pecan"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_field_w_transform_parent(self):
Expand All @@ -2023,10 +2043,8 @@ def test_with_merge_field_w_transform_parent(self):
document_path, cheese=update_data["cheese"], butter={"popcorn": "yum"}
)
self._update_document_mask(update_pb, ["cheese", "butter"])
transform_pb = self._make_write_w_transform(
document_path, fields=["butter.pecan"]
)
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter.pecan"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)


Expand Down Expand Up @@ -2134,23 +2152,19 @@ def _helper(self, option=None, do_transform=False, **write_kwargs):
if isinstance(option, _helpers.ExistsOption):
precondition = common.Precondition(exists=False)
expected_update_pb._pb.current_document.CopyFrom(precondition._pb)
expected_pbs = [expected_update_pb]

if do_transform:
transform_paths = FieldPath.from_string(field_path2)
server_val = DocumentTransform.FieldTransform.ServerValue
expected_transform_pb = write.Write(
transform=write.DocumentTransform(
document=document_path,
field_transforms=[
write.DocumentTransform.FieldTransform(
field_path=transform_paths.to_api_repr(),
set_to_server_value=server_val.REQUEST_TIME,
)
],
field_transform_pbs = [
write.DocumentTransform.FieldTransform(
field_path=transform_paths.to_api_repr(),
set_to_server_value=server_val.REQUEST_TIME,
)
)
expected_pbs.append(expected_transform_pb)
self.assertEqual(write_pbs, expected_pbs)
]
expected_update_pb.update_transforms.extend(field_transform_pbs)

self.assertEqual(write_pbs, [expected_update_pb])

def test_without_option(self):
from google.cloud.firestore_v1.types import common
Expand Down
Loading