Skip to content

Commit

Permalink
[FEAT][prep_torch_inference]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 23, 2023
1 parent a9b91d4 commit f39d722
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 11 deletions.
2 changes: 1 addition & 1 deletion swarms/memory/weaviate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

try:
import weaviate
except ImportError as error:
except ImportError:
print("pip install weaviate-client")


Expand Down
6 changes: 3 additions & 3 deletions swarms/models/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def time_to_first_token(self, prompt: str) -> float:
float: _description_
"""
start_time = time.time()
tokens = self.track_resource_utilization(
self.track_resource_utilization(
prompt
) # assuming `generate` is a method that generates tokens
first_token_time = time.time()
Expand All @@ -411,7 +411,7 @@ def generation_latency(self, prompt: str) -> float:
float: _description_
"""
start_time = time.time()
tokens = self.run(prompt)
self.run(prompt)
end_time = time.time()
return end_time - start_time

Expand All @@ -426,6 +426,6 @@ def throughput(self, prompts: List[str]) -> float:
"""
start_time = time.time()
for prompt in prompts:
tokens = self.run(prompt)
self.run(prompt)
end_time = time.time()
return len(prompts) / (end_time - start_time)
4 changes: 4 additions & 0 deletions swarms/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from swarms.utils.llm_metrics_decorator import metrics_decorator
from swarms.utils.device_checker_cuda import check_device
from swarms.utils.load_model_torch import load_model_torch
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,
)

__all__ = [
"display_markdown_message",
Expand All @@ -18,4 +21,5 @@
"metrics_decorator",
"check_device",
"load_model_torch",
"prep_torch_inference",
]
30 changes: 30 additions & 0 deletions swarms/utils/prep_torch_model_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from swarms.utils.load_model_torch import load_model_torch


def prep_torch_inference(
model_path: str = None,
device: torch.device = None,
*args,
**kwargs,
):
"""
Prepare a Torch model for inference.
Args:
model_path (str): Path to the model file.
device (torch.device): Device to run the model on.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
torch.nn.Module: The prepared model.
"""
try:
model = load_model_torch(model_path, device)
model.eval()
return model
except Exception as e:
# Add error handling code here
print(f"Error occurred while preparing Torch model: {e}")
return None
10 changes: 5 additions & 5 deletions tests/models/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def test_init_huggingface_llm():
assert llm.model_id == "test_model"
assert llm.device == "cuda"
assert llm.max_length == 1000
assert llm.quantize == True
assert llm.quantize is True
assert llm.quantization_config == {"config_key": "config_value"}
assert llm.verbose == True
assert llm.distributed == True
assert llm.decoding == True
assert llm.verbose is True
assert llm.distributed is True
assert llm.decoding is True
assert llm.max_workers == 3
assert llm.repitition_penalty == 1.5
assert llm.no_repeat_ngram_size == 4
Expand All @@ -90,7 +90,7 @@ def test_load_model(mock_huggingface_llm):
# Test running the model
def test_run(mock_huggingface_llm):
llm = HuggingfaceLLM(model_id="test_model")
result = llm.run("Test prompt")
llm.run("Test prompt")

# Ensure that the run function is called
assert True
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/load_models_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def test_load_model_torch_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
mocker.patch("torch.cuda.is_available", return_value=False)
model = load_model_torch("model_path")
load_model_torch("model_path")
mock_model.to.assert_called_once_with(torch.device("cpu"))


def test_load_model_torch_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch("torch.load", return_value=mock_model)
model = load_model_torch(
load_model_torch(
"model_path", device=torch.device("cuda")
)
mock_model.to.assert_called_once_with(torch.device("cuda"))
Expand Down
50 changes: 50 additions & 0 deletions tests/utils/prep_torch_model_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
from unittest.mock import MagicMock
from swarms.utils.prep_torch_model_inference import (
prep_torch_inference,
)


def test_prep_torch_inference_no_model_path():
result = prep_torch_inference()
assert result is None


def test_prep_torch_inference_model_not_found(mocker):
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
side_effect=FileNotFoundError,
)
result = prep_torch_inference("non_existent_model_path")
assert result is None


def test_prep_torch_inference_runtime_error(mocker):
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
side_effect=RuntimeError,
)
result = prep_torch_inference("model_path")
assert result is None


def test_prep_torch_inference_no_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference("model_path")
mock_model.eval.assert_called_once()


def test_prep_torch_inference_device_specified(mocker):
mock_model = MagicMock(spec=torch.nn.Module)
mocker.patch(
"swarms.utils.prep_torch_model_inference.load_model_torch",
return_value=mock_model,
)
prep_torch_inference(
"model_path", device=torch.device("cuda")
)
mock_model.eval.assert_called_once()

0 comments on commit f39d722

Please sign in to comment.