From fd7fbc9e35410bc574ad194c3c5485a216a19407 Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Sat, 1 Jul 2023 21:09:06 -0700 Subject: [PATCH] fix: lru cache HF get model params (#105) --- CHANGELOG.rst | 4 ++++ manifest/clients/diffuser.py | 2 ++ manifest/clients/huggingface.py | 2 ++ manifest/clients/huggingface_embedding.py | 2 ++ 4 files changed, 10 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1d7bbfe..e18f607 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,9 @@ 0.1.9 - Unreleased --------------------- +Fixed +^^^^^ +* Added trust code params HF models +* Added LRU cache to HF model param calls to avoid extra calls 0.1.8 - 2023-05-22 --------------------- diff --git a/manifest/clients/diffuser.py b/manifest/clients/diffuser.py index cfdeefe..551c3c1 100644 --- a/manifest/clients/diffuser.py +++ b/manifest/clients/diffuser.py @@ -1,5 +1,6 @@ """Diffuser client.""" import logging +from functools import lru_cache from typing import Any, Dict, Optional import numpy as np @@ -79,6 +80,7 @@ def supports_streaming_inference(self) -> bool: """ return False + @lru_cache(maxsize=1) def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/huggingface.py b/manifest/clients/huggingface.py index 9a61027..7f43539 100644 --- a/manifest/clients/huggingface.py +++ b/manifest/clients/huggingface.py @@ -1,5 +1,6 @@ """Hugging Face client.""" import logging +from functools import lru_cache from typing import Any, Dict, Optional import requests @@ -73,6 +74,7 @@ def supports_streaming_inference(self) -> bool: """ return False + @lru_cache(maxsize=1) def get_model_params(self) -> Dict: """ Get model params. diff --git a/manifest/clients/huggingface_embedding.py b/manifest/clients/huggingface_embedding.py index 02b3d81..9478355 100644 --- a/manifest/clients/huggingface_embedding.py +++ b/manifest/clients/huggingface_embedding.py @@ -1,5 +1,6 @@ """Hugging Face client.""" import logging +from functools import lru_cache from typing import Any, Dict, Optional, Tuple import numpy as np @@ -65,6 +66,7 @@ def supports_streaming_inference(self) -> bool: """ return False + @lru_cache(maxsize=1) def get_model_params(self) -> Dict: """ Get model params.