12 const std::vector<float>& batch_input_activations,
13 int num_tokens_in_batch,
14 int num_cpu_layers_to_process,
15 int start_pos_in_sequence,
17 const std::vector<int>& prompt_lengths) {
20 Logger::error(
"[CPU_BATCH_FWD] Input size mismatch. Expected: " +
22 std::to_string(batch_input_activations.size()));
31 Logger::error(
"[CPU_BATCH_FWD] Error: num_attention_heads is zero.");
34 int head_dim = hs / n_heads;
38 float attention_scale = 1.0f /
SAFE_SQRT(
static_cast<float>(head_dim));
40 std::vector<float> current_batch_activations = batch_input_activations;
42 std::vector<int> sequence_indices(num_tokens_in_batch);
43 std::vector<int> position_in_sequence(num_tokens_in_batch);
45 if (!prompt_lengths.empty()) {
47 for (
size_t seq_idx = 0; seq_idx < prompt_lengths.size(); ++seq_idx) {
48 for (
int pos = 0; pos < prompt_lengths[seq_idx]; ++pos) {
49 if (token_offset >= num_tokens_in_batch) {
50 Logger::error(
"[CPU_BATCH_FWD] Token offset exceeded num_tokens_in_batch");
53 sequence_indices[token_offset] = seq_idx;
54 position_in_sequence[token_offset] = pos;
59 for (
int token_idx = 0; token_idx < num_tokens_in_batch; ++token_idx) {
60 sequence_indices[token_idx] = 0;
61 position_in_sequence[token_idx] = start_pos_in_sequence + token_idx;
65 for (
int l = 0; l < num_cpu_layers_to_process; ++l) {
76 std::vector<float> batch_x_norm1(current_batch_activations.size());
77 const std::vector<float>& w_input_norm_vec =
78 lw.input_layernorm_f32.empty()
80 : lw.input_layernorm_f32;
81 rmsnorm_batch_cpu(current_batch_activations, w_input_norm_vec, batch_x_norm1, num_tokens_in_batch, hs, eps);
83 std::vector<float> residual_batch_component_attn = current_batch_activations;
85 std::vector<float> q_batch((
size_t)num_tokens_in_batch * hs);
86 std::vector<float> k_batch((
size_t)num_tokens_in_batch * n_kv_heads * head_dim);
87 std::vector<float> v_batch((
size_t)num_tokens_in_batch * n_kv_heads * head_dim);
89 if (!lw.q_proj_f32.empty()) {
91 }
else if (!lw.q_proj_q8_0.empty()) {
93 }
else if (!lw.q_proj_q6k.empty()) {
95 }
else if (!lw.q_proj_q4k.empty()) {
98 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No Q proj weights found for CPU");
102 if (!lw.k_proj_f32.empty()) {
104 }
else if (!lw.k_proj_q8_0.empty()) {
106 }
else if (!lw.k_proj_q6k.empty()) {
108 }
else if (!lw.k_proj_q4k.empty()) {
111 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No K proj weights found for CPU");
115 if (!lw.v_proj_f32.empty()) {
117 }
else if (!lw.v_proj_q8_0.empty()) {
119 }
else if (!lw.v_proj_q6k.empty()) {
121 }
else if (!lw.v_proj_q4k.empty()) {
124 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No V proj weights found for CPU");
128 if (!prompt_lengths.empty()) {
129 for (
int t = 0; t < num_tokens_in_batch; ++t) {
130 int current_token_pos = position_in_sequence[t];
131 int seq_idx = sequence_indices[t];
133 if (current_token_pos < 0 || current_token_pos >= max_pos_embeddings) {
134 Logger::warning(
"[CPU_BATCH_FWD] Token " + std::to_string(t) +
" (seq=" + std::to_string(seq_idx) +
135 ", pos=" + std::to_string(current_token_pos) +
") is out of range. Skipping RoPE.");
139 std::vector<float> q_token(hs);
140 std::vector<float> k_token(n_kv_heads * head_dim);
142 std::copy(q_batch.begin() + (
size_t)t * hs,
143 q_batch.begin() + (
size_t)(t + 1) * hs,
145 std::copy(k_batch.begin() + (
size_t)t * n_kv_heads * head_dim,
146 k_batch.begin() + (
size_t)(t + 1) * n_kv_heads * head_dim,
152 std::copy(q_token.begin(), q_token.end(), q_batch.begin() + (
size_t)t * hs);
153 std::copy(k_token.begin(), k_token.end(), k_batch.begin() + (
size_t)t * n_kv_heads * head_dim);
161 if (!prompt_lengths.empty()) {
163 sequence_indices, position_in_sequence, n_kv_heads, head_dim);
166 start_pos_in_sequence, n_kv_heads, head_dim);
170 std::vector<float> batch_attn_output((
size_t)num_tokens_in_batch * hs);
172 if (kv_cache &&
static_cast<size_t>(l) < kv_cache->
layers.size()) {
173 if (!prompt_lengths.empty()) {
175 num_tokens_in_batch, sequence_indices, position_in_sequence,
176 n_heads, n_kv_heads, head_dim, attention_scale,
180 num_tokens_in_batch, start_pos_in_sequence,
181 n_heads, n_kv_heads, head_dim, attention_scale);
183 }
else if (kv_cache) {
185 " is out of bounds for KV Cache access during attention. KVCache layers size: " +
186 std::to_string(kv_cache->
layers.size()) +
187 ". Filling attention output with zeros.");
188 std::fill(batch_attn_output.begin(), batch_attn_output.end(), 0.0f);
190 Logger::error(
"[CPU_BATCH_FWD] KV Cache is null, cannot perform attention for layer " + std::to_string(l) +
191 ". Filling attention output with zeros.");
192 std::fill(batch_attn_output.begin(), batch_attn_output.end(), 0.0f);
195 std::vector<float> batch_attn_proj_out((
size_t)num_tokens_in_batch * hs);
196 if(!lw.o_proj_f32.empty()) {
198 }
else if (!lw.o_proj_q8_0.empty()) {
200 }
else if (!lw.o_proj_q6k.empty()) {
202 }
else if (!lw.o_proj_q4k.empty()) {
205 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No O proj weights found for CPU");
209 for(
size_t i=0; i < current_batch_activations.size(); ++i) {
210 current_batch_activations[i] = residual_batch_component_attn[i] + batch_attn_proj_out[i];
213 std::vector<float> residual_batch_component_mlp = current_batch_activations;
214 std::vector<float> batch_x_norm2(current_batch_activations.size());
215 const std::vector<float>& w_post_attn_norm_vec =
216 lw.post_attention_layernorm_f32.empty()
218 : lw.post_attention_layernorm_f32;
220 rmsnorm_batch_cpu(current_batch_activations, w_post_attn_norm_vec, batch_x_norm2, num_tokens_in_batch, hs, eps);
222 std::vector<float> batch_gate_proj_out((
size_t)num_tokens_in_batch * is);
223 std::vector<float> batch_up_proj_out((
size_t)num_tokens_in_batch * is);
225 if (!lw.gate_proj_f32.empty()) {
227 }
else if (!lw.gate_proj_q8_0.empty()) {
229 }
else if (!lw.gate_proj_q6k.empty()) {
231 }
else if (!lw.gate_proj_q4k.empty()) {
234 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No gate_proj weights found for CPU");
238 if (!lw.up_proj_f32.empty()) {
240 }
else if (!lw.up_proj_q8_0.empty()) {
242 }
else if (!lw.up_proj_q6k.empty()) {
244 }
else if (!lw.up_proj_q4k.empty()) {
247 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No up_proj weights found for CPU");
251 std::vector<float> batch_swiglu_out((
size_t)num_tokens_in_batch * is);
252 for (
size_t i = 0; i < batch_gate_proj_out.size(); ++i) {
253 float gate_val = batch_gate_proj_out[i];
254 float silu_gate_val = gate_val / (1.0f + std::exp(-gate_val));
255 batch_swiglu_out[i] = silu_gate_val * batch_up_proj_out[i];
258 std::vector<float> batch_mlp_down_proj_out((
size_t)num_tokens_in_batch * hs);
259 if (!lw.down_proj_f32.empty()) {
261 }
else if (!lw.down_proj_q8_0.empty()) {
263 }
else if (!lw.down_proj_q6k.empty()) {
265 }
else if (!lw.down_proj_q4k.empty()) {
268 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No down_proj weights found for CPU");
272 for(
size_t i = 0; i < current_batch_activations.size(); ++i) {
273 current_batch_activations[i] = residual_batch_component_mlp[i] + batch_mlp_down_proj_out[i];
277 if (kv_cache && num_tokens_in_batch > 0) {
278 kv_cache->
seq_len = start_pos_in_sequence + num_tokens_in_batch;
280 return current_batch_activations;