Skip to content

Commit

Permalink
fix: pass do sample generation (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 committed Jan 17, 2024
1 parent 637fb14 commit c84b2fd
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ format:
black manifest/ tests/ web_app/

check:
isort -c -v manifest/ tests/ web_app/
isort -c manifest/ tests/ web_app/
black manifest/ tests/ web_app/ --check
flake8 manifest/ tests/ web_app/
mypy manifest/ tests/ web_app/
Expand Down
2 changes: 1 addition & 1 deletion manifest/api/models/diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_init_params(self) -> Dict:
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.
Expand Down
7 changes: 4 additions & 3 deletions manifest/api/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(

def __call__(
self, text: Union[str, List[str]], **kwargs: Any
) -> List[Dict[str, Union[str, List[float]]]]:
) -> List[Dict[str, Union[str, List[float], List[str]]]]:
"""Generate from text.
Args:
Expand Down Expand Up @@ -162,6 +162,7 @@ def __call__(
top_p=kwargs.get("top_p"),
repetition_penalty=kwargs.get("repetition_penalty"),
num_return_sequences=kwargs.get("num_return_sequences"),
do_sample=kwargs.get("do_sample"),
)
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
output_dict = self.model.generate( # type: ignore
Expand Down Expand Up @@ -587,7 +588,7 @@ def embed(self, prompt: Union[str, List[str]], **kwargs: Any) -> np.ndarray:
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.
Expand Down Expand Up @@ -616,7 +617,7 @@ def generate(
(
cast(str, r["generated_text"]),
sum(cast(List[float], r["logprobs"])),
cast(List[int], r["tokens"]),
cast(List[str], r["tokens"]),
cast(List[float], r["logprobs"]),
)
for r in result
Expand Down
2 changes: 1 addition & 1 deletion manifest/api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_init_params(self) -> Dict:

def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.
Expand Down
2 changes: 1 addition & 1 deletion manifest/api/models/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_init_params(self) -> Dict:
@torch.no_grad()
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[Any, float, List[int], List[float]]]:
) -> List[Tuple[Any, float, List[str], List[float]]]:
"""
Generate the prompt from model.
Expand Down

0 comments on commit c84b2fd

Please sign in to comment.