Skip to content

Commit

Permalink
fixes after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Valery Chernov committed Feb 8, 2024
1 parent 8686a1d commit e34d042
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
13 changes: 13 additions & 0 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,23 @@ def sample_from_logits(
return outputs


def get_logprob_infos(
i: int,
logprob_infos: Optional[List[Optional[RawLogprobsInfo]]],
) -> Optional[List[Optional[RawLogprobsInfo]]]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]


def sample_loglikelihood_from_logits(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sequence_ids: List[SequenceId],
) -> List[TextGenerationResult]:
# Convert to torch tensors if logits are in tvm ndarray
if isinstance(logits, tvm.nd.NDArray):
logits = torch.from_dlpack(logits)

# TODO(vvchernov): cut prompt lengths from logits
try:
logprob_infos = loglikelihood_sample(logits)
Expand Down
2 changes: 0 additions & 2 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,6 @@ def get_loglikelihood_logprob_infos(
def loglikelihood_sample(
logits: torch.Tensor,
) -> List[Optional[RawLogprobsInfo]]:
logits = torch.from_dlpack(logits)

# TODO(vvchernov): we need token_ids from input
# It is not neccessary top1
res = torch.argmax(logits, -1)
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def generate(
request_past_decode_tokens = [self.vocab_size]
elif isinstance(request, LoglikelihoodRequest):
seq_id = get_prompt_sequence_id(request.request_id)
# TODO(vvchernov): it it needed?
# TODO(vvchernov): is it needed?
request_past_decode_tokens = [self.vocab_size]
elif isinstance(request, DecodeRequest):
seq_id = request.sequence_id
Expand Down

0 comments on commit e34d042

Please sign in to comment.