Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid leaking memory when Client.with_options is used #956

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ select = [
"T203",
]
ignore = [
# lru_cache in methods, will be fixed separately
"B019",
# mutable defaults
"B006",
]
Expand Down
28 changes: 15 additions & 13 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,12 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)

# headers are case-insensitive while dictionaries are not.
headers = httpx.Headers(headers_dict)

idempotency_header = self._idempotency_header
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
if not options.idempotency_key:
options.idempotency_key = self._idempotency_key()

headers[idempotency_header] = options.idempotency_key
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()

return headers

Expand Down Expand Up @@ -594,16 +592,8 @@ def base_url(self) -> URL:
def base_url(self, url: URL | str) -> None:
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))

@lru_cache(maxsize=None)
def platform_headers(self) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": self._version,
"X-Stainless-OS": str(get_platform()),
"X-Stainless-Arch": str(get_architecture()),
"X-Stainless-Runtime": platform.python_implementation(),
"X-Stainless-Runtime-Version": platform.python_version(),
}
return platform_headers(self._version)

def _calculate_retry_timeout(
self,
Expand Down Expand Up @@ -1691,6 +1681,18 @@ def get_platform() -> Platform:
return "Unknown"


@lru_cache(maxsize=None)
def platform_headers(version: str) -> Dict[str, str]:
return {
"X-Stainless-Lang": "python",
"X-Stainless-Package-Version": version,
"X-Stainless-OS": str(get_platform()),
"X-Stainless-Arch": str(get_architecture()),
"X-Stainless-Runtime": platform.python_implementation(),
"X-Stainless-Runtime-Version": platform.python_version(),
}


class OtherArch:
def __init__(self, name: str) -> None:
self.name = name
Expand Down
4 changes: 2 additions & 2 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def copy(
return self.__class__(
api_key=api_key or self.api_key,
organization=organization or self.organization,
base_url=base_url or str(self.base_url),
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
Expand Down Expand Up @@ -402,7 +402,7 @@ def copy(
return self.__class__(
api_key=api_key or self.api_key,
organization=organization or self.organization,
base_url=base_url or str(self.base_url),
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
Expand Down
124 changes: 124 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import gc
import os
import json
import asyncio
import inspect
import tracemalloc
from typing import Any, Union, cast
from unittest import mock

Expand Down Expand Up @@ -195,6 +197,67 @@ def test_copy_signature(self) -> None:
copy_param = copy_signature.parameters.get(name)
assert copy_param is not None, f"copy() signature is missing the {name} param"

def test_copy_build_request(self) -> None:
options = FinalRequestOptions(method="get", url="/foo")

def build_request(options: FinalRequestOptions) -> None:
client = self.client.copy()
client._build_request(options)

# ensure that the machinery is warmed up before tracing starts.
build_request(options)
gc.collect()

tracemalloc.start(1000)

snapshot_before = tracemalloc.take_snapshot()

ITERATIONS = 10
for _ in range(ITERATIONS):
build_request(options)
gc.collect()

snapshot_after = tracemalloc.take_snapshot()

tracemalloc.stop()

def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
if diff.count == 0:
# Avoid false positives by considering only leaks (i.e. allocations that persist).
return

if diff.count % ITERATIONS != 0:
# Avoid false positives by considering only leaks that appear per iteration.
return

for frame in diff.traceback:
if any(
frame.filename.endswith(fragment)
for fragment in [
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
#
# removing the decorator fixes the leak for reasons we don't understand.
"openai/_response.py",
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
"openai/_compat.py",
# Standard library leaks we don't care about.
"/logging/__init__.py",
]
):
return

leaks.append(diff)

leaks: list[tracemalloc.StatisticDiff] = []
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
add_leak(leaks, diff)
if leaks:
for leak in leaks:
print("MEMORY LEAK:", leak)
for frame in leak.traceback:
print(frame)
raise AssertionError()

def test_request_timeout(self) -> None:
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
Expand Down Expand Up @@ -858,6 +921,67 @@ def test_copy_signature(self) -> None:
copy_param = copy_signature.parameters.get(name)
assert copy_param is not None, f"copy() signature is missing the {name} param"

def test_copy_build_request(self) -> None:
options = FinalRequestOptions(method="get", url="/foo")

def build_request(options: FinalRequestOptions) -> None:
client = self.client.copy()
client._build_request(options)

# ensure that the machinery is warmed up before tracing starts.
build_request(options)
gc.collect()

tracemalloc.start(1000)

snapshot_before = tracemalloc.take_snapshot()

ITERATIONS = 10
for _ in range(ITERATIONS):
build_request(options)
gc.collect()

snapshot_after = tracemalloc.take_snapshot()

tracemalloc.stop()

def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
if diff.count == 0:
# Avoid false positives by considering only leaks (i.e. allocations that persist).
return

if diff.count % ITERATIONS != 0:
# Avoid false positives by considering only leaks that appear per iteration.
return

for frame in diff.traceback:
if any(
frame.filename.endswith(fragment)
for fragment in [
# to_raw_response_wrapper leaks through the @functools.wraps() decorator.
#
# removing the decorator fixes the leak for reasons we don't understand.
"openai/_response.py",
# pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
"openai/_compat.py",
# Standard library leaks we don't care about.
"/logging/__init__.py",
]
):
return

leaks.append(diff)

leaks: list[tracemalloc.StatisticDiff] = []
for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
add_leak(leaks, diff)
if leaks:
for leak in leaks:
print("MEMORY LEAK:", leak)
for frame in leak.traceback:
print(frame)
raise AssertionError()

async def test_request_timeout(self) -> None:
request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
Expand Down