Skip to content

Commit

Permalink
Improve multinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed May 24, 2024
1 parent c453e3e commit 9038308
Showing 1 changed file with 28 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,48 +296,44 @@ class Sampler {
}

std::vector<int64_t> _multinomial_sample(ov::Tensor logits, float temperature, float top_p, size_t top_k, size_t n) {
std::vector<int64_t> out_tokens;
ov::Shape logits_shape = logits.get_shape();
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
// OPENVINO_ASSERT(batch_size == 1);

const float * logits_data = logits.data<const float>() + (seq_len - 1) * vocab_size;
std::vector<LogitWithIdx> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = LogitWithIdx(logits_data[i], i);
}
for (size_t i = 0; i < batch_size; ++i) {
const float * logits_data = logits.data<const float>() + (seq_len - 1) * vocab_size * i;
std::vector<LogitWithIdx> logit_vector(vocab_size);
for (size_t i = 0; i < logit_vector.size(); i++) {
logit_vector[i] = LogitWithIdx(logits_data[i], i);
}

auto temperature_transform = TemperatureLogitTransform(temperature);
std::vector<ProbabilityWithIdx> softmax_vector = temperature_transform.apply(logit_vector);
auto temperature_transform = TemperatureLogitTransform(temperature);
std::vector<ProbabilityWithIdx> softmax_vector = temperature_transform.apply(logit_vector);

std::vector<ProbabilityWithIdx> filtered(softmax_vector);
std::vector<ProbabilityWithIdx> filtered(softmax_vector);

if (top_p != 0.0f) {
auto filter = TopPFilter(top_p);
filtered = filter.filter(filtered);
}
if (top_p != 0.0f) {
auto filter = TopPFilter(top_p);
filtered = filter.filter(filtered);
}

if (top_k != 0) {
auto filter = TopKFilter(top_k);
filtered = filter.filter(filtered);
}
if (top_k != 0) {
auto filter = TopKFilter(top_k);
filtered = filter.filter(filtered);
}

auto normalize_transform = ProbabilityNormalizeTransform();
filtered = normalize_transform.apply(filtered);
std::vector<float> multinomial_weights(filtered.size());
for (size_t i = 0; i < filtered.size(); i++) multinomial_weights[i] = filtered[i].first;
// if (n > filtered.size()) {
// n = filtered.size();
// }
auto normalize_transform = ProbabilityNormalizeTransform();
filtered = normalize_transform.apply(filtered);
std::vector<float> multinomial_weights(filtered.size());
for (size_t i = 0; i < filtered.size(); i++) multinomial_weights[i] = filtered[i].first;
// if (n > filtered.size()) {
// n = filtered.size();
// }

auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1
std::vector<int64_t> out_tokens;
do {
auto dist = std::discrete_distribution<size_t>(multinomial_weights.begin(), multinomial_weights.end()); // equivalent to multinomial with number of trials == 1
size_t element_to_pick = dist(rng_engine);
int64_t out_token = filtered[element_to_pick].second;
// if (std::find(out_tokens.begin(), out_tokens.end(), out_token) == out_tokens.end()) {
out_tokens.push_back(out_token);
// }
} while (out_tokens.size() < n);
out_tokens.push_back(out_token);
}

return out_tokens;
}
Expand Down

0 comments on commit 9038308

Please sign in to comment.