Skip to content

Commit

Permalink
feat: use 'update_transforms' (#219)
Browse files Browse the repository at this point in the history
Update `pbs_for_create`, `pbs_for_set_no_merge`, `pbs_for_set_with_merge`, and `pbs_for_update` to match semantics expected by current versions of [conformance tests](googleapis/conformance-tests@0bb8520):

- Rather than create separate `Write.transform` messages to hold field transforms, inline them as `update_transforms` in the main `Write.update` message (which will always be created now).

Copy in the current version of the conftest JSON files and verify.

Closes #217
  • Loading branch information
tseaver authored Oct 10, 2020
1 parent 9b6c2f3 commit c122e41
Show file tree
Hide file tree
Showing 63 changed files with 1,184 additions and 1,406 deletions.
70 changes: 26 additions & 44 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,9 @@ def get_update_pb(

return update_pb

def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
def get_field_transform_pbs(
self, document_path
) -> List[types.write.DocumentTransform.FieldTransform]:
def make_array_value(values):
value_list = [encode_value(element) for element in values]
return document.ArrayValue(values=value_list)
Expand Down Expand Up @@ -559,9 +561,10 @@ def make_array_value(values):
for path, value in self.minimums.items()
]
)
field_transforms = [
transform for path, transform in sorted(path_field_transforms)
]
return [transform for path, transform in sorted(path_field_transforms)]

def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
field_transforms = self.get_field_transform_pbs(document_path)
transform_pb = write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=field_transforms
Expand Down Expand Up @@ -592,19 +595,13 @@ def pbs_for_create(document_path, document_data) -> List[types.write.Write]:
if extractor.deleted_fields:
raise ValueError("Cannot apply DELETE_FIELD in a create request.")

write_pbs = []

# Conformance tests require skipping the 'update_pb' if the document
# contains only transforms.
if extractor.empty_document or extractor.set_fields:
write_pbs.append(extractor.get_update_pb(document_path, exists=False))
create_pb = extractor.get_update_pb(document_path, exists=False)

if extractor.has_transforms:
exists = None if write_pbs else False
transform_pb = extractor.get_transform_pb(document_path, exists)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
create_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [create_pb]


def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]:
Expand All @@ -627,15 +624,13 @@ def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write
"specifying 'merge=True' or 'merge=[field_paths]'."
)

# Conformance tests require send the 'update_pb' even if the document
# contains only transforms.
write_pbs = [extractor.get_update_pb(document_path)]
set_pb = extractor.get_update_pb(document_path)

if extractor.has_transforms:
transform_pb = extractor.get_transform_pb(document_path)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
set_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [set_pb]


