84 {
85 if (logits.empty()) {
86 throw std::runtime_error("Cannot sample from empty logits.");
87 }
88
89
90 if (temperature < 0.05f) {
91 return std::distance(logits.begin(), std::max_element(logits.begin(), logits.end()));
92 }
93
94 int vocab_size = logits.size();
95
96 top_k = (std::min)(top_k, vocab_size);
97 if (top_k <= 0) top_k = vocab_size;
98
99 std::vector<float> scaled_logits(vocab_size);
100 float max_logit = -std::numeric_limits<float>::infinity();
101 for (float logit : logits) max_logit = (std::max)(max_logit, logit);
102
103
104 const float scale = 1.0f / temperature;
105 for (int i = 0; i < vocab_size; ++i) {
106 scaled_logits[i] = (logits[i] - max_logit) * scale;
107 }
108
109 std::vector<double> probs_double(vocab_size);
110 double sum_exp = 0.0;
111 for (int i = 0; i < vocab_size; ++i) {
112 probs_double[i] = std::exp(static_cast<double>(scaled_logits[i]));
113 sum_exp += probs_double[i];
114 }
115
116
117 if (sum_exp > 0.0) {
118 for (int i = 0; i < vocab_size; ++i) {
119 probs_double[i] /= sum_exp;
120 }
121 } else {
122
123 for (int i = 0; i < vocab_size; ++i) {
124 probs_double[i] = 1.0 / vocab_size;
125 }
126 }
127
128 std::vector<std::pair<float, int>> prob_idx(vocab_size);
129 for (int i = 0; i < vocab_size; ++i) {
130 prob_idx[i] = {static_cast<float>(probs_double[i]), i};
131 }
132
133 std::sort(prob_idx.begin(), prob_idx.end(),
134 std::greater<std::pair<float, int>>());
135
136 if (top_k < vocab_size) {
137 prob_idx.resize(top_k);
138 }
139
140 float cumulative_prob = 0.0f;
141 int last_idx = 0;
142 for (int i = 0; i < prob_idx.size(); ++i) {
143 cumulative_prob += prob_idx[i].first;
144 last_idx = i;
145 if (cumulative_prob >= top_p) {
146 break;
147 }
148 }
149 prob_idx.resize(last_idx + 1);
150
151 float final_sum = 0.0f;
152 for (const auto& pi : prob_idx) {
153 final_sum += pi.first;
154 }
155
156
157 std::vector<float> final_probs(prob_idx.size());
158 if (final_sum > 0.0f) {
159 for (size_t i = 0; i < prob_idx.size(); ++i) {
160 final_probs[i] = prob_idx[i].first / final_sum;
161 }
162 } else {
163
164 float uniform_prob = 1.0f / prob_idx.size();
165 std::fill(final_probs.begin(), final_probs.end(), uniform_prob);
166 }
167
168 std::discrete_distribution<int> dist(final_probs.begin(), final_probs.end());
169 int sampled_idx_in_filtered = dist(rng);
170
171 return prob_idx[sampled_idx_in_filtered].second;
172}