Skip to content

Commit

Permalink
fix(relocation): Use proper provenance for SAAS -> SAAS (#75355)
Browse files Browse the repository at this point in the history
This ensures that users are merged, rather than having new accounts
created.
  • Loading branch information
azaslavsky authored Jul 31, 2024
1 parent d2d039a commit 1acd0b6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/sentry/tasks/relocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ def importing(uuid: UUID) -> None:
flags=ImportFlags(
import_uuid=str(uuid),
hide_organizations=True,
merge_users=False,
merge_users=relocation.provenance == Relocation.Provenance.SAAS_TO_SAAS,
overwrite_configs=False,
),
org_filter=set(relocation.want_org_slugs),
Expand Down
82 changes: 70 additions & 12 deletions tests/sentry/tasks/test_relocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
unwrap_encrypted_export_tarball,
)
from sentry.backup.dependencies import NormalizedModelName, get_model_name
from sentry.backup.exports import export_in_organization_scope
from sentry.backup.helpers import ImportFlags, Printer
from sentry.backup.imports import import_in_organization_scope
from sentry.models.files.file import File
Expand Down Expand Up @@ -172,8 +173,14 @@ def file(self):

return file

def swap_file(
def swap_relocation_file_with_data_from_fixture(
self, file: File, fixture_name: str, blob_size: int = RELOCATION_BLOB_SIZE
) -> None:
with open(get_fixture_path("backup", fixture_name), "rb") as fp:
return self.swap_relocation_file(file, BytesIO(fp.read()), blob_size)

def swap_relocation_file(
self, file: File, contents: BytesIO, blob_size: int = RELOCATION_BLOB_SIZE
) -> None:
with TemporaryDirectory() as tmp_dir:
tmp_priv_key_path = Path(tmp_dir).joinpath("key")
Expand All @@ -182,13 +189,13 @@ def swap_file(
f.write(self.priv_key_pem)
with open(tmp_pub_key_path, "wb") as f:
f.write(self.pub_key_pem)
with open(get_fixture_path("backup", fixture_name)) as f:
data = json.load(f)
with open(tmp_pub_key_path, "rb") as p:
self.tarball = create_encrypted_export_tarball(
data, LocalFileEncryptor(p)
).getvalue()
file.putfile(BytesIO(self.tarball), blob_size=blob_size)

data = json.load(contents)
with open(tmp_pub_key_path, "rb") as p:
self.tarball = create_encrypted_export_tarball(
data, LocalFileEncryptor(p)
).getvalue()
file.putfile(BytesIO(self.tarball), blob_size=blob_size)

def mock_kms_client(self, fake_kms_client: FakeKeyManagementServiceClient):
fake_kms_client.asymmetric_decrypt.call_count = 0
Expand Down Expand Up @@ -803,7 +810,7 @@ def test_fail_invalid_json(
fake_kms_client: FakeKeyManagementServiceClient,
):
file = RelocationFile.objects.get(relocation=self.relocation).file
self.swap_file(file, "invalid-user.json")
self.swap_relocation_file_with_data_from_fixture(file, "invalid-user.json")
self.mock_message_builder(fake_message_builder)
self.mock_kms_client(fake_kms_client)

Expand All @@ -829,7 +836,7 @@ def test_fail_no_users(
fake_kms_client: FakeKeyManagementServiceClient,
):
file = RelocationFile.objects.get(relocation=self.relocation).file
self.swap_file(file, "single-option.json")
self.swap_relocation_file_with_data_from_fixture(file, "single-option.json")
self.mock_message_builder(fake_message_builder)
self.mock_kms_client(fake_kms_client)

Expand Down Expand Up @@ -880,7 +887,7 @@ def test_fail_no_orgs(
fake_kms_client: FakeKeyManagementServiceClient,
):
file = RelocationFile.objects.get(relocation=self.relocation).file
self.swap_file(file, "user-with-minimum-privileges.json")
self.swap_relocation_file_with_data_from_fixture(file, "user-with-minimum-privileges.json")
self.mock_message_builder(fake_message_builder)
self.mock_kms_client(fake_kms_client)

Expand Down Expand Up @@ -1984,7 +1991,7 @@ def setUp(self):
self.relocation.latest_task = OrderedTask.VALIDATING_COMPLETE.name
self.relocation.save()

def test_success(
def test_success_self_hosted(
self, postprocessing_mock: Mock, fake_kms_client: FakeKeyManagementServiceClient
):
self.mock_kms_client(fake_kms_client)
Expand Down Expand Up @@ -2021,6 +2028,57 @@ def test_success(
"sentry.useremail",
]

def test_success_saas_to_saas(
self, postprocessing_mock: Mock, fake_kms_client: FakeKeyManagementServiceClient
):
org_count = Organization.objects.filter(slug__startswith="testing").count()
with assume_test_silo_mode(SiloMode.CONTROL):
user_count = User.objects.all().count()

# Export the existing state of the `testing` organization, so that we retain exact ids.
export_contents = BytesIO()
export_in_organization_scope(
export_contents,
org_filter=set(self.relocation.want_org_slugs),
printer=Printer(),
)
export_contents.seek(0)

# Convert this into a `SAAS_TO_SAAS` relocation, and use the data we just exported as the
# import blob.
file = RelocationFile.objects.get(relocation=self.relocation).file
self.swap_relocation_file(file, export_contents)
self.mock_kms_client(fake_kms_client)
self.relocation.provenance = Relocation.Provenance.SAAS_TO_SAAS
self.relocation.save()

# Now, try importing again, which should enable user merging.
importing(self.uuid)

with assume_test_silo_mode(SiloMode.CONTROL):
# User counts should NOT change, since `merge_users` should be enabled.
assert User.objects.all().count() == user_count
common_user = User.objects.get(username="existing_org_owner@example.com")

# The existing user should now be in both orgs.
assert OrganizationMember.objects.filter(user_id=common_user.id).count() == 2

assert postprocessing_mock.call_count == 1
assert Organization.objects.filter(slug__startswith="testing").count() == org_count + 1
assert (
Organization.objects.filter(
slug__startswith="testing", status=OrganizationStatus.RELOCATION_PENDING_APPROVAL
).count()
== 1
)

with assume_test_silo_mode(SiloMode.CONTROL):
assert ControlImportChunk.objects.filter(import_uuid=self.uuid).count() == 1
assert sorted(ControlImportChunk.objects.values_list("model", flat=True)) == [
"sentry.user",
# We don't overwrite `sentry.useremail`, retaining the existing value instead.
]

def test_pause(
self,
postprocessing_mock: Mock,
Expand Down

0 comments on commit 1acd0b6

Please sign in to comment.