class DocumentExtractorForMerge(DocumentExtractor):
Expand Down Expand Up @@ -799,19 +794,14 @@ def pbs_for_set_with_merge(
extractor.apply_merge(merge)

merge_empty = not document_data
allow_empty_mask = merge_empty or extractor.transform_paths

write_pbs = []

if extractor.has_updates or merge_empty:
write_pbs.append(
extractor.get_update_pb(document_path, allow_empty_mask=merge_empty)
)

set_pb = extractor.get_update_pb(document_path, allow_empty_mask=allow_empty_mask)
if extractor.transform_paths:
transform_pb = extractor.get_transform_pb(document_path)
write_pbs.append(transform_pb)
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
set_pb.update_transforms.extend(field_transform_pbs)

return write_pbs
return [set_pb]


class DocumentExtractorForUpdate(DocumentExtractor):
Expand Down Expand Up @@ -876,22 +866,14 @@ def pbs_for_update(document_path, field_updates, option) -> List[types.write.Wri
if option is None: # Default is to use ``exists=True``.
option = ExistsOption(exists=True)

write_pbs = []

if extractor.field_paths or extractor.deleted_fields:
update_pb = extractor.get_update_pb(document_path)
option.modify_write(update_pb)
write_pbs.append(update_pb)
update_pb = extractor.get_update_pb(document_path)
option.modify_write(update_pb)

if extractor.has_transforms:
transform_pb = extractor.get_transform_pb(document_path)
if not write_pbs:
# NOTE: set the write option on the ``transform_pb`` only if there
# is no ``update_pb``
option.modify_write(transform_pb)
write_pbs.append(transform_pb)

return write_pbs
field_transform_pbs = extractor.get_field_transform_pbs(document_path)
update_pb.update_transforms.extend(field_transform_pbs)

return [update_pb]


def pb_for_delete(document_path, option) -> types.write.Write:
Expand Down
158 changes: 86 additions & 72 deletions tests/unit/v1/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,38 @@ def test_get_update_pb_wo_exists_precondition(self):
self.assertEqual(update_pb.update.fields, encode_dict(document_data))
self.assertFalse(update_pb._pb.HasField("current_document"))

def test_get_field_transform_pbs_miss(self):
document_data = {"a": 1}
inst = self._make_one(document_data)
document_path = (
"projects/project-id/databases/(default)/" "documents/document-id"
)

field_transform_pbs = inst.get_field_transform_pbs(document_path)

self.assertEqual(field_transform_pbs, [])

def test_get_field_transform_pbs_w_server_timestamp(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM

document_data = {"a": SERVER_TIMESTAMP}
inst = self._make_one(document_data)
document_path = (
"projects/project-id/databases/(default)/" "documents/document-id"
)

field_transform_pbs = inst.get_field_transform_pbs(document_path)

self.assertEqual(len(field_transform_pbs), 1)
field_transform_pb = field_transform_pbs[0]
self.assertIsInstance(
field_transform_pb, write.DocumentTransform.FieldTransform
)
self.assertEqual(field_transform_pb.field_path, "a")
self.assertEqual(field_transform_pb.set_to_server_value, REQUEST_TIME_ENUM)

def test_get_transform_pb_w_server_timestamp_w_exists_precondition(self):
from google.cloud.firestore_v1.types import write
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
Expand Down Expand Up @@ -1526,23 +1558,16 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
from google.cloud.firestore_v1.types import write
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

def _helper(self, do_transform=False, empty_val=False):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP
Expand All @@ -1569,9 +1594,7 @@ def _helper(self, do_transform=False, empty_val=False):
expected_pbs = [update_pb]

if do_transform:
expected_pbs.append(
self._make_write_w_transform(document_path, fields=["butter"])
)
self._add_field_transforms(update_pb, fields=["butter"])

self.assertEqual(write_pbs, expected_pbs)

Expand Down Expand Up @@ -1603,23 +1626,16 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
from google.cloud.firestore_v1.types import write
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

def test_w_empty_document(self):
document_path = _make_ref_string(u"little", u"town", u"of", u"ham")
Expand All @@ -1640,8 +1656,8 @@ def test_w_only_server_timestamp(self):
write_pbs = self._call_fut(document_path, document_data)

update_pb = self._make_write_w_document(document_path)
transform_pb = self._make_write_w_transform(document_path, ["butter"])
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def _helper(self, do_transform=False, empty_val=False):
Expand Down Expand Up @@ -1669,9 +1685,7 @@ def _helper(self, do_transform=False, empty_val=False):
expected_pbs = [update_pb]

if do_transform:
expected_pbs.append(
self._make_write_w_transform(document_path, fields=["butter"])
)
self._add_field_transforms(update_pb, fields=["butter"])

self.assertEqual(write_pbs, expected_pbs)

Expand Down Expand Up @@ -1904,23 +1918,16 @@ def _make_write_w_document(document_path, **data):
)

@staticmethod
def _make_write_w_transform(document_path, fields):
from google.cloud.firestore_v1.types import write
def _add_field_transforms(update_pb, fields):
from google.cloud.firestore_v1 import DocumentTransform

server_val = DocumentTransform.FieldTransform.ServerValue
transforms = [
write.DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
for field in fields
]

return write.Write(
transform=write.DocumentTransform(
document=document_path, field_transforms=transforms
for field in fields:
update_pb.update_transforms.append(
DocumentTransform.FieldTransform(
field_path=field, set_to_server_value=server_val.REQUEST_TIME
)
)
)

@staticmethod
def _update_document_mask(update_pb, field_paths):
Expand Down Expand Up @@ -1954,6 +1961,20 @@ def test_with_merge_field_wo_transform(self):
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_true_w_only_transform(self):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP

document_path = _make_ref_string(u"little", u"town", u"of", u"ham")
document_data = {"butter": SERVER_TIMESTAMP}

write_pbs = self._call_fut(document_path, document_data, merge=True)

update_pb = self._make_write_w_document(document_path)
self._update_document_mask(update_pb, field_paths=())
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_true_w_transform(self):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP

Expand All @@ -1966,8 +1987,8 @@ def test_with_merge_true_w_transform(self):

update_pb = self._make_write_w_document(document_path, **update_data)
self._update_document_mask(update_pb, field_paths=sorted(update_data))
transform_pb = self._make_write_w_transform(document_path, fields=["butter"])
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_field_w_transform(self):
Expand All @@ -1986,8 +2007,8 @@ def test_with_merge_field_w_transform(self):
document_path, cheese=document_data["cheese"]
)
self._update_document_mask(update_pb, ["cheese"])
transform_pb = self._make_write_w_transform(document_path, fields=["butter"])
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_field_w_transform_masking_simple(self):
Expand All @@ -2001,10 +2022,9 @@ def test_with_merge_field_w_transform_masking_simple(self):
write_pbs = self._call_fut(document_path, document_data, merge=["butter.pecan"])

update_pb = self._make_write_w_document(document_path)
transform_pb = self._make_write_w_transform(
document_path, fields=["butter.pecan"]
)
expected_pbs = [update_pb, transform_pb]
self._update_document_mask(update_pb, field_paths=())
self._add_field_transforms(update_pb, fields=["butter.pecan"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)

def test_with_merge_field_w_transform_parent(self):
Expand All @@ -2023,10 +2043,8 @@ def test_with_merge_field_w_transform_parent(self):
document_path, cheese=update_data["cheese"], butter={"popcorn": "yum"}
)
self._update_document_mask(update_pb, ["cheese", "butter"])
transform_pb = self._make_write_w_transform(
document_path, fields=["butter.pecan"]
)
expected_pbs = [update_pb, transform_pb]
self._add_field_transforms(update_pb, fields=["butter.pecan"])
expected_pbs = [update_pb]
self.assertEqual(write_pbs, expected_pbs)


Expand Down Expand Up @@ -2134,23 +2152,19 @@ def _helper(self, option=None, do_transform=False, **write_kwargs):
if isinstance(option, _helpers.ExistsOption):
precondition = common.Precondition(exists=False)
expected_update_pb._pb.current_document.CopyFrom(precondition._pb)
expected_pbs = [expected_update_pb]

if do_transform:
transform_paths = FieldPath.from_string(field_path2)
server_val = DocumentTransform.FieldTransform.ServerValue
expected_transform_pb = write.Write(
transform=write.DocumentTransform(
document=document_path,
field_transforms=[
write.DocumentTransform.FieldTransform(
field_path=transform_paths.to_api_repr(),
set_to_server_value=server_val.REQUEST_TIME,
)
],
field_transform_pbs = [
write.DocumentTransform.FieldTransform(
field_path=transform_paths.to_api_repr(),
set_to_server_value=server_val.REQUEST_TIME,
)
)
expected_pbs.append(expected_transform_pb)
self.assertEqual(write_pbs, expected_pbs)
]
expected_update_pb.update_transforms.extend(field_transform_pbs)

self.assertEqual(write_pbs, [expected_update_pb])

def test_without_option(self):
from google.cloud.firestore_v1.types import common
Expand Down
Loading

0 comments on commit c122e41

Please sign in to comment.