Skip to content

Commit

Permalink
Fixed syntax error in Cohere
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Jan 31, 2024
1 parent 834e5c7 commit e1d18f1
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/seb/registered_models/cohere_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def encode(
*,
task: Optional[Task] = None,
**kwargs: Any, # noqa: ARG002
) -> torch.Tensor
) -> torch.Tensor:
if task and task.task_type == "Classification":
input_type = "classification"
elif task and task.task_type == "Clustering":
Expand All @@ -55,14 +55,21 @@ def encode(
def encode_queries(self, queries: list[str], batch_size: int, **kwargs): # noqa
return self._embed(queries, input_type="search_query")

def encode_corpus(self, corpus: list[dict[str, str]], batch_size: int, **kwargs): # noqa
def encode_corpus(
self, corpus: list[dict[str, str]], batch_size: int, **kwargs
): # noqa
if isinstance(corpus, dict):
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() # type: ignore
for i in range(len(corpus["text"])) # type: ignore
]
else:
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
sentences = [
(doc["title"] + self.sep + doc["text"]).strip()
if "title" in doc
else doc["text"].strip()
for doc in corpus
]
return self._embed(sentences, input_type="search_document")


Expand Down

0 comments on commit e1d18f1

Please sign in to comment.