Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,11 +1326,12 @@ struct server_context_impl {
}

void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
const size_t n_probs = slot.task->params.sampling.n_probs;
const size_t n_probs_request = slot.task->params.sampling.n_probs;

if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
const size_t max_probs = cur_p->size;
const size_t n_probs = std::min(max_probs, n_probs_request);

// set probability for sampled token
for (size_t i = 0; i < max_probs; i++) {
Expand All @@ -1341,8 +1342,8 @@ struct server_context_impl {
}

// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.reserve(n_probs);
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),
Expand All @@ -1352,9 +1353,11 @@ struct server_context_impl {
} else {
// TODO: optimize this with min-p optimization
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
const size_t max_probs = cur.size();
const size_t n_probs = std::min(max_probs, n_probs_request);

// set probability for sampled token
for (size_t i = 0; i < cur.size(); i++) {
for (size_t i = 0; i < max_probs; i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
result.prob = cur[i].p;
Expand All @@ -1364,7 +1367,7 @@ struct server_context_impl {

// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
for (size_t i = 0; i < n_probs; i++) {
result.probs.push_back({
cur[i].id,
common_token_to_piece(ctx, cur[i].id, special),
Expand Down
Loading