Skip to content

Commit

Permalink
Break: get_model returns tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Mar 28, 2024
1 parent 2f725dc commit a710ea9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
12 changes: 12 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Contributing to UForm

We welcome contributions to UForm!
Before submitting any changes, please make sure that the tests pass.

```sh
pip install -e . # For core dependencies
pip install -e ".[torch]" # For PyTorch
pip install -e ".[onnx]" # For ONNX on CPU
pip install -e ".[onnx-gpu]" # For ONNX on GPU, available for some platforms
pytest scripts/ -s -x -Wd -v
```
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ For that pick the encoder of the model you want to run in parallel (`text_encode
```python
import uform

model = uform.get_model('unum-cloud/uform-vl-english')
model, processor = uform.get_model('unum-cloud/uform-vl-english')
model_image = nn.DataParallel(model.image_encoder)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
16 changes: 8 additions & 8 deletions scripts/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import requests
import torch
from PIL import Image
from transformers import (AutoProcessor, InstructBlipForConditionalGeneration,
InstructBlipProcessor, LlavaForConditionalGeneration)
from transformers import (
AutoProcessor,
InstructBlipForConditionalGeneration,
InstructBlipProcessor,
LlavaForConditionalGeneration,
)

from uform import get_model
from uform.gen_model import VLMForCausalLM, VLMProcessor
Expand Down Expand Up @@ -76,9 +80,7 @@ def bench_image_embeddings(model, images):
total_embeddings = 0
images *= 10
while total_duration < 10:
seconds, embeddings = duration(
lambda: model.encode_image(model.preprocess_image(images))
)
seconds, embeddings = duration(lambda: model.encode_image(processor.preprocess_image(images)))
total_duration += seconds
total_embeddings += len(embeddings)

Expand All @@ -90,9 +92,7 @@ def bench_text_embeddings(model, texts):
total_embeddings = 0
texts *= 10
while total_duration < 10:
seconds, embeddings = duration(
lambda: model.encode_text(model.preprocess_text(texts))
)
seconds, embeddings = duration(lambda: model.encode_text(processor.preprocess_text(texts)))
total_duration += seconds
total_embeddings += len(embeddings)

Expand Down
12 changes: 6 additions & 6 deletions scripts/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

@pytest.mark.parametrize("model_name", torch_models)
def test_one_embedding(model_name: str):
model = uform.get_model(model_name)
model, processor = uform.get_model(model_name)
text = "a small red panda in a zoo"
image_path = "assets/unum.png"

image = Image.open(image_path)
image_data = model.preprocess_image(image)
text_data = model.preprocess_text(text)
image_data = processor.preprocess_image(image)
text_data = processor.preprocess_text(text)

_, image_embedding = model.encode_image(image_data, return_features=True)
_, text_embedding = model.encode_text(text_data, return_features=True)
Expand All @@ -28,13 +28,13 @@ def test_one_embedding(model_name: str):
@pytest.mark.parametrize("model_name", torch_models)
@pytest.mark.parametrize("batch_size", [1, 2])
def test_many_embeddings(model_name: str, batch_size: int):
model = uform.get_model(model_name)
model, processor = uform.get_model(model_name)
texts = ["a small red panda in a zoo"] * batch_size
image_paths = ["assets/unum.png"] * batch_size

images = [Image.open(path) for path in image_paths]
image_data = model.preprocess_image(images)
text_data = model.preprocess_text(texts)
image_data = processor.preprocess_image(images)
text_data = processor.preprocess_text(texts)

image_embeddings = model.encode_image(image_data, return_features=False)
text_embeddings = model.encode_text(text_data, return_features=False)
Expand Down

0 comments on commit a710ea9

Please sign in to comment.