10 const std::vector<float>& k_batch_for_layer,
11 const std::vector<float>& v_batch_for_layer,
12 int num_tokens_in_batch,
13 int start_pos_in_sequence,
21 if (layer_idx < 0 ||
static_cast<size_t>(layer_idx) >= kv_cache->
layers.size()) {
22 Logger::error(
"update_kv_cache_batch_cpu: layer_idx " + std::to_string(layer_idx) +
" is out of bounds for KVCache layers (size " + std::to_string(kv_cache->
layers.size()) +
").");
25 Logger::info(
"[CPU_KV_UPDATE] Layer=" + std::to_string(layer_idx) +
26 ", start_pos=" + std::to_string(start_pos_in_sequence) +
27 ", num_tokens=" + std::to_string(num_tokens_in_batch) +
28 ", k_batch_first_vals=[" + std::to_string(k_batch_for_layer[0]) +
29 "," + std::to_string(k_batch_for_layer[1]) +
"," + std::to_string(k_batch_for_layer[2]) +
"]");
31 int kv_dim = num_kv_heads * head_dim;
33 if (k_batch_for_layer.size() !=
static_cast<size_t>(num_tokens_in_batch * kv_dim)) {
34 Logger::error(
"[KV_BATCH_UPDATE L" + std::to_string(layer_idx) +
"] k_batch_for_layer size mismatch. Expected " +
35 std::to_string(num_tokens_in_batch * kv_dim) +
", got " + std::to_string(k_batch_for_layer.size()));
39 if (v_batch_for_layer.size() !=
static_cast<size_t>(num_tokens_in_batch * kv_dim)) {
40 Logger::error(
"[KV_BATCH_UPDATE L" + std::to_string(layer_idx) +
"] v_batch_for_layer size mismatch. Expected " +
41 std::to_string(num_tokens_in_batch * kv_dim) +
", got " + std::to_string(v_batch_for_layer.size()));
46 if (layer_cache.
k.size() != expected_total_elements_in_layer_cache || layer_cache.
v.size() != expected_total_elements_in_layer_cache) {
47 Logger::error(
"[KV_BATCH_UPDATE L" + std::to_string(layer_idx) +
48 "] Precondition failed: Layer cache not sized to max_seq_len_config. K size: " + std::to_string(layer_cache.
k.size()) +
49 ", V size: " + std::to_string(layer_cache.
v.size()) +
50 ", Expected size: " + std::to_string(expected_total_elements_in_layer_cache) +
51 ". Check KVCache::initialize.");
54 for (
int token_idx_in_batch = 0; token_idx_in_batch < num_tokens_in_batch; ++token_idx_in_batch) {
55 size_t current_token_batch_offset =
static_cast<size_t>(token_idx_in_batch) * kv_dim;
57 int global_seq_pos = start_pos_in_sequence + token_idx_in_batch;
60 Logger::error(
"[KV_BATCH_UPDATE L" + std::to_string(layer_idx) +
61 "] Error: global_seq_pos (" + std::to_string(global_seq_pos) +
62 ") is out of bounds for total cache size. Skipping update for this token.");
66 size_t destination_offset_in_layer_cache =
static_cast<size_t>(global_seq_pos) * kv_dim;
67 size_t k_size_before = layer_cache.
k.size();
68 std::string k_vals_to_log =
" vals to copy: ";
69 for(
int i = 0; i < std::min(3, kv_dim); ++i) { k_vals_to_log += std::to_string(k_batch_for_layer[current_token_batch_offset + i]) +
" "; }
70 if (kv_dim > 3) k_vals_to_log +=
"...";
73 std::copy(k_batch_for_layer.begin() + current_token_batch_offset,
74 k_batch_for_layer.begin() + current_token_batch_offset + kv_dim,
75 layer_cache.
k.begin() + destination_offset_in_layer_cache);
78 size_t v_size_before = layer_cache.
v.size();
79 std::string v_vals_to_log =
" vals to copy: ";
80 for(
int i = 0; i < std::min(3, kv_dim); ++i) { v_vals_to_log += std::to_string(v_batch_for_layer[current_token_batch_offset + i]) +
" "; }
81 if (kv_dim > 3) v_vals_to_log +=
"...";
83 std::copy(v_batch_for_layer.begin() + current_token_batch_offset,
84 v_batch_for_layer.begin() + current_token_batch_offset + kv_dim,
85 layer_cache.
v.begin() + destination_offset_in_layer_cache);
92 const std::vector<float>& q_batch_roped,
94 std::vector<float>& batch_attn_output,
95 int num_tokens_in_batch,
96 int start_pos_in_sequence,
100 float attention_scale
102 size_t expected_q_size = (size_t)num_tokens_in_batch * num_q_heads * head_dim;
103 if (q_batch_roped.size() != expected_q_size) {
104 Logger::error(
"[ATTN_BATCH_CPU] q_batch_roped size mismatch. Expected: " + std::to_string(expected_q_size) +
105 ", Got: " + std::to_string(q_batch_roped.size()));
106 std::fill(batch_attn_output.begin(), batch_attn_output.end(), 0.0f);
109 Logger::info(
"[ATTENTION_BATCH_CPU_ENTRY] Called with num_tokens=" + std::to_string(num_tokens_in_batch));
110 size_t expected_output_size = (size_t)num_tokens_in_batch * num_q_heads * head_dim;
111 batch_attn_output.assign(expected_output_size, 0.0f);
115 for (
int token_idx = 0; token_idx < num_tokens_in_batch; ++token_idx) {
116 size_t q_token_offset = (size_t)token_idx * num_q_heads * head_dim;
117 size_t attn_out_token_offset = (size_t)token_idx * num_q_heads * head_dim;
118 int current_token_absolute_pos = start_pos_in_sequence + token_idx;
120 for (
int h_q = 0; h_q < num_q_heads; ++h_q) {
121 const float* q_head_for_token_ptr = q_batch_roped.data() + q_token_offset + (h_q * head_dim);
122 int kv_group_head_idx = h_q / (num_q_heads / num_kv_heads);
124 bool log_details_for_this_head = (token_idx == 0 && h_q == 0);
127 int history_len = current_token_absolute_pos + 1;
128 if (history_len <= 0) {
129 Logger::warning(
"[ATTN_BATCH_CPU] Token_idx " + std::to_string(token_idx) +
", Q_Head " + std::to_string(h_q) +
130 ": history_len is " + std::to_string(history_len) +
". Skipping score calculation for this head.");
133 std::vector<float> scores(history_len);
135 for (
int t_hist = 0; t_hist < history_len; ++t_hist) {
136 size_t k_cache_offset = ((size_t)t_hist * num_kv_heads + kv_group_head_idx) * head_dim;
137 if (token_idx == 0 && h_q == 0 && t_hist < 3) {
138 Logger::info(
"[CPU_ATTN_MEM] T" + std::to_string(token_idx) +
"_H" + std::to_string(h_q) +
139 " accessing K_cache[pos=" + std::to_string(t_hist) +
",kv_head=" + std::to_string(kv_group_head_idx) +
140 "]: offset=" + std::to_string(k_cache_offset) +
141 ", k_vals=[" + std::to_string(current_layer_kv_cache.
k[k_cache_offset]) +
142 "," + std::to_string(current_layer_kv_cache.
k[k_cache_offset + 1]) +
143 "," + std::to_string(current_layer_kv_cache.
k[k_cache_offset + 2]) +
"]");
145 if (k_cache_offset + head_dim > current_layer_kv_cache.
k.size()) {
146 Logger::error(
"[ATTN_BATCH_CPU] K cache out of bounds. Token_idx " + std::to_string(token_idx) +
147 " (abs_pos " + std::to_string(current_token_absolute_pos) +
"), Q_Head " + std::to_string(h_q) +
148 ", history_pos " + std::to_string(t_hist) +
149 ". Required k_cache_offset " + std::to_string(k_cache_offset + head_dim) +
150 " > cache_k_size " + std::to_string(current_layer_kv_cache.
k.size()));
151 scores[t_hist] = -std::numeric_limits<float>::infinity();
155 float current_dot_product = 0.0f;
156 for (
int d = 0; d < head_dim; ++d) {
157 current_dot_product += q_head_for_token_ptr[d] * current_layer_kv_cache.
k[k_cache_offset + d];
159 if (token_idx == 0 && h_q == 0 && t_hist < 2) {
160 Logger::info(
"[CPU_ATTN_SCORE] T0_H0 pos=" + std::to_string(t_hist) +
161 ", q_vals=[" + std::to_string(q_head_for_token_ptr[0]) +
162 "," + std::to_string(q_head_for_token_ptr[1]) +
"] " +
163 ", k_vals=[" + std::to_string(current_layer_kv_cache.
k[k_cache_offset]) +
164 "," + std::to_string(current_layer_kv_cache.
k[k_cache_offset + 1]) +
"]" +
165 ", dot=" + std::to_string(current_dot_product) +
", scale=" + std::to_string(attention_scale));
167 scores[t_hist] = current_dot_product * attention_scale;
172 if (token_idx == 0 && h_q == 0) {
173 std::string scores_str =
"";
174 for (
int i = 0; i < std::min(3, (
int)scores.size()); i++) {
175 scores_str += std::to_string(scores[i]) +
" ";
177 Logger::info(
"[CPU_SOFTMAX] T0_H0 first_3_probs=[" + scores_str +
"]");
179 float* current_attn_out_head_ptr = batch_attn_output.data() + attn_out_token_offset + (h_q * head_dim);
181 for (
int t_hist = 0; t_hist < history_len; ++t_hist) {
182 if (scores[t_hist] == -std::numeric_limits<float>::infinity() || scores[t_hist] == 0.0f)
continue;
184 size_t v_cache_offset = ((size_t)t_hist * num_kv_heads + kv_group_head_idx) * head_dim;
185 if (v_cache_offset + head_dim > current_layer_kv_cache.
v.size()) {
186 Logger::error(
"[ATTN_BATCH_CPU] V cache out of bounds. Token_idx " + std::to_string(token_idx) +
187 " (abs_pos " + std::to_string(current_token_absolute_pos) +
"), Q_Head " + std::to_string(h_q) +
188 ", history_pos " + std::to_string(t_hist) +
189 ". Required v_cache_offset " + std::to_string(v_cache_offset + head_dim) +
190 " > cache_v_size " + std::to_string(current_layer_kv_cache.
v.size()));
194 for (
int d = 0; d < head_dim; ++d) {
195 float val_before = (log_details_for_this_head && t_hist < 2 && d < 2) ? current_attn_out_head_ptr[d] : 0.0f;
196 current_attn_out_head_ptr[d] += scores[t_hist] * current_layer_kv_cache.
v[v_cache_offset + d];
206 const std::vector<float>& k_batch_for_layer,
207 const std::vector<float>& v_batch_for_layer,
208 int num_tokens_in_batch,
209 const std::vector<int>& sequence_indices,
210 const std::vector<int>& position_in_sequence,
215 Logger::error(
"update_kv_cache_batch_cpu_sequence_aware: KVCache is null.");
218 if (layer_idx < 0 ||
static_cast<size_t>(layer_idx) >= kv_cache->
layers.size()) {
219 Logger::error(
"update_kv_cache_batch_cpu_sequence_aware: layer_idx " + std::to_string(layer_idx) +
220 " is out of bounds for KVCache layers (size " + std::to_string(kv_cache->
layers.size()) +
").");
225 int kv_dim = num_kv_heads * head_dim;
227 for (
int token_idx = 0; token_idx < num_tokens_in_batch; ++token_idx) {
228 size_t current_token_batch_offset =
static_cast<size_t>(token_idx) * kv_dim;
230 int seq_idx = sequence_indices[token_idx];
231 int pos_in_seq = position_in_sequence[token_idx];
234 int actual_cache_position = sequence_base_offset + pos_in_seq;
236 Logger::error(
"[KV_BATCH_UPDATE_SEQ_AWARE L" + std::to_string(layer_idx) +
237 "] Error: actual_cache_position (" + std::to_string(actual_cache_position) +
238 ") is out of bounds for total cache size. Skipping update for this token.");
242 size_t destination_offset_in_layer_cache =
static_cast<size_t>(actual_cache_position) * kv_dim;
244 std::copy(k_batch_for_layer.begin() + current_token_batch_offset,
245 k_batch_for_layer.begin() + current_token_batch_offset + kv_dim,
246 layer_cache.
k.begin() + destination_offset_in_layer_cache);
248 std::copy(v_batch_for_layer.begin() + current_token_batch_offset,
249 v_batch_for_layer.begin() + current_token_batch_offset + kv_dim,
250 layer_cache.
v.begin() + destination_offset_in_layer_cache);
255 const std::vector<float>& q_batch_roped,
257 std::vector<float>& batch_attn_output,
258 int num_tokens_in_batch,
259 const std::vector<int>& sequence_indices,
260 const std::vector<int>& position_in_sequence,
264 float attention_scale,
265 int max_seq_len_per_sequence
267 size_t expected_q_size = (size_t)num_tokens_in_batch * num_q_heads * head_dim;
268 if (q_batch_roped.size() != expected_q_size) {
269 Logger::error(
"[ATTN_BATCH_CPU_SEQ_AWARE] q_batch_roped size mismatch. Expected: " + std::to_string(expected_q_size) +
270 ", Got: " + std::to_string(q_batch_roped.size()));
271 std::fill(batch_attn_output.begin(), batch_attn_output.end(), 0.0f);
275 batch_attn_output.assign((
size_t)num_tokens_in_batch * num_q_heads * head_dim, 0.0f);
277 for (
int token_idx = 0; token_idx < num_tokens_in_batch; ++token_idx) {
278 size_t q_token_offset = (size_t)token_idx * num_q_heads * head_dim;
279 size_t attn_out_token_offset = (size_t)token_idx * num_q_heads * head_dim;
281 int seq_idx = sequence_indices[token_idx];
282 int pos_in_seq = position_in_sequence[token_idx];
283 int sequence_base_offset = seq_idx * max_seq_len_per_sequence;
285 for (
int h_q = 0; h_q < num_q_heads; ++h_q) {
286 const float* q_head_for_token_ptr = q_batch_roped.data() + q_token_offset + (h_q * head_dim);
287 int kv_group_head_idx = h_q / (num_q_heads / num_kv_heads);
289 int history_len = pos_in_seq + 1;
290 std::vector<float> scores(history_len);
292 for (
int t_hist = 0; t_hist < history_len; ++t_hist) {
293 size_t k_cache_offset = ((size_t)(sequence_base_offset + t_hist) * num_kv_heads + kv_group_head_idx) * head_dim;
295 if (k_cache_offset + head_dim > current_layer_kv_cache.
k.size()) {
296 scores[t_hist] = -std::numeric_limits<float>::infinity();
300 float current_dot_product = 0.0f;
301 for (
int d = 0; d < head_dim; ++d) {
302 current_dot_product += q_head_for_token_ptr[d] * current_layer_kv_cache.
k[k_cache_offset + d];
304 scores[t_hist] = current_dot_product * attention_scale;
309 float* current_attn_out_head_ptr = batch_attn_output.data() + attn_out_token_offset + (h_q * head_dim);
311 for (
int t_hist = 0; t_hist < history_len; ++t_hist) {
312 if (scores[t_hist] == -std::numeric_limits<float>::infinity() || scores[t_hist] == 0.0f)
continue;
314 size_t v_cache_offset = ((size_t)(sequence_base_offset + t_hist) * num_kv_heads + kv_group_head_idx) * head_dim;
315 if (v_cache_offset + head_dim > current_layer_kv_cache.
v.size()) {
319 for (
int d = 0; d < head_dim; ++d) {
320 current_attn_out_head_ptr[d] += scores[t_hist] * current_layer_kv_cache.
v[v_cache_offset + d];