diff --git a/posthog/client.py b/posthog/client.py index 3c4d8b89..368592c2 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -232,6 +232,7 @@ def __init__( self.distinct_ids_feature_flags_reported = SizeLimitedDict(MAX_DICT_SIZE, set) self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url) self.flag_definition_version = 0 + self._flags_etag: Optional[str] = None self.disabled = disabled self.disable_geoip = disable_geoip self.historical_migration = historical_migration @@ -1183,11 +1184,29 @@ def _load_feature_flags(self): f"/api/feature_flag/local_evaluation/?token={self.api_key}&send_cohorts", self.host, timeout=10, + etag=self._flags_etag, ) - self.feature_flags = response["flags"] or [] - self.group_type_mapping = response["group_type_mapping"] or {} - self.cohorts = response["cohorts"] or {} + # Update stored ETag (clear if server stops sending one) + self._flags_etag = response.etag + + # If 304 Not Modified, flags haven't changed - skip processing + if response.not_modified: + self.log.debug( + "[FEATURE FLAGS] Flags not modified (304), using cached data" + ) + self._last_feature_flag_poll = datetime.now(tz=tzutc()) + return + + if response.data is None: + self.log.error( + "[FEATURE FLAGS] Unexpected empty response data in non-304 response" + ) + return + + self.feature_flags = response.data["flags"] or [] + self.group_type_mapping = response.data["group_type_mapping"] or {} + self.cohorts = response.data["cohorts"] or {} # Check if flag definitions changed and update version if self.flag_cache and old_flags_by_key != ( diff --git a/posthog/request.py b/posthog/request.py index 2b97872f..2540f0e7 100644 --- a/posthog/request.py +++ b/posthog/request.py @@ -1,5 +1,7 @@ import json import logging +import re +from dataclasses import dataclass from datetime import date, datetime from gzip import GzipFile from io import BytesIO @@ -12,6 +14,21 @@ from posthog.utils import remove_trailing_slash from posthog.version import VERSION + +def _mask_tokens_in_url(url: str) -> str: + """Mask token values in URLs for safe logging, keeping first 10 chars visible.""" + return re.sub(r"(token=)([^&]{10})[^&]*", r"\1\2...", url) + + +@dataclass +class GetResponse: + """Response from a GET request with ETag support.""" + + data: Any + etag: Optional[str] = None + not_modified: bool = False + + # Retry on both connect and read errors # by default read errors will only retry idempotent HTTP methods (so not POST) adapter = requests.adapters.HTTPAdapter( @@ -139,12 +156,13 @@ def remote_config( timeout: int = 15, ) -> Any: """Get remote config flag value from remote_config API endpoint""" - return get( + response = get( personal_api_key, f"/api/projects/@current/feature_flags/{key}/remote_config?token={project_api_key}", host, timeout, ) + return response.data def batch_post( @@ -162,15 +180,42 @@ def batch_post( def get( - api_key: str, url: str, host: Optional[str] = None, timeout: Optional[int] = None -) -> requests.Response: - url = remove_trailing_slash(host or DEFAULT_HOST) + url - res = requests.get( - url, - headers={"Authorization": "Bearer %s" % api_key, "User-Agent": USER_AGENT}, - timeout=timeout, + api_key: str, + url: str, + host: Optional[str] = None, + timeout: Optional[int] = None, + etag: Optional[str] = None, +) -> GetResponse: + """ + Make a GET request with optional ETag support. + + If an etag is provided, sends If-None-Match header. Returns GetResponse with: + - not_modified=True and data=None if server returns 304 + - not_modified=False and data=response if server returns 200 + """ + log = logging.getLogger("posthog") + full_url = remove_trailing_slash(host or DEFAULT_HOST) + url + headers = {"Authorization": "Bearer %s" % api_key, "User-Agent": USER_AGENT} + + if etag: + headers["If-None-Match"] = etag + + res = _session.get(full_url, headers=headers, timeout=timeout) + + masked_url = _mask_tokens_in_url(full_url) + + # Handle 304 Not Modified + if res.status_code == 304: + log.debug(f"GET {masked_url} returned 304 Not Modified") + response_etag = res.headers.get("ETag") + return GetResponse(data=None, etag=response_etag or etag, not_modified=True) + + # Handle normal response + data = _process_response( + res, success_message=f"GET {masked_url} completed successfully" ) - return _process_response(res, success_message=f"GET {url} completed successfully") + response_etag = res.headers.get("ETag") + return GetResponse(data=data, etag=response_etag, not_modified=False) class APIError(Exception): diff --git a/posthog/test/test_client.py b/posthog/test/test_client.py index e97d5349..a32f2322 100644 --- a/posthog/test/test_client.py +++ b/posthog/test/test_client.py @@ -9,7 +9,7 @@ from posthog.client import Client from posthog.contexts import get_context_session_id, new_context, set_context_session -from posthog.request import APIError +from posthog.request import APIError, GetResponse from posthog.test.test_utils import FAKE_TEST_API_KEY from posthog.types import FeatureFlag, LegacyFlagMetadata from posthog.version import VERSION @@ -2095,13 +2095,21 @@ def test_enable_local_evaluation_false_disables_poller( self, patch_get, patch_poller ): """Test that when enable_local_evaluation=False, the poller is not started""" - patch_get.return_value = { - "flags": [ - {"id": 1, "name": "Beta Feature", "key": "beta-feature", "active": True} - ], - "group_type_mapping": {}, - "cohorts": {}, - } + patch_get.return_value = GetResponse( + data={ + "flags": [ + { + "id": 1, + "name": "Beta Feature", + "key": "beta-feature", + "active": True, + } + ], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag='"test-etag"', + ) client = Client( FAKE_TEST_API_KEY, @@ -2123,13 +2131,21 @@ def test_enable_local_evaluation_false_disables_poller( @mock.patch("posthog.client.get") def test_enable_local_evaluation_true_starts_poller(self, patch_get, patch_poller): """Test that when enable_local_evaluation=True (default), the poller is started""" - patch_get.return_value = { - "flags": [ - {"id": 1, "name": "Beta Feature", "key": "beta-feature", "active": True} - ], - "group_type_mapping": {}, - "cohorts": {}, - } + patch_get.return_value = GetResponse( + data={ + "flags": [ + { + "id": 1, + "name": "Beta Feature", + "key": "beta-feature", + "active": True, + } + ], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag='"test-etag"', + ) client = Client( FAKE_TEST_API_KEY, diff --git a/posthog/test/test_feature_flags.py b/posthog/test/test_feature_flags.py index 06ee233d..1756c778 100644 --- a/posthog/test/test_feature_flags.py +++ b/posthog/test/test_feature_flags.py @@ -11,7 +11,7 @@ match_property, relative_date_parse_for_feature_flag_matching, ) -from posthog.request import APIError +from posthog.request import APIError, GetResponse from posthog.test.test_utils import FAKE_TEST_API_KEY @@ -2348,23 +2348,27 @@ def test_production_style_multivariate_dependency_chain( @mock.patch("posthog.client.Poller") @mock.patch("posthog.client.get") def test_load_feature_flags(self, patch_get, patch_poll): - patch_get.return_value = { - "flags": [ - { - "id": 1, - "name": "Beta Feature", - "key": "beta-feature", - "active": True, - }, - { - "id": 2, - "name": "Alpha Feature", - "key": "alpha-feature", - "active": False, - }, - ], - "group_type_mapping": {"0": "company"}, - } + patch_get.return_value = GetResponse( + data={ + "flags": [ + { + "id": 1, + "name": "Beta Feature", + "key": "beta-feature", + "active": True, + }, + { + "id": 2, + "name": "Alpha Feature", + "key": "alpha-feature", + "active": False, + }, + ], + "group_type_mapping": {"0": "company"}, + "cohorts": {}, + }, + etag='"abc123"', + ) client = Client(FAKE_TEST_API_KEY, personal_api_key="test") with freeze_time("2020-01-01T12:01:00.0000Z"): client.load_feature_flags() @@ -2375,6 +2379,139 @@ def test_load_feature_flags(self, patch_get, patch_poll): client._last_feature_flag_poll.isoformat(), "2020-01-01T12:01:00+00:00" ) self.assertEqual(patch_poll.call_count, 1) + # Verify ETag is stored + self.assertEqual(client._flags_etag, '"abc123"') + + @mock.patch("posthog.client.Poller") + @mock.patch("posthog.client.get") + def test_load_feature_flags_sends_etag_on_subsequent_requests( + self, patch_get, patch_poll + ): + """Test that the ETag is sent in If-None-Match header on subsequent requests""" + patch_get.return_value = GetResponse( + data={ + "flags": [{"id": 1, "key": "beta-feature", "active": True}], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag='"initial-etag"', + ) + client = Client(FAKE_TEST_API_KEY, personal_api_key="test") + client.load_feature_flags() + + # First call should have no etag + first_call_kwargs = patch_get.call_args_list[0][1] + self.assertIsNone(first_call_kwargs.get("etag")) + + # Simulate second call + client._load_feature_flags() + + # Second call should have the etag + second_call_kwargs = patch_get.call_args_list[1][1] + self.assertEqual(second_call_kwargs.get("etag"), '"initial-etag"') + + @mock.patch("posthog.client.Poller") + @mock.patch("posthog.client.get") + def test_load_feature_flags_304_not_modified(self, patch_get, patch_poll): + """Test that 304 Not Modified responses skip flag processing""" + # First response with flags + initial_response = GetResponse( + data={ + "flags": [{"id": 1, "key": "beta-feature", "active": True}], + "group_type_mapping": {"0": "company"}, + "cohorts": {}, + }, + etag='"test-etag"', + ) + # Second response is 304 Not Modified + not_modified_response = GetResponse( + data=None, + etag='"test-etag"', + not_modified=True, + ) + patch_get.side_effect = [initial_response, not_modified_response] + + client = Client(FAKE_TEST_API_KEY, personal_api_key="test") + client.load_feature_flags() + + # Verify initial flags are loaded + self.assertEqual(len(client.feature_flags), 1) + self.assertEqual(client.feature_flags[0]["key"], "beta-feature") + self.assertEqual(client.group_type_mapping, {"0": "company"}) + + # Second call with 304 + client._load_feature_flags() + + # Flags should still be the same (not cleared) + self.assertEqual(len(client.feature_flags), 1) + self.assertEqual(client.feature_flags[0]["key"], "beta-feature") + self.assertEqual(client.group_type_mapping, {"0": "company"}) + + @mock.patch("posthog.client.Poller") + @mock.patch("posthog.client.get") + def test_load_feature_flags_etag_updated_on_new_response( + self, patch_get, patch_poll + ): + """Test that ETag is updated when flags change""" + patch_get.side_effect = [ + GetResponse( + data={ + "flags": [{"id": 1, "key": "flag-v1", "active": True}], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag='"etag-v1"', + ), + GetResponse( + data={ + "flags": [{"id": 1, "key": "flag-v2", "active": True}], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag='"etag-v2"', + ), + ] + + client = Client(FAKE_TEST_API_KEY, personal_api_key="test") + client.load_feature_flags() + self.assertEqual(client._flags_etag, '"etag-v1"') + + client._load_feature_flags() + self.assertEqual(client._flags_etag, '"etag-v2"') + self.assertEqual(client.feature_flags[0]["key"], "flag-v2") + + @mock.patch("posthog.client.Poller") + @mock.patch("posthog.client.get") + def test_load_feature_flags_clears_etag_when_server_stops_sending( + self, patch_get, patch_poll + ): + """Test that ETag is cleared when server stops sending it""" + patch_get.side_effect = [ + GetResponse( + data={ + "flags": [{"id": 1, "key": "flag-v1", "active": True}], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag='"etag-v1"', + ), + GetResponse( + data={ + "flags": [{"id": 1, "key": "flag-v2", "active": True}], + "group_type_mapping": {}, + "cohorts": {}, + }, + etag=None, # Server stopped sending ETag + ), + ] + + client = Client(FAKE_TEST_API_KEY, personal_api_key="test") + client.load_feature_flags() + self.assertEqual(client._flags_etag, '"etag-v1"') + + client._load_feature_flags() + self.assertIsNone(client._flags_etag) + self.assertEqual(client.feature_flags[0]["key"], "flag-v2") def test_load_feature_flags_wrong_key(self): client = Client(FAKE_TEST_API_KEY, personal_api_key=FAKE_TEST_API_KEY) diff --git a/posthog/test/test_request.py b/posthog/test/test_request.py index 89318250..7eee835f 100644 --- a/posthog/test/test_request.py +++ b/posthog/test/test_request.py @@ -7,15 +7,53 @@ import requests from posthog.request import ( + APIError, DatetimeSerializer, + GetResponse, QuotaLimitError, + _mask_tokens_in_url, batch_post, decide, determine_server_host, + get, ) from posthog.test.test_utils import TEST_API_KEY +@pytest.mark.parametrize( + "url, expected", + [ + # Token with params after - masks keeping first 10 chars + ( + "https://example.com/api/flags?token=phc_abc123xyz789&send_cohorts", + "https://example.com/api/flags?token=phc_abc123...&send_cohorts", + ), + # Token at end of URL + ( + "https://example.com/api/flags?token=phc_abc123xyz789", + "https://example.com/api/flags?token=phc_abc123...", + ), + # No token - unchanged + ( + "https://example.com/api/flags?other=value", + "https://example.com/api/flags?other=value", + ), + # Short token (<10 chars) - unchanged + ( + "https://example.com/api/flags?token=short", + "https://example.com/api/flags?token=short", + ), + # Exactly 10 char token - gets ellipsis + ( + "https://example.com/api/flags?token=1234567890", + "https://example.com/api/flags?token=1234567890...", + ), + ], +) +def test_mask_tokens_in_url(url, expected): + assert _mask_tokens_in_url(url) == expected + + class TestRequests(unittest.TestCase): def test_valid_request(self): res = batch_post( @@ -107,6 +145,184 @@ def test_normal_decide_response(self): self.assertEqual(response["featureFlags"], {"flag1": True}) +class TestGet(unittest.TestCase): + """Unit tests for the get() function HTTP-level behavior.""" + + @mock.patch("posthog.request._session.get") + def test_get_returns_data_and_etag(self, mock_get): + """Test that get() returns GetResponse with data and etag from headers.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response.headers["ETag"] = '"abc123"' + mock_response._content = json.dumps({"flags": [{"key": "test-flag"}]}).encode( + "utf-8" + ) + mock_get.return_value = mock_response + + response = get("api_key", "/test-url", host="https://example.com") + + self.assertIsInstance(response, GetResponse) + self.assertEqual(response.data, {"flags": [{"key": "test-flag"}]}) + self.assertEqual(response.etag, '"abc123"') + self.assertFalse(response.not_modified) + + @mock.patch("posthog.request._session.get") + def test_get_sends_if_none_match_header_when_etag_provided(self, mock_get): + """Test that If-None-Match header is sent when etag parameter is provided.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response.headers["ETag"] = '"new-etag"' + mock_response._content = json.dumps({"flags": []}).encode("utf-8") + mock_get.return_value = mock_response + + get("api_key", "/test-url", host="https://example.com", etag='"previous-etag"') + + call_kwargs = mock_get.call_args[1] + self.assertEqual(call_kwargs["headers"]["If-None-Match"], '"previous-etag"') + + @mock.patch("posthog.request._session.get") + def test_get_does_not_send_if_none_match_when_no_etag(self, mock_get): + """Test that If-None-Match header is not sent when no etag provided.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({"flags": []}).encode("utf-8") + mock_get.return_value = mock_response + + get("api_key", "/test-url", host="https://example.com") + + call_kwargs = mock_get.call_args[1] + self.assertNotIn("If-None-Match", call_kwargs["headers"]) + + @mock.patch("posthog.request._session.get") + def test_get_handles_304_not_modified(self, mock_get): + """Test that 304 Not Modified response returns not_modified=True with no data.""" + mock_response = requests.Response() + mock_response.status_code = 304 + mock_response.headers["ETag"] = '"unchanged-etag"' + mock_get.return_value = mock_response + + response = get( + "api_key", "/test-url", host="https://example.com", etag='"unchanged-etag"' + ) + + self.assertIsInstance(response, GetResponse) + self.assertIsNone(response.data) + self.assertEqual(response.etag, '"unchanged-etag"') + self.assertTrue(response.not_modified) + + @mock.patch("posthog.request._session.get") + def test_get_304_without_etag_header_uses_request_etag(self, mock_get): + """Test that 304 response without ETag header falls back to request etag.""" + mock_response = requests.Response() + mock_response.status_code = 304 + # Server doesn't return ETag header on 304 + mock_get.return_value = mock_response + + response = get( + "api_key", "/test-url", host="https://example.com", etag='"original-etag"' + ) + + self.assertTrue(response.not_modified) + self.assertEqual(response.etag, '"original-etag"') + + @mock.patch("posthog.request._session.get") + def test_get_200_without_etag_header(self, mock_get): + """Test that 200 response without ETag header returns None for etag.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({"flags": []}).encode("utf-8") + # No ETag header + mock_get.return_value = mock_response + + response = get("api_key", "/test-url", host="https://example.com") + + self.assertFalse(response.not_modified) + self.assertIsNone(response.etag) + self.assertEqual(response.data, {"flags": []}) + + @mock.patch("posthog.request._session.get") + def test_get_error_response_raises_api_error(self, mock_get): + """Test that error responses raise APIError.""" + mock_response = requests.Response() + mock_response.status_code = 401 + mock_response._content = json.dumps({"detail": "Unauthorized"}).encode("utf-8") + mock_get.return_value = mock_response + + with self.assertRaises(APIError) as ctx: + get("bad_key", "/test-url", host="https://example.com") + + self.assertEqual(ctx.exception.status, 401) + self.assertEqual(ctx.exception.message, "Unauthorized") + + @mock.patch("posthog.request._session.get") + def test_get_sends_authorization_header(self, mock_get): + """Test that Authorization header is sent with Bearer token.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({}).encode("utf-8") + mock_get.return_value = mock_response + + get("my-api-key", "/test-url", host="https://example.com") + + call_kwargs = mock_get.call_args[1] + self.assertEqual(call_kwargs["headers"]["Authorization"], "Bearer my-api-key") + + @mock.patch("posthog.request._session.get") + def test_get_sends_user_agent_header(self, mock_get): + """Test that User-Agent header is sent.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({}).encode("utf-8") + mock_get.return_value = mock_response + + get("api_key", "/test-url", host="https://example.com") + + call_kwargs = mock_get.call_args[1] + self.assertIn("User-Agent", call_kwargs["headers"]) + self.assertTrue( + call_kwargs["headers"]["User-Agent"].startswith("posthog-python/") + ) + + @mock.patch("posthog.request._session.get") + def test_get_passes_timeout(self, mock_get): + """Test that timeout parameter is passed to the request.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({}).encode("utf-8") + mock_get.return_value = mock_response + + get("api_key", "/test-url", host="https://example.com", timeout=30) + + call_kwargs = mock_get.call_args[1] + self.assertEqual(call_kwargs["timeout"], 30) + + @mock.patch("posthog.request._session.get") + def test_get_constructs_full_url(self, mock_get): + """Test that host and url are combined correctly.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({}).encode("utf-8") + mock_get.return_value = mock_response + + get("api_key", "/api/flags", host="https://example.com") + + call_args = mock_get.call_args[0] + self.assertEqual(call_args[0], "https://example.com/api/flags") + + @mock.patch("posthog.request._session.get") + def test_get_removes_trailing_slash_from_host(self, mock_get): + """Test that trailing slash is removed from host.""" + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({}).encode("utf-8") + mock_get.return_value = mock_response + + get("api_key", "/api/flags", host="https://example.com/") + + call_args = mock_get.call_args[0] + self.assertEqual(call_args[0], "https://example.com/api/flags") + + @pytest.mark.parametrize( "host, expected", [