Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FT-874: Fix the unit test and integration test in Audit Search #455

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions tests/integration/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def create_glossary(client: AtlanClient, name: str) -> AtlasGlossary:
return r.assets_created(AtlasGlossary)[0]


@pytest.fixture(scope="module")
def audit_glossary(client: AtlanClient) -> Generator[AtlasGlossary, None, None]:
created_glossary = create_glossary(
client, TestId.make_unique("test-audit-glossary")
)
yield created_glossary
delete_asset(client, guid=created_glossary.guid, asset_type=AtlasGlossary)


@pytest.fixture(scope="module")
def sl_glossary(
client: AtlanClient,
Expand Down Expand Up @@ -626,64 +635,101 @@ def test_audit_find_by_user(
assert audit_entity != audit_entity_next_page


def _assert_audit_search_results(results, expected_sorts, size, total_count):
@pytest.fixture(scope="module")
def generate_audit_entries(client: AtlanClient, audit_glossary: AtlasGlossary):
log_count = 5
for i in range(log_count):
updater = AtlasGlossary.updater(
qualified_name=audit_glossary.qualified_name,
name=audit_glossary.name,
)
updater.description = f"Updated description {i+1}"
client.asset.save(updater)
time.sleep(1)

request = AuditSearchRequest.by_guid(guid=audit_glossary.guid, size=log_count)
response = client.audit.search(request)
assert (
response.total_count >= log_count
), f"Expected at least {log_count} logs, but got {response.total_count}."


def _assert_audit_search_results(
results, expected_sorts, size, TOTAL_AUDIT_ENTRIES, bulk=False
):
assert results.total_count > size
assert len(results.current_page()) == size
assert results.total_count == total_count
counter = 0
for audit in results:
assert audit
counter += 1
assert counter == TOTAL_AUDIT_ENTRIES
assert results
assert results._bulk is bulk
assert results._criteria.dsl.sort == expected_sorts


@pytest.mark.order(after="test_audit_find_by_user")
@patch.object(AUDIT_LOGGER, "debug")
def test_audit_search_pagination(
mock_logger, audit_info: AuditInfo, client: AtlanClient
mock_logger,
audit_glossary: AtlasGlossary,
generate_audit_entries,
client: AtlanClient,
):
size = 2

# Test audit search by GUID with default offset-based pagination
dsl = DSL(
query=Bool(filter=[Term(field="entityId", value=audit_info.guid)]),
query=Bool(filter=[Term(field="entityId", value=audit_glossary.guid)]),
sort=[],
size=size,
)
request = AuditSearchRequest(dsl=dsl)
results = client.audit.search(criteria=request, bulk=False)
total_count = results.total_count
TOTAL_AUDIT_ENTRIES = results.total_count
expected_sorts = [SortItem(field="entityId", order=SortOrder.ASCENDING)]
_assert_audit_search_results(results, expected_sorts, size, total_count)
_assert_audit_search_results(
results, expected_sorts, size, TOTAL_AUDIT_ENTRIES, False
)

# Test audit search by guid with `bulk` option using timestamp-based pagination
dsl = DSL(
query=Bool(filter=[Term(field="entityId", value=audit_glossary.guid)]),
sort=[],
size=size,
)
request = AuditSearchRequest(dsl=dsl)
results = client.audit.search(criteria=request, bulk=True)
total_count = results.total_count
expected_sorts = [
SortItem("created", order=SortOrder.ASCENDING),
SortItem(field="entityId", order=SortOrder.ASCENDING),
]
_assert_audit_search_results(results, expected_sorts, size, total_count)
_assert_audit_search_results(
results, expected_sorts, size, TOTAL_AUDIT_ENTRIES, True
)
assert mock_logger.call_count == 1
assert "Audit bulk search option is enabled." in mock_logger.call_args_list[0][0][0]
mock_logger.reset_mock()

# When the number of results exceeds the predefined
# threshold and bulk is true and no pre-defined sort.
with patch.object(AuditSearchResults, "_MASS_EXTRACT_THRESHOLD", -1):
username = client.user.get_current().username
assert username
dsl = DSL(
query=Bool(filter=[Term(field="user", value=username)]),
query=Bool(filter=[Term(field="entityId", value=audit_glossary.guid)]),
sort=[],
size=size,
)
request = AuditSearchRequest(dsl=dsl)
results = client.audit.search(criteria=request, bulk=True)
total_count = results.total_count
expected_sorts = [
SortItem("created", order=SortOrder.ASCENDING),
SortItem(field="entityId", order=SortOrder.ASCENDING),
]
_assert_audit_search_results(results, expected_sorts, size, total_count)
assert mock_logger.call_count < total_count
_assert_audit_search_results(
results, expected_sorts, size, TOTAL_AUDIT_ENTRIES, True
)
assert mock_logger.call_count < TOTAL_AUDIT_ENTRIES
assert (
"Audit bulk search option is enabled."
in mock_logger.call_args_list[0][0][0]
Expand All @@ -693,15 +739,22 @@ def test_audit_search_pagination(
# When the number of results exceeds the predefined threshold and bulk is `False` and no pre-defined sort.
# Then SDK automatically switches to a `bulk` search option using timestamp-based pagination
with patch.object(AuditSearchResults, "_MASS_EXTRACT_THRESHOLD", -1):
dsl = DSL(
query=Bool(filter=[Term(field="entityId", value=audit_glossary.guid)]),
sort=[],
size=size,
)
request = AuditSearchRequest(dsl=dsl)
results = client.audit.search(criteria=request, bulk=False)
total_count = results.total_count
results.total_count
expected_sorts = [
SortItem("created", order=SortOrder.ASCENDING),
SortItem(field="entityId", order=SortOrder.ASCENDING),
]
_assert_audit_search_results(results, expected_sorts, size, total_count)
assert mock_logger.call_count < total_count
_assert_audit_search_results(
results, expected_sorts, size, TOTAL_AUDIT_ENTRIES, False
)
assert mock_logger.call_count < TOTAL_AUDIT_ENTRIES
assert (
"Result size (%s) exceeds threshold (%s)."
in mock_logger.call_args_list[0][0][0]
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/data/search_responses/audit_search_paging.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"totalCount": 2,
"totalCount": 1,
"entityAudits": [
{
"entity_qualified_name": "sample_entity_1",
Expand All @@ -10,16 +10,6 @@
"user": "user1",
"action": "ENTITY_CREATE",
"eventKey": "guid_1:1733491479782"
},
{
"entity_qualified_name": "sample_entity_2",
"type_name": "AtlasGlossaryTerm",
"entity_id": "guid_2",
"timestamp": 1733491479783,
"created": 1733491480577,
"user": "user2",
"action": "ENTITY_UPDATE",
"eventKey": "guid_2:1733491479783"
}
]
}
44 changes: 36 additions & 8 deletions tests/unit/test_audit_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timezone
from json import load
from pathlib import Path
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -36,14 +37,23 @@ def load_json(filename):


def _assert_audit_search_results(results, response_json, sorts, bulk=False):
for i, result in enumerate(results.current_page()):
assert result
assert result.entity_id == response_json["entityAudits"][i]["entity_id"]
for audit in results:
assert audit.entity_id == response_json["entityAudits"][0]["entity_id"]
assert (
result.entity_qualified_name
== response_json["entityAudits"][i]["entity_qualified_name"]
audit.entity_qualified_name
== response_json["entityAudits"][0]["entity_qualified_name"]
)
assert result.type_name == response_json["entityAudits"][i]["type_name"]
assert audit.type_name == response_json["entityAudits"][0]["type_name"]
expected_timestamp = datetime.fromtimestamp(
response_json["entityAudits"][0]["timestamp"] / 1000, tz=timezone.utc
)
assert audit.timestamp == expected_timestamp
expected_created = datetime.fromtimestamp(
response_json["entityAudits"][0]["created"] / 1000, tz=timezone.utc
)
assert audit.created == expected_created
assert audit.user == response_json["entityAudits"][0]["user"]
assert audit.action == response_json["entityAudits"][0]["action"]

assert results.total_count == response_json["totalCount"]
assert results._bulk == bulk
Expand All @@ -65,18 +75,24 @@ def test_audit_search_pagination(
dsl = DSL(
query=Bool(filter=[Term(field="entityId", value="some-guid")]),
sort=[],
size=2,
size=1,
from_=0,
)
audit_search_request = AuditSearchRequest(dsl=dsl)
response = client.search(criteria=audit_search_request, bulk=False)
expected_sorts = [SortItem(field="entityId", order=SortOrder.ASCENDING)]

_assert_audit_search_results(response, audit_search_paging_json, expected_sorts)
assert mock_api_caller._call_api.call_count == 1
assert mock_api_caller._call_api.call_count == 3
assert mock_logger.call_count == 0
mock_api_caller.reset_mock()

# Test bulk pagination
mock_api_caller._call_api.side_effect = [
audit_search_paging_json,
audit_search_paging_json,
{},
]
audit_search_request = AuditSearchRequest(dsl=dsl)
response = client.search(criteria=audit_search_request, bulk=True)
expected_sorts = [
Expand All @@ -87,6 +103,14 @@ def test_audit_search_pagination(
_assert_audit_search_results(
response, audit_search_paging_json, expected_sorts, bulk=True
)
# The call count will be 2 because
# audit search entries are processed in the first API call.
# In the second API call, self._entity_audits
# becomes 0, which breaks the pagination.
# This differs from offset-based pagination
# where an additional API call is needed
# to verify if the results are empty
assert mock_api_caller._call_api.call_count == 2
assert mock_logger.call_count == 1
assert "Audit bulk search option is enabled." in mock_logger.call_args_list[0][0][0]
mock_logger.reset_mock()
Expand All @@ -95,6 +119,9 @@ def test_audit_search_pagination(
# Test automatic bulk search conversion when exceeding threshold
with patch.object(AuditSearchResults, "_MASS_EXTRACT_THRESHOLD", -1):
mock_api_caller._call_api.side_effect = [
# Extra call to re-fetch the first page
# results with updated timestamp sorting
audit_search_paging_json,
audit_search_paging_json,
audit_search_paging_json,
{},
Expand All @@ -105,6 +132,7 @@ def test_audit_search_pagination(
response, audit_search_paging_json, expected_sorts, bulk=False
)
assert mock_logger.call_count == 1
assert mock_api_caller._call_api.call_count == 3
assert (
"Result size (%s) exceeds threshold (%s)"
in mock_logger.call_args_list[0][0][0]
Expand Down
Loading