Skip to content

Commit

Permalink
Cast embeddings to float32 before computing distances.
Browse files Browse the repository at this point in the history
This fixes a bug where bfloat16 is not supported in nonGPU or nonTPU machine.

PiperOrigin-RevId: 646651355
  • Loading branch information
llcourage authored and LIT team committed Jun 25, 2024
1 parent 2e9d267 commit 5456011
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lit_nlp/components/nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def run(

# <float32>[emb_size]
dataset_embs = [output[nnconf.embedding_name] for output in dataset_outputs]
dataset_embs = [emb.astype(np.float32) for emb in dataset_embs]

example_embs = [example_output[nnconf.embedding_name]]
example_embs = [emb.astype(np.float32) for emb in example_embs]

distances = distance.cdist(example_embs, dataset_embs)[0]
sorted_indices = np.argsort(distances)
k = nnconf.num_neighbors
Expand Down

0 comments on commit 5456011

Please sign in to comment.