⚡️ Speed up function global_get_all_tag_names by 6%
#427
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 6% (0.06x) speedup for
global_get_all_tag_namesinlitellm/proxy/spend_tracking/spend_management_endpoints.py⏱️ Runtime :
7.01 milliseconds→6.60 milliseconds(best of93runs)📝 Explanation and details
The optimization achieves a 6% runtime improvement and 1.1% throughput increase through two key changes:
Primary optimization: List comprehension replacement
[row.get("individual_request_tag") for row in db_response]append()Secondary optimization: Simplified None check
if db_response is None: return []toif not db_response: return {"tag_names": []}Noneand empty list cases in a single check, eliminating redundant branchingWhy this matters:
The function processes database query results containing tag arrays, which can be substantial in spend tracking systems. The list comprehension optimization is particularly effective for workloads with many tags (as shown in test cases with 500-1000 tags), where the performance difference compounds. The improvement is consistent across different load patterns - from basic single-tag cases to large concurrent scenarios with 100 simultaneous calls.
Impact on workloads:
Since this is a spend tracking endpoint, it's likely called frequently in financial reporting and analytics workflows where tag aggregation is common. The 6% improvement becomes significant when processing large datasets or handling high-frequency API calls for budget monitoring dashboards.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import asyncio # used to run async functions
from typing import TYPE_CHECKING, Any, List
from unittest.mock import AsyncMock, patch
import pytest # used for our unit tests
function to test
(copied exactly as provided)
from fastapi import APIRouter, Depends, HTTPException, status
from litellm.proxy._types import (TYPE_CHECKING, Any, List, LiteLLM_SpendLogs,
ProxyException)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.spend_tracking.spend_management_endpoints import
global_get_all_tag_names
router = APIRouter()
from litellm.proxy.spend_tracking.spend_management_endpoints import
global_get_all_tag_names
---- UNIT TESTS ----
Helper: Patch prisma_client in the correct import path
PRISMA_CLIENT_PATH = "litellm.proxy.proxy_server.prisma_client"
@pytest.mark.asyncio
async def test_global_get_all_tag_names_basic_tags():
"""
Basic: Test with a typical db_response with several tags.
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{"individual_request_tag": "engineering"},
{"individual_request_tag": "marketing"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_empty_tags():
"""
Basic: Test with an empty db_response (no tags in database).
"""
mock_db_response = []
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_db_response_none():
"""
Edge: Test when db.query_raw returns None (should return empty list).
"""
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = None
@pytest.mark.asyncio
async def test_global_get_all_tag_names_prisma_client_none():
"""
Edge: Test when prisma_client is None (should raise ProxyException).
"""
with patch(PRISMA_CLIENT_PATH, None):
with pytest.raises(Exception) as exc_info:
await global_get_all_tag_names()
@pytest.mark.asyncio
async def test_global_get_all_tag_names_duplicate_tags():
"""
Edge: Test with duplicate tags in db_response (should still return all).
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{"individual_request_tag": "finance"},
{"individual_request_tag": "engineering"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_row_missing_key():
"""
Edge: Test a row missing the 'individual_request_tag' key (should append None).
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{}, # Missing key
{"individual_request_tag": "engineering"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_concurrent_execution():
"""
Edge: Test concurrent execution of the async function.
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{"individual_request_tag": "engineering"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_exception_in_query_raw():
"""
Edge: Simulate an exception being raised by db.query_raw.
"""
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.side_effect = RuntimeError("DB error")
@pytest.mark.asyncio
async def test_global_get_all_tag_names_http_exception():
"""
Edge: Simulate an HTTPException being raised in the try block.
"""
class DummyHTTPException(HTTPException):
def init(self):
super().init(status_code=400, detail="HTTP error", headers=None)
self.param = "test_param"
@pytest.mark.asyncio
async def test_global_get_all_tag_names_proxy_exception():
"""
Edge: Simulate a ProxyException being raised in the try block.
"""
class DummyProxyException(Exception):
pass
@pytest.mark.asyncio
async def test_global_get_all_tag_names_large_scale():
"""
Large Scale: Test with a large number of tags (e.g., 500).
"""
tag_count = 500
mock_db_response = [
{"individual_request_tag": f"tag_{i}"} for i in range(tag_count)
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_large_concurrent():
"""
Large Scale: Test with many concurrent calls (e.g., 50).
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{"individual_request_tag": "engineering"},
{"individual_request_tag": "marketing"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_throughput_small_load():
"""
Throughput: Test throughput with a small number of concurrent calls (10).
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{"individual_request_tag": "engineering"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_throughput_medium_load():
"""
Throughput: Test throughput with a medium number of concurrent calls (50).
"""
mock_db_response = [
{"individual_request_tag": "finance"},
{"individual_request_tag": "engineering"},
{"individual_request_tag": "marketing"},
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
@pytest.mark.asyncio
async def test_global_get_all_tag_names_throughput_large_load():
"""
Throughput: Test throughput with a large number of concurrent calls (100).
"""
tag_count = 100
mock_db_response = [
{"individual_request_tag": f"tag_{i}"} for i in range(tag_count)
]
mock_prisma_client = AsyncMock()
mock_prisma_client.db.query_raw.return_value = mock_db_response
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import asyncio # used to run async functions
from typing import TYPE_CHECKING, Any, List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest # used for our unit tests
function to test
(Copy-pasted as per instructions)
from fastapi import APIRouter, Depends, HTTPException, status
from litellm.proxy._types import (TYPE_CHECKING, Any, List, LiteLLM_SpendLogs,
ProxyException)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.spend_tracking.spend_management_endpoints import
global_get_all_tag_names
router = APIRouter()
from litellm.proxy.spend_tracking.spend_management_endpoints import
global_get_all_tag_names
------------------ UNIT TESTS ------------------
Helper: Patch prisma_client in the function's import context
PRISMA_CLIENT_PATH = "litellm.proxy.proxy_server.prisma_client"
@pytest.mark.asyncio
async def test_global_get_all_tag_names_basic_single_tag():
"""Test basic case: one row, one tag."""
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"individual_request_tag": "tag1"}])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_basic_multiple_tags():
"""Test multiple distinct tags returned."""
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[
{"individual_request_tag": "tag1"},
{"individual_request_tag": "tag2"},
{"individual_request_tag": "tag3"},
])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_basic_empty_result():
"""Test empty result from DB returns empty tag_names list."""
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_basic_db_none():
"""Test db.query_raw returns None (should return empty list)."""
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=None)
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
------------------ EDGE CASES ------------------
@pytest.mark.asyncio
async def test_global_get_all_tag_names_db_not_connected():
"""Test when prisma_client is None, should raise ProxyException."""
with patch(PRISMA_CLIENT_PATH, None):
with pytest.raises(Exception) as excinfo:
await global_get_all_tag_names()
@pytest.mark.asyncio
async def test_global_get_all_tag_names_row_missing_tag_key():
"""Test when a row is missing 'individual_request_tag' key."""
mock_db = MagicMock()
# One row has the key, one does not
mock_db.query_raw = AsyncMock(return_value=[
{"individual_request_tag": "tag1"},
{},
{"individual_request_tag": "tag2"},
])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_concurrent_requests():
"""Test multiple concurrent calls (asyncio.gather) return correct results."""
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[
{"individual_request_tag": "tagA"},
{"individual_request_tag": "tagB"},
])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_db_query_raises_exception():
"""Test if db.query_raw raises an exception, function raises ProxyException."""
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(side_effect=Exception("DB error"))
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
------------------ LARGE SCALE CASES ------------------
@pytest.mark.asyncio
async def test_global_get_all_tag_names_large_number_of_tags():
"""Test with a large number of unique tags (up to 1000)."""
tags = [f"tag{i}" for i in range(1000)]
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"individual_request_tag": t} for t in tags])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_large_concurrent_calls():
"""Test with many concurrent calls (e.g., 50)."""
tags = [f"tag{i}" for i in range(10)]
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"individual_request_tag": t} for t in tags])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
------------------ THROUGHPUT TEST CASES ------------------
@pytest.mark.asyncio
async def test_global_get_all_tag_names_throughput_small_load():
"""Throughput: Run 10 concurrent calls with small dataset."""
tags = ["tag1", "tag2"]
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"individual_request_tag": t} for t in tags])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_throughput_medium_load():
"""Throughput: Run 50 concurrent calls with medium dataset."""
tags = [f"tag{i}" for i in range(50)]
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"individual_request_tag": t} for t in tags])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
@pytest.mark.asyncio
async def test_global_get_all_tag_names_throughput_large_load():
"""Throughput: Run 100 concurrent calls with large dataset (100 tags)."""
tags = [f"tag{i}" for i in range(100)]
mock_db = MagicMock()
mock_db.query_raw = AsyncMock(return_value=[{"individual_request_tag": t} for t in tags])
mock_prisma_client = MagicMock()
mock_prisma_client.db = mock_db
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-global_get_all_tag_names-mhtxuihsand push.