Skip to content

Commit

Permalink
Adding custom User-Agent to client (#121)
Browse files Browse the repository at this point in the history
* Adding custom User-Agent to client

* Lint and flake fixes

* Fixing user-agent capitalization

* Update spicepy/_http.py

Co-authored-by: Sergei Grebnov <sergei.grebnov@gmail.com>

* Adding else branch for HTTP headers

---------

Co-authored-by: Sergei Grebnov <sergei.grebnov@gmail.com>
  • Loading branch information
slyons and sgrebnov authored Nov 25, 2024
1 parent 07082a2 commit 8a27c04
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Querying data is done through a `Client` object that initialize the connection w
- **api_key** (string, required): API key to authenticate with the endpoint.
- **url** (string, optional): URL of the endpoint to use (default: grpc+tls://flight.spiceai.io; firecache: grpc+tls://firecache.spiceai.io)
- **tls_root_cert** (Path or string, optional): Path to the tls certificate to use for the secure connection (omit for automatic detection)
- **user_agent** (string, optional): A custom `User-Agent` string to pass when connecting to Spice. Use `spicepy.config.get_user_agent` to build the custom `User-Agent`

Once a `Client` is obtained queries can be made using the `query()` function. The `query()` function has the following arguments:

Expand Down
25 changes: 16 additions & 9 deletions spicepy/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,16 @@ def read_cert(self, tls_root_cert):

class _SpiceFlight:
@staticmethod
def _user_agent():
def _user_agent(custom_user_agent=None):
# headers kwargs claim to support Tuple[str, str], but it's actually Tuple[bytes, bytes] :|
# Open issue in Arrow: https://github.com/apache/arrow/issues/35288
return (str.encode("x-spice-user-agent"), str.encode(config.SPICE_USER_AGENT))
user_agent = custom_user_agent or config.SPICE_USER_AGENT
return (str.encode("user-agent"), str.encode(user_agent))

def __init__(self, grpc: str, api_key: str, tls_root_certs):
def __init__(self, grpc: str, api_key: str, tls_root_certs, user_agent=None):
self._flight_client = flight.connect(grpc, tls_root_certs=tls_root_certs)
self._api_key = api_key
self.headers = [_SpiceFlight._user_agent()]
self.headers = [_SpiceFlight._user_agent(user_agent)]
self._flight_options = flight.FlightCallOptions(
headers=self.headers, timeout=DEFAULT_QUERY_TIMEOUT_SECS
)
Expand Down Expand Up @@ -134,25 +135,31 @@ def _threaded_flight_do_get(self, ticket: Ticket):


class Client:
# pylint: disable=R0917
def __init__(
self,
api_key: str = None,
flight_url: str = config.DEFAULT_LOCAL_FLIGHT_URL,
http_url: str = config.DEFAULT_HTTP_URL,
tls_root_cert: Union[str, Path, None] = None,
user_agent: Optional[str] = None,
): # pylint: disable=R0913
tls_root_certs = _Cert(tls_root_cert).tls_root_certs
self._flight = _SpiceFlight(flight_url, api_key, tls_root_certs)
self._flight = _SpiceFlight(flight_url, api_key, tls_root_certs, user_agent)

self.api_key = api_key
self.http = HttpRequests(http_url, self._headers())
self.http = HttpRequests(http_url, self._headers(user_agent))

def _headers(self) -> Dict[str, str]:
return {
def _headers(self, user_agent=None) -> Dict[str, str]:
headers = {
"X-API-Key": self._api_key(),
"Accept": "application/json",
"User-Agent": "spicepy 2.0",
}
if user_agent is not None:
headers["user-agent"] = user_agent
else:
headers["user-agent"] = config.SPICE_USER_AGENT
return headers

def _api_key(self) -> str:
key = self.api_key
Expand Down
5 changes: 3 additions & 2 deletions spicepy/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class HttpRequests:
def __init__(self, base_url: str, headers: Dict[str, str]) -> None:
self.session = self._create_session(headers)

# set the x-spice-user-agent header
self.session.headers["X-Spice-User-Agent"] = SPICE_USER_AGENT
# set the user-agent header
if "user-agent" not in self.session.headers:
self.session.headers["user-agent"] = SPICE_USER_AGENT

self.base_url = base_url

Expand Down
19 changes: 15 additions & 4 deletions spicepy/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import platform
from importlib.metadata import version
from typing import Optional

DEFAULT_FLIGHT_URL = os.environ.get("SPICE_FLIGHT_URL", "grpc+tls://flight.spiceai.io")
DEFAULT_HTTP_URL = os.environ.get("SPICE_HTTP_URL", "https://data.spiceai.io")
Expand All @@ -11,16 +12,26 @@
DEFAULT_LOCAL_HTTP_URL = os.environ.get("SPICE_LOCAL_HTTP_URL", "http://localhost:8090")


def get_user_agent():
package_version = version("spicepy")
###
# Get the default `User-Agent` string, or build a custom one
#
# client_name: Optional[str] = None : The name of the client. Default is `spicepy`.
# client_version: Optional[str] = None : The version of the client. Default is the version of the `spicepy` package.
# client_system: Optional[str] = None : The system information of the client.
# Default is the system information of the current system, e.g. `Linux/5.4.0-1043-aws x86_64`.
###
def get_user_agent(client_name: Optional[str] = None, client_version: Optional[str] = None,
client_system: Optional[str] = None) -> str:
package_version = version("spicepy") if client_version is None else client_version
system = platform.system()
release = platform.release()
arch = platform.machine()
if arch == "AMD64":
arch = "x86_64"

system_info = f"{system}/{release} {arch}"
return f"spicepy {package_version} ({system_info})"
system_info = f"{system}/{release} {arch}" if client_system is None else client_system
client = "spicepy" if client_name is None else client_name
return f"{client}/{package_version} ({system_info})"


SPICE_USER_AGENT = get_user_agent()
3 changes: 2 additions & 1 deletion test.requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pylint>=3.3.1
flake8>=7.1.1
pytest>=8.3.3
pytest>=8.3.3
pytest_httpserver==1.1.0
32 changes: 31 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import time
import re
import json
import pytest
from spicepy import Client, RefreshOpts
from spicepy.config import (
SPICE_USER_AGENT,
DEFAULT_LOCAL_FLIGHT_URL,
DEFAULT_LOCAL_HTTP_URL,
get_user_agent
)


Expand All @@ -27,7 +29,7 @@ def get_local_client():

def test_user_agent_is_populated():
# use a regex to match the expected user agent string
matching_regex = r"spicepy \d+\.\d+\.\d+ \((Linux|Windows|Darwin)/[\d\w\.\-\_]+ (x86_64|aarch64|i386|arm64)\)"
matching_regex = r"spicepy/\d+\.\d+\.\d+ \((Linux|Windows|Darwin)/[\d\w\.\-\_]+ (x86_64|aarch64|i386|arm64)\)"

assert re.match(matching_regex, SPICE_USER_AGENT)

Expand Down Expand Up @@ -140,9 +142,37 @@ def test_local_runtime_refresh():
assert len(pandas_data) == 20


# pylint: disable=E1120
def test_user_agent(httpserver):
reply = {"message": "OK"}
httpserver.expect_request("/v1/datasets/test/acceleration/refresh", headers={"User-Agent": SPICE_USER_AGENT})\
.respond_with_data(json.dumps(reply), content_type="application/json")
client = Client(flight_url=DEFAULT_LOCAL_FLIGHT_URL, http_url=httpserver.url_for("/"))
response = client.refresh_dataset("test")
httpserver.check_assertions()
assert response == reply

httpserver.expect_request("/v1/datasets/test/acceleration/refresh", headers={"User-Agent": "custom-agent"})\
.respond_with_data(json.dumps(reply), content_type="application/json")
client = Client(flight_url=DEFAULT_LOCAL_FLIGHT_URL, http_url=httpserver.url_for("/"), user_agent="custom-agent")
response = client.refresh_dataset("test")
httpserver.check_assertions()
assert response == reply

custom_ua = get_user_agent("custom-client", "1.0.0", "custom-system")
httpserver.expect_request("/v1/datasets/test/acceleration/refresh",
headers={"User-Agent": "custom-client/1.0.0 (custom-system)"})\
.respond_with_data(json.dumps(reply), content_type="application/json")
client = Client(flight_url=DEFAULT_LOCAL_FLIGHT_URL, http_url=httpserver.url_for("/"), user_agent=custom_ua)
response = client.refresh_dataset("test")
httpserver.check_assertions()
assert response == reply


if __name__ == "__main__":
test_flight_recent_blocks()
test_flight_streaming()
test_flight_timeout()
test_local_runtime()
test_local_runtime_refresh()
test_user_agent()

0 comments on commit 8a27c04

Please sign in to comment.