diff --git a/api_app/_version.py b/api_app/_version.py index 4151b9f93e..496906afb5 100644 --- a/api_app/_version.py +++ b/api_app/_version.py @@ -1 +1 @@ -__version__ = "0.5.16" +__version__ = "0.5.17" diff --git a/api_app/api/routes/costs.py b/api_app/api/routes/costs.py index ef1fcb13fc..1505949735 100644 --- a/api_app/api/routes/costs.py +++ b/api_app/api/routes/costs.py @@ -1,6 +1,7 @@ from datetime import datetime from dateutil.relativedelta import relativedelta from fastapi import APIRouter, Depends, Query, HTTPException, status +from fastapi.responses import JSONResponse import logging from typing import Optional @@ -16,8 +17,7 @@ from models.domain.costs import CostReport, GranularityEnum, WorkspaceCostReport from resources import strings from services.authentication import get_current_admin_user, get_current_workspace_owner_or_tre_admin -from services.cost_service import CostService, ServiceUnavailable, SubscriptionNotSupported, TooManyRequests, WorkspaceDoesNotExist -from starlette.responses import JSONResponse +from services.cost_service import CostService, ServiceUnavailable, SubscriptionNotSupported, TooManyRequests, WorkspaceDoesNotExist, cost_service_factory costs_core_router = APIRouter(dependencies=[Depends(get_current_admin_user)]) costs_workspace_router = APIRouter(dependencies=[Depends(get_current_workspace_owner_or_tre_admin)]) @@ -54,7 +54,7 @@ def __init__( responses=get_cost_report_responses()) async def costs( params: CostsQueryParams = Depends(), - cost_service=Depends(CostService), + cost_service: CostService = Depends(cost_service_factory), workspace_repo=Depends(get_repository(WorkspaceRepository)), shared_services_repo=Depends(get_repository(SharedServiceRepository))) -> CostReport: @@ -80,7 +80,7 @@ async def costs( }}, status_code=503, headers={"Retry-After": str(e.retry_after)}) except Exception as e: logging.error("Failed to query Azure TRE costs", exc_info=e) - raise e + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=strings.API_GET_COSTS_INTERNAL_SERVER_ERROR) @costs_workspace_router.get("/workspaces/{workspace_id}/costs", response_model=WorkspaceCostReport, @@ -88,7 +88,7 @@ async def costs( dependencies=[Depends(get_current_workspace_owner_or_tre_admin)], responses=get_workspace_cost_report_responses()) async def workspace_costs(workspace_id: UUID4, params: CostsQueryParams = Depends(), - cost_service=Depends(CostService), + cost_service: CostService = Depends(cost_service_factory), workspace_repo=Depends(get_repository(WorkspaceRepository)), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), user_resource_repo=Depends(get_repository(UserResourceRepository))) -> WorkspaceCostReport: @@ -117,5 +117,5 @@ async def workspace_costs(workspace_id: UUID4, params: CostsQueryParams = Depend "retry-after": str(e.retry_after) }}, status_code=503, headers={"Retry-After": str(e.retry_after)}) except Exception as e: - logging.error("Failed to query Azure TRE workspace costs", exc_info=e) - raise e + logging.error("Failed to query Azure TRE costs", exc_info=e) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=strings.API_GET_COSTS_INTERNAL_SERVER_ERROR) diff --git a/api_app/main.py b/api_app/main.py index f7fb62a583..bdc3cd5d98 100644 --- a/api_app/main.py +++ b/api_app/main.py @@ -87,5 +87,6 @@ async def watch_deployment_status() -> None: async def update_airlock_request_status() -> None: await receive_step_result_message_and_update_status(app) + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/api_app/models/schemas/costs.py b/api_app/models/schemas/costs.py index ecaa99f54a..39f3f5b046 100644 --- a/api_app/models/schemas/costs.py +++ b/api_app/models/schemas/costs.py @@ -38,7 +38,7 @@ def get_cost_report_responses(): "example": { "error": { "code": "429", - "message": "Too many requests to Azure cost management api. Please retry.", + "message": "Too many requests to Azure cost management API. Please retry.", "retry-after": "30" } } @@ -52,7 +52,7 @@ def get_cost_report_responses(): "example": { "error": { "code": "503", - "message": "Azure cost management api is temporarly unavaiable. Please retry.", + "message": "Azure cost management API is temporarly unavaiable. Please retry.", "retry-after": "30" } } @@ -98,7 +98,7 @@ def get_workspace_cost_report_responses(): "example": { "error": { "code": "429", - "message": "Too many requests to Azure cost management api. Please retry.", + "message": "Too many requests to Azure cost management API. Please retry.", "retry-after": "30" } } @@ -112,7 +112,7 @@ def get_workspace_cost_report_responses(): "example": { "error": { "code": "503", - "message": "Azure cost management api is temporarly unavaiable. Please retry.", + "message": "Azure cost management API is temporarly unavaiable. Please retry.", "retry-after": "30" } } diff --git a/api_app/resources/strings.py b/api_app/resources/strings.py index 8a261c9169..4e4eef03c7 100644 --- a/api_app/resources/strings.py +++ b/api_app/resources/strings.py @@ -72,8 +72,9 @@ API_GET_COSTS_TO_DATE_NEED_TO_BE_LATER_THEN_FROM_DATE = "to_date needs to be later than from_date" API_GET_COSTS_FROM_DATE_NEED_TO_BE_BEFORE_TO_DATE = "from_date needs to be before to_date" API_GET_COSTS_SUBSCRIPTION_NOT_SUPPORTED = "Azure subscription doesn't support cost management" -API_GET_COSTS_TOO_MANY_REQUESTS = "Too many requests to Azure cost management api. Please retry." -API_GET_COSTS_SERVICE_UNAVAILABLE = "Azure cost management api is temporarily unavailable. Please retry." +API_GET_COSTS_TOO_MANY_REQUESTS = "Too many requests to Azure cost management API. Please retry." +API_GET_COSTS_SERVICE_UNAVAILABLE = "Azure cost management API is temporarily unavailable. Please retry." +API_GET_COSTS_INTERNAL_SERVER_ERROR = "Failed to query Azure TRE costs." # State store status diff --git a/api_app/services/cost_service.py b/api_app/services/cost_service.py index fd36686ad7..23d4b1fa4c 100644 --- a/api_app/services/cost_service.py +++ b/api_app/services/cost_service.py @@ -1,6 +1,7 @@ -from datetime import datetime, date +from datetime import datetime, date, timedelta from enum import Enum -from typing import Dict, Optional +from functools import lru_cache +from typing import Dict, Optional, Union import pandas as pd import logging @@ -63,9 +64,22 @@ def __init__(self, retry_after: int, *args: object) -> None: self.retry_after = retry_after +class CostCacheItem(): + """Holds cost qery result and time to leave for storing in cache""" + result: QueryResult + ttl: datetime + + def __init__(self, item: QueryResult, ttl: datetime) -> None: + self.result = item + self.ttl = ttl + + +# make sure CostService is singleton +@lru_cache(maxsize=None) class CostService: scope: str client: CostManagementClient + cache: Dict[str, CostCacheItem] TRE_ID_TAG: str = "tre_id" TRE_CORE_SERVICE_ID_TAG: str = "tre_core_service_id" TRE_WORKSPACE_ID_TAG: str = "tre_workspace_id" @@ -76,10 +90,49 @@ class CostService: RATE_LIMIT_RETRY_AFTER_HEADER_KEY: str = "x-ms-ratelimit-microsoft.costmanagement-entity-retry-after" SERVICE_UNAVAILABLE_RETRY_AFTER_HEADER_KEY: str = "Retry-After" - def __init__(self): + def __init__(self) -> None: self.scope = "/subscriptions/{}".format(config.SUBSCRIPTION_ID) self.client = CostManagementClient(credential=credentials.get_credential()) self.resource_client = ResourceManagementClient(credentials.get_credential(), config.SUBSCRIPTION_ID) + self.cache = {} + + def get_cached_result(self, key: str) -> Union[QueryResult, None]: + """Returns cached item result. + + Args: + key (str): key of the cached item in cache. + Returns: + result (Union[QueryResult, None]): cost query result or None if not found or expired. + """ + cached_item: CostCacheItem = self.cache.get(key, None) + + # return None if key doesn't exist + if cached_item is None: + return None + + # return None if key expired + if (datetime.now() > cached_item.ttl): + # remove expired cache item + self.cache.pop(key) + return None + + return cached_item.result + + def clear_expired_cache_items(self) -> None: + """Clears all expired cache items.""" + expired_keys = [key for key in self.cache.keys() if datetime.now() > self.cache[key].ttl] + for key in expired_keys: + self.cache.pop(key) + + def cache_result(self, key: str, result: QueryResult, timedelta: timedelta) -> None: + """Add cost result to cache. + + Args: + key (str) : key of the cached item in cache. + result (QueryResult) : cost query result to cache. + """ + self.cache[key] = CostCacheItem(result, datetime.now() + timedelta) + self.clear_expired_cache_items() def query_tre_costs(self, tre_id, granularity: GranularityEnum, from_date: datetime, to_date: datetime, workspace_repo: WorkspaceRepository, @@ -87,7 +140,12 @@ def query_tre_costs(self, tre_id, granularity: GranularityEnum, from_date: datet resource_groups_dict = self.get_resource_groups_by_tag(self.TRE_ID_TAG, tre_id) - query_result = self.query_costs(CostService.TRE_ID_TAG, tre_id, granularity, from_date, to_date, list(resource_groups_dict.keys())) + cache_key = f"{CostService.TRE_ID_TAG}_{tre_id}_granularity{granularity}_from_date{from_date}_to_date{to_date}_rgs{'_'.join(list(resource_groups_dict.keys()))}" + query_result = self.get_cached_result(cache_key) + + if query_result is None: + query_result = self.query_costs(CostService.TRE_ID_TAG, tre_id, granularity, from_date, to_date, list(resource_groups_dict.keys())) + self.cache_result(cache_key, query_result, timedelta(hours=2)) summerized_result = self.summerize_untagged(query_result, granularity, resource_groups_dict) @@ -112,7 +170,14 @@ def query_tre_workspace_costs(self, workspace_id: str, granularity: GranularityE user_resource_repo) -> WorkspaceCostReport: resource_groups_dict = self.get_resource_groups_by_tag(self.TRE_WORKSPACE_ID_TAG, workspace_id) - query_result = self.query_costs(CostService.TRE_WORKSPACE_ID_TAG, workspace_id, granularity, from_date, to_date, list(resource_groups_dict.keys())) + + cache_key = f"{CostService.TRE_WORKSPACE_ID_TAG}_{workspace_id}_granularity{granularity}_from_date{from_date}_to_date{to_date}_rgs{'_'.join(list(resource_groups_dict.keys()))}" + query_result = self.get_cached_result(cache_key) + + if query_result is None: + query_result = self.query_costs(CostService.TRE_WORKSPACE_ID_TAG, workspace_id, granularity, from_date, to_date, list(resource_groups_dict.keys())) + self.cache_result(cache_key, query_result, timedelta(hours=2)) + summerized_result = self.summerize_untagged(query_result, granularity, resource_groups_dict) query_result_dict = self.__query_result_to_dict(summerized_result, granularity) @@ -330,3 +395,8 @@ def __query_result_to_dict(self, query_result: list, granularity: GranularityEnu def __parse_cost_management_date_value(self, date_value: int): return datetime.strptime(str(date_value), "%Y%m%d").date() + + +@lru_cache(maxsize=None) +def cost_service_factory() -> CostService: + return CostService() diff --git a/api_app/tests_ma/test_services/test_cost_service.py b/api_app/tests_ma/test_services/test_cost_service.py index 513dabc325..763ff955a8 100644 --- a/api_app/tests_ma/test_services/test_cost_service.py +++ b/api_app/tests_ma/test_services/test_cost_service.py @@ -11,10 +11,18 @@ from azure.core.exceptions import ResourceNotFoundError +@pytest.fixture(autouse=True) +def clear_lru_cache(): + CostService.cache_clear() + yield + CostService.cache_clear() + + @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_granularity_none_returns_correct_cost_report(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, workspace_repo_mock): client_mock.return_value.query.usage.return_value = __get_cost_management_query_result() @@ -55,7 +63,8 @@ def test_query_tre_costs_with_granularity_none_returns_correct_cost_report(get_r @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_granularity_daily_returns_correct_cost_report( get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, workspace_repo_mock): client_mock.return_value.query.usage.return_value = __set_cost_management_client_mock_query_result() @@ -156,7 +165,8 @@ def __get_daily_cost_management_query_result(): @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_granularity_none_and_missing_costs_data_returns_empty_cost_report(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, @@ -191,7 +201,8 @@ def test_query_tre_costs_with_granularity_none_and_missing_costs_data_returns_em @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_for_unsupported_subscription_raises_subscription_not_supported_exception(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, @@ -218,7 +229,8 @@ def test_query_tre_costs_for_unsupported_subscription_raises_subscription_not_su @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_granularity_daily_and_missing_costs_data_returns_empty_cost_report(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, @@ -255,7 +267,8 @@ def test_query_tre_costs_with_granularity_daily_and_missing_costs_data_returns_e @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_granularity_none_and_display_name_data_returns_template_name_in_cost_report(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, @@ -298,7 +311,8 @@ def test_query_tre_costs_with_granularity_none_and_display_name_data_returns_tem @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_dates_set_as_none_calls_client_with_month_to_date(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, workspace_repo_mock, from_date, @@ -308,6 +322,7 @@ def test_query_tre_costs_with_dates_set_as_none_calls_client_with_month_to_date( __set_resource_group_by_tag_return_value(get_resource_groups_by_tag_mock) cost_service = CostService() + CostService.cache_clear() cost_service.query_tre_costs( "guy22", GranularityEnum.none, from_date, to_date, workspace_repo_mock, shared_service_repo_mock) @@ -318,7 +333,8 @@ def test_query_tre_costs_with_dates_set_as_none_calls_client_with_month_to_date( @patch('db.repositories.workspaces.WorkspaceRepository') @patch('db.repositories.shared_services.SharedServiceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_costs_with_dates_set_as_none_calls_client_with_custom_dates(get_resource_groups_by_tag_mock, client_mock, shared_service_repo_mock, workspace_repo_mock): @@ -427,7 +443,8 @@ def __set_user_resource_repo_mock_return_value(user_resource_repo_mock): @patch('db.repositories.workspace_services.WorkspaceServiceRepository') @patch('db.repositories.workspaces.WorkspaceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_workspace_costs_with_granularity_none_returns_correct_workspace_cost_report(get_resource_groups_by_tag_mock, client_mock, workspace_repo_mock, @@ -481,7 +498,8 @@ def test_query_tre_workspace_costs_with_granularity_none_returns_correct_workspa @patch('db.repositories.workspace_services.WorkspaceServiceRepository') @patch('db.repositories.workspaces.WorkspaceRepository') @patch('services.cost_service.CostManagementClient') -@patch('services.cost_service.CostService.get_resource_groups_by_tag') +# CostService is lru_cached which creates a wrapper method +@patch('services.cost_service.CostService.__wrapped__.get_resource_groups_by_tag') def test_query_tre_workspace_costs_with_granularity_daily_returns_correct_workspace_cost_report(get_resource_groups_by_tag_mock, client_mock, workspace_repo_mock,