17 {
18
20 Logger::error(
"[CPU_BATCH_FWD] Input size mismatch. Expected: " +
22 std::to_string(batch_input_activations.size()));
23 return {};
24 }
25
30 if (n_heads == 0) {
31 Logger::error(
"[CPU_BATCH_FWD] Error: num_attention_heads is zero.");
32 return {};
33 }
34 int head_dim = hs / n_heads;
38 float attention_scale = 1.0f /
SAFE_SQRT(
static_cast<float>(head_dim));
39
40 std::vector<float> current_batch_activations = batch_input_activations;
41
42 std::vector<int> sequence_indices(num_tokens_in_batch);
43 std::vector<int> position_in_sequence(num_tokens_in_batch);
44
45 if (!prompt_lengths.empty()) {
46 int token_offset = 0;
47 for (size_t seq_idx = 0; seq_idx < prompt_lengths.size(); ++seq_idx) {
48 for (int pos = 0; pos < prompt_lengths[seq_idx]; ++pos) {
49 if (token_offset >= num_tokens_in_batch) {
50 Logger::error(
"[CPU_BATCH_FWD] Token offset exceeded num_tokens_in_batch");
51 return {};
52 }
53 sequence_indices[token_offset] = seq_idx;
54 position_in_sequence[token_offset] = pos;
55 token_offset++;
56 }
57 }
58 } else {
59 for (int token_idx = 0; token_idx < num_tokens_in_batch; ++token_idx) {
60 sequence_indices[token_idx] = 0;
61 position_in_sequence[token_idx] = start_pos_in_sequence + token_idx;
62 }
63 }
64
65 for (int l = 0; l < num_cpu_layers_to_process; ++l) {
73
75
76 std::vector<float> batch_x_norm1(current_batch_activations.size());
77 const std::vector<float>& w_input_norm_vec =
78 lw.input_layernorm_f32.empty()
80 : lw.input_layernorm_f32;
81 rmsnorm_batch_cpu(current_batch_activations, w_input_norm_vec, batch_x_norm1, num_tokens_in_batch, hs, eps);
82
83 std::vector<float> residual_batch_component_attn = current_batch_activations;
84
85 std::vector<float> q_batch((size_t)num_tokens_in_batch * hs);
86 std::vector<float> k_batch((size_t)num_tokens_in_batch * n_kv_heads * head_dim);
87 std::vector<float> v_batch((size_t)num_tokens_in_batch * n_kv_heads * head_dim);
88
89 if (!lw.q_proj_f32.empty()) {
91 } else if (!lw.q_proj_q8_0.empty()) {
93 } else if (!lw.q_proj_q6k.empty()) {
95 } else if (!lw.q_proj_q4k.empty()) {
97 } else {
98 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No Q proj weights found for CPU");
99 return {};
100 }
101
102 if (!lw.k_proj_f32.empty()) {
104 } else if (!lw.k_proj_q8_0.empty()) {
106 } else if (!lw.k_proj_q6k.empty()) {
108 } else if (!lw.k_proj_q4k.empty()) {
110 } else {
111 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No K proj weights found for CPU");
112 return {};
113 }
114
115 if (!lw.v_proj_f32.empty()) {
117 } else if (!lw.v_proj_q8_0.empty()) {
119 } else if (!lw.v_proj_q6k.empty()) {
121 } else if (!lw.v_proj_q4k.empty()) {
123 } else {
124 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No V proj weights found for CPU");
125 return {};
126 }
127
128 if (!prompt_lengths.empty()) {
129 for (int t = 0; t < num_tokens_in_batch; ++t) {
130 int current_token_pos = position_in_sequence[t];
131 int seq_idx = sequence_indices[t];
132
133 if (current_token_pos < 0 || current_token_pos >= max_pos_embeddings) {
134 Logger::warning(
"[CPU_BATCH_FWD] Token " + std::to_string(t) +
" (seq=" + std::to_string(seq_idx) +
135 ", pos=" + std::to_string(current_token_pos) + ") is out of range. Skipping RoPE.");
136 continue;
137 }
138
139 std::vector<float> q_token(hs);
140 std::vector<float> k_token(n_kv_heads * head_dim);
141
142 std::copy(q_batch.begin() + (size_t)t * hs,
143 q_batch.begin() + (size_t)(t + 1) * hs,
144 q_token.begin());
145 std::copy(k_batch.begin() + (size_t)t * n_kv_heads * head_dim,
146 k_batch.begin() + (size_t)(t + 1) * n_kv_heads * head_dim,
147 k_token.begin());
148
151
152 std::copy(q_token.begin(), q_token.end(), q_batch.begin() + (size_t)t * hs);
153 std::copy(k_token.begin(), k_token.end(), k_batch.begin() + (size_t)t * n_kv_heads * head_dim);
154 }
155 } else {
158 }
159
160 if (kv_cache) {
161 if (!prompt_lengths.empty()) {
163 sequence_indices, position_in_sequence, n_kv_heads, head_dim);
164 } else {
166 start_pos_in_sequence, n_kv_heads, head_dim);
167 }
168 }
169
170 std::vector<float> batch_attn_output((size_t)num_tokens_in_batch * hs);
171
172 if (kv_cache &&
static_cast<size_t>(l) < kv_cache->
layers.size()) {
173 if (!prompt_lengths.empty()) {
175 num_tokens_in_batch, sequence_indices, position_in_sequence,
176 n_heads, n_kv_heads, head_dim, attention_scale,
178 } else {
180 num_tokens_in_batch, start_pos_in_sequence,
181 n_heads, n_kv_heads, head_dim, attention_scale);
182 }
183 } else if (kv_cache) {
185 " is out of bounds for KV Cache access during attention. KVCache layers size: " +
186 std::to_string(kv_cache->
layers.size()) +
187 ". Filling attention output with zeros.");
188 std::fill(batch_attn_output.begin(), batch_attn_output.end(), 0.0f);
189 } else {
190 Logger::error(
"[CPU_BATCH_FWD] KV Cache is null, cannot perform attention for layer " + std::to_string(l) +
191 ". Filling attention output with zeros.");
192 std::fill(batch_attn_output.begin(), batch_attn_output.end(), 0.0f);
193 }
194
195 std::vector<float> batch_attn_proj_out((size_t)num_tokens_in_batch * hs);
196 if(!lw.o_proj_f32.empty()) {
198 } else if (!lw.o_proj_q8_0.empty()) {
200 } else if (!lw.o_proj_q6k.empty()) {
202 } else if (!lw.o_proj_q4k.empty()) {
204 } else {
205 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No O proj weights found for CPU");
206 return {};
207 }
208
209 for(size_t i=0; i < current_batch_activations.size(); ++i) {
210 current_batch_activations[i] = residual_batch_component_attn[i] + batch_attn_proj_out[i];
211 }
212
213 std::vector<float> residual_batch_component_mlp = current_batch_activations;
214 std::vector<float> batch_x_norm2(current_batch_activations.size());
215 const std::vector<float>& w_post_attn_norm_vec =
216 lw.post_attention_layernorm_f32.empty()
218 : lw.post_attention_layernorm_f32;
219
220 rmsnorm_batch_cpu(current_batch_activations, w_post_attn_norm_vec, batch_x_norm2, num_tokens_in_batch, hs, eps);
221
222 std::vector<float> batch_gate_proj_out((size_t)num_tokens_in_batch * is);
223 std::vector<float> batch_up_proj_out((size_t)num_tokens_in_batch * is);
224
225 if (!lw.gate_proj_f32.empty()) {
227 } else if (!lw.gate_proj_q8_0.empty()) {
229 } else if (!lw.gate_proj_q6k.empty()) {
231 } else if (!lw.gate_proj_q4k.empty()) {
233 } else {
234 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No gate_proj weights found for CPU");
235 return {};
236 }
237
238 if (!lw.up_proj_f32.empty()) {
240 } else if (!lw.up_proj_q8_0.empty()) {
242 } else if (!lw.up_proj_q6k.empty()) {
244 } else if (!lw.up_proj_q4k.empty()) {
246 } else {
247 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No up_proj weights found for CPU");
248 return {};
249 }
250
251 std::vector<float> batch_swiglu_out((size_t)num_tokens_in_batch * is);
252 for (size_t i = 0; i < batch_gate_proj_out.size(); ++i) {
253 float gate_val = batch_gate_proj_out[i];
254 float silu_gate_val = gate_val / (1.0f + std::exp(-gate_val));
255 batch_swiglu_out[i] = silu_gate_val * batch_up_proj_out[i];
256 }
257
258 std::vector<float> batch_mlp_down_proj_out((size_t)num_tokens_in_batch * hs);
259 if (!lw.down_proj_f32.empty()) {
261 } else if (!lw.down_proj_q8_0.empty()) {
263 } else if (!lw.down_proj_q6k.empty()) {
265 } else if (!lw.down_proj_q4k.empty()) {
267 } else {
268 Logger::error(
"[CPU_BATCH_FWD] Layer " + std::to_string(l) +
": No down_proj weights found for CPU");
269 return {};
270 }
271
272 for(size_t i = 0; i < current_batch_activations.size(); ++i) {
273 current_batch_activations[i] = residual_batch_component_mlp[i] + batch_mlp_down_proj_out[i];
274 }
275 }
276
277 if (kv_cache && num_tokens_in_batch > 0) {
278 kv_cache->
seq_len = start_pos_in_sequence + num_tokens_in_batch;
279 }
280 return current_batch_activations;
281}
static void warning(const std::string &message)
static void error(const std::string &message)
void ensure_up_proj_dequantized(int layer_idx)
void ensure_v_proj_dequantized(int layer_idx)
std::vector< std::pair< float, float > > precomputed_freqs_cis_
std::vector< LayerWeights > layers
void ensure_o_proj_dequantized(int layer_idx)
void ensure_k_proj_dequantized(int layer_idx)
void ensure_q_proj_dequantized(int layer_idx)
void ensure_down_proj_dequantized(int layer_idx)
void ensure_gate_proj_dequantized(int layer_idx)
void attention_batch_cpu_sequence_aware(const std::vector< float > &q_batch_roped, KVCacheLayer ¤t_layer_kv_cache, std::vector< float > &batch_attn_output, int num_tokens_in_batch, const std::vector< int > &sequence_indices, const std::vector< int > &position_in_sequence, int num_q_heads, int num_kv_heads, int head_dim, float attention_scale, int max_seq_len_per_sequence)
void attention_batch_cpu(const std::vector< float > &q_batch_roped, KVCacheLayer ¤t_layer_kv_cache, std::vector< float > &batch_attn_output, int num_tokens_in_batch, int start_pos_in_sequence, int num_q_heads, int num_kv_heads, int head_dim, float attention_scale)
void update_kv_cache_batch_cpu_sequence_aware(KVCache *kv_cache, int layer_idx, const std::vector< float > &k_batch_for_layer, const std::vector< float > &v_batch_for_layer, int num_tokens_in_batch, const std::vector< int > &sequence_indices, const std::vector< int > &position_in_sequence, int num_kv_heads, int head_dim)
void update_kv_cache_batch_cpu(KVCache *kv_cache, int layer_idx, const std::vector< float > &k_batch_for_layer, const std::vector< float > &v_batch_for_layer, int num_tokens_in_batch, int start_pos_in_sequence, int num_kv_heads, int head_dim)
std::vector< KVCacheLayer > layers
int max_position_embeddings
void apply_rope_vector(std::vector< float > &x, int num_heads, int head_dim, int current_token_pos, const std::vector< std::pair< float, float > > &all_freqs_cis, int max_pos_embeddings, bool use_adjacent_pairing)
void matmul_q4k_f32_batch_cpu(const std::vector< block_q4_K > &mat_q4k, const std::vector< float > &batch_input_activations, std::vector< float > &batch_output_activations, int num_tokens, int output_dim, int input_dim)
std::vector< float > bf16vec_to_float_vec(const std::vector< uint16_t > &v_bf16)
void matmul_q8_0_f32_batch_cpu(const std::vector< block_q8_0 > &mat_q8_0, const std::vector< float > &batch_input_activations, std::vector< float > &batch_output_activations, int num_tokens, int output_dim, int input_dim)
void apply_rope_batch_cpu(std::vector< float > &q_batch, std::vector< float > &k_batch, int num_tokens, int num_q_heads, int num_kv_heads, int head_dim, int start_pos_in_sequence, const std::vector< std::pair< float, float > > &all_freqs_cis, int max_pos_embeddings, bool use_adjacent_pairing)
void matmul_f32_f32_batch_cpu(const std::vector< float > &mat_weights, const std::vector< float > &batch_input_activations, std::vector< float > &batch_output_activations, int num_tokens, int output_dim, int input_dim)
void matmul_q6k_f32_batch_cpu(const std::vector< block_q6_K > &mat_q6k, const std::vector< float > &batch_input_activations, std::vector< float > &batch_output_activations, int num_tokens, int output_dim, int input_dim)
void rmsnorm_batch_cpu(const std::vector< float > &x_batch, const std::vector< float > &weight, std::vector< float > &out_batch, int num_tokens, int hidden_size, float eps)