1111 min_p = 0.0 ,
1212 # strictly disabled for now
1313 top_k = 0 ,
14- # top_p=0 .0,
14+ top_p = 1 .0 ,
1515 # frequency_penalties=0.0,
1616 # presence_penalties=0.0,
1717 # repetition_penalties=0.0,
@@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata:
2626 temperature : torch .Tensor = None
2727
2828 min_p : torch .Tensor = None
29- # Still too slow on forward_native!
3029 top_k : torch .Tensor = None
3130 top_p : torch .Tensor = None
3231
33- # Greedy sampling flag for compiling single xla graph.
3432 all_greedy : bool = True
3533
3634 # unsupported, you need to return an extra tensor of static size BxV
@@ -103,17 +101,17 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
103101 DEFAULT_SAMPLING_PARAMS ["min_p" ])
104102 fill_slice (input_batch .top_k_cpu_tensor ,
105103 DEFAULT_SAMPLING_PARAMS ["top_k" ])
106- # TODO Temporarily disabled until sampling options are enabled
107- # fill_slice(input_batch.top_p_cpu_tensor,
108- # DEFAULT_SAMPLING_PARAMS["top_p"])
104+ fill_slice (input_batch .top_p_cpu_tensor ,
105+ DEFAULT_SAMPLING_PARAMS ["top_p" ])
109106
110107 # Slice persistent device tensors to a fixed pre-compiled padded shape.
111108 return cls (
112109 temperature = input_batch .temperature_cpu_tensor [:padded_num_reqs ].
113110 to (xla_device ),
114111 all_greedy = input_batch .all_greedy ,
115112 # TODO enable more and avoid returning None values
116- top_p = None , # input_batch.top_p[:padded_num_reqs],
113+ top_p = input_batch .top_p_cpu_tensor [:padded_num_reqs ].to (
114+ xla_device ),
117115 top_k = input_batch .top_k_cpu_tensor [:padded_num_reqs ].to (
118116 xla_device ),
119117 min_p = input_batch .min_p_cpu_tensor [:padded_num_reqs ].to (
0 commit comments