Skip to content

Commit

Permalink
fix: made the to method optional on the encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Feb 17, 2024
1 parent 0b6d0be commit 157a91c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
15 changes: 11 additions & 4 deletions src/seb/interfaces/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
from dataclasses import dataclass
from datetime import date
from pathlib import Path
Expand All @@ -14,6 +15,9 @@
from .task import Task


logger = logging.getLogger(__name__)


@runtime_checkable
class Encoder(Protocol):
"""
Expand Down Expand Up @@ -42,10 +46,10 @@ def encode(
"""
...

def to(self, device: torch.device):
...

# The following methods are optional and can be implemented if the model supports them.
# def to(self, device: torch.device):
# ...

# def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
# ...

Expand Down Expand Up @@ -109,7 +113,10 @@ def load_model(self):

def to(self, device: torch.device):
self.load_model()
self._model = self._model.to(device) # type: ignore
try:
self._model = self._model.to(device) # type: ignore
except AttributeError:
logging.debug(f"Model {self._model} does not have a to method")

@property
def model(self) -> Encoder:
Expand Down
1 change: 0 additions & 1 deletion tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import pytest

import seb
from seb.cli import cli, run_benchmark_cli

Expand Down

0 comments on commit 157a91c

Please sign in to comment.