6#include <nlohmann/json.hpp>
8#include <unordered_map>
14#if defined(WINDOWS_CUDA_12_1_WORKAROUND) && defined(_WIN32)
18#include <cuda_runtime.h>
131 std::vector<float>
k;
132 std::vector<float>
v;
134 float* k_dev_fp32 =
nullptr;
135 float* v_dev_fp32 =
nullptr;
137 int8_t* k_dev_quantized =
nullptr;
138 int8_t* v_dev_quantized =
nullptr;
139 float* k_dev_scales =
nullptr;
140 float* v_dev_scales =
nullptr;
176 int total_num_model_layers,
int num_gpu_layers_to_allocate,
177 int max_seq_len_arg,
int num_kv_heads,
int head_dim,
178 int max_batch_size_arg = 1);
189 for (
auto& layer :
layers) {
190 std::fill(layer.k.begin(), layer.k.end(), 0.0f);
191 std::fill(layer.v.begin(), layer.v.end(), 0.0f);
205 ". Using max batch size.");
215 int allocated_num_layers = 0;
216 int allocated_max_seq_len = 0;
217 int allocated_num_kv_heads = 0;
218 int allocated_head_dim = 0;
231 int layer,
const std::string& name,
const std::vector<float>& v)>;
266 float* input_layernorm_dev =
nullptr;
267 float* post_attention_layernorm_dev =
nullptr;
270 float* q_proj_f32_dev =
nullptr;
271 float* k_proj_f32_dev =
nullptr;
272 float* v_proj_f32_dev =
nullptr;
273 float* o_proj_f32_dev =
nullptr;
274 float* gate_proj_f32_dev =
nullptr;
275 float* up_proj_f32_dev =
nullptr;
276 float* down_proj_f32_dev =
nullptr;
307 std::unique_ptr<GGUFData> gguf_data_from_session);
323 std::vector<float>& input,
324 int n_tokens,
KVCache* kv_cache,
325 const std::vector<int>* attention_mask);
348 int m_user,
int n_user,
int k_user,
349 const float* alpha_user,
350 const float* A_f32_user,
int lda_user,
351 const float* B_f32_user,
int ldb_user,
352 const float* beta_user,
353 float* C_f32_user,
int ldc_user,
355 const char* operation_name =
"GEMM");
367 std::vector<float> forward_device(
371 const std::vector<int>* attention_mask =
nullptr,
372 cudaStream_t stream = 0);
374 float* get_x_dev() {
return x_dev_; }
376 void forward_device(
int token_id,
int pos,
KVCache* kv_cache,
377 cudaStream_t stream = 0);
378 void forward_device_token(
int token_id,
int pos,
KVCache* kv_cache, cudaStream_t stream = 0);
380 std::vector<float> forward_device_batch_prefill(
381 float* d_batch_input_hidden_states,
382 int num_tokens_in_batch,
383 int start_pos_in_kv_cache,
388 std::vector<std::vector<float>> forward_device_batch_generation(
389 float* d_batch_input_hidden_states,
390 const std::vector<int>& token_positions,
391 const std::vector<int>& original_sequence_indices,
392 int num_tokens_in_batch,
402 static constexpr int MAX_BATCH_TOKENS = 2048;
405 float* d_persistent_batch_input_ =
nullptr;
406 float* d_persistent_batch_norm_out_ =
nullptr;
407 float* d_persistent_batch_residual_ =
nullptr;
408 float* d_persistent_q_batch_ =
nullptr;
409 float* d_persistent_k_batch_ =
nullptr;
410 float* d_persistent_v_batch_ =
nullptr;
411 float* d_persistent_attn_output_ =
nullptr;
412 float* d_persistent_attn_proj_out_ =
nullptr;
413 float* d_persistent_gate_proj_out_ =
nullptr;
414 float* d_persistent_up_proj_out_ =
nullptr;
415 float* d_persistent_swiglu_out_ =
nullptr;
416 float* d_persistent_mlp_down_out_ =
nullptr;
419 void allocate_persistent_batch_buffers();
420 void free_persistent_batch_buffers();
421 void resize_persistent_batch_buffers_if_needed(
int required_batch_size);
455 const std::vector<float>& batch_input_activations,
456 int num_tokens_in_batch,
457 int num_cpu_layers_to_process,
458 int start_pos_in_sequence,
460 const std::vector<int>& prompt_lengths = {}
464 const std::vector<float>& final_batch_activations,
465 int num_tokens_in_batch
469 const std::vector<float>& batch_input_activations,
470 const std::vector<int>& token_positions,
471 const std::vector<int>& original_sequence_indices,
472 int num_tokens_in_batch,
494 float* final_norm_dev =
nullptr;
495 float* all_freqs_cis_dev =
nullptr;
496 uint16_t* token_embedding_table_dev_ =
nullptr;
497 uint16_t* w_q_dev_ =
nullptr;
498 uint16_t* w_k_dev_ =
nullptr;
499 uint16_t* w_v_dev_ =
nullptr;
500 uint16_t* w_o_dev_ =
nullptr;
501 uint16_t* w_gate_dev_ =
nullptr;
502 uint16_t* w_up_dev_ =
nullptr;
503 uint16_t* w_down_dev_ =
nullptr;
504 uint16_t* lm_head_dev_ =
nullptr;
505 float* token_embedding_table_f32_dev_ =
nullptr;
506 float* w_q_f32_dev_ =
nullptr;
507 float* w_k_f32_dev_ =
nullptr;
508 float* w_v_f32_dev_ =
nullptr;
509 float* w_o_f32_dev_ =
nullptr;
510 float* w_gate_f32_dev_ =
nullptr;
511 float* w_up_f32_dev_ =
nullptr;
512 float* w_down_f32_dev_ =
nullptr;
513 float* lm_head_f32_dev_ =
nullptr;
514 cublasHandle_t cublas_handle_ =
nullptr;
516 float* x_dev_ =
nullptr;
517 float* x_norm_dev_ =
nullptr;
518 float* x_resid1_dev_ =
nullptr;
519 float* x_resid2_dev_ =
nullptr;
520 float* q_dev_ =
nullptr;
521 float* k_dev_ =
nullptr;
522 float* v_dev_ =
nullptr;
523 float* attn_out_dev_ =
nullptr;
524 float* attn_proj_dev_ =
nullptr;
525 float* gate_vec_dev_ =
nullptr;
526 float* up_vec_dev_ =
nullptr;
527 float* swiglu_vec_dev_ =
nullptr;
528 float* mlp_down_dev_ =
nullptr;
529 float* logits_dev_ =
nullptr;
532 float* dequant_k_cache_buffer_dev_ =
nullptr;
533 float* dequant_v_cache_buffer_dev_ =
nullptr;
536 float* selective_k_dequant_buffer_dev_ =
nullptr;
537 float* selective_v_dequant_buffer_dev_ =
nullptr;
538 size_t selective_dequant_buffer_size_ = 0;
543 uint16_t* w_q_bf16_dev_ =
nullptr;
544 uint16_t* w_k_bf16_dev_ =
nullptr;
545 uint16_t* w_v_bf16_dev_ =
nullptr;
546 uint16_t* w_o_bf16_dev_ =
nullptr;
547 uint16_t* w_gate_bf16_dev_ =
nullptr;
548 uint16_t* w_up_bf16_dev_ =
nullptr;
549 uint16_t* w_down_bf16_dev_ =
nullptr;
550 bool bf16_concatenated_weights_loaded_ =
false;
569int argmax(
const std::vector<float>& v);
573void rmsnorm(
const std::vector<float>& x,
const std::vector<uint16_t>& weight,
574 float eps, std::vector<float>& out);
577 const std::vector<float>& vec, std::vector<float>& out,
586 const std::vector<uint8_t>& bytes,
size_t numel);
592 int num_tokens_in_batch,
int single_token_vector_size,
static void warning(const std::string &message)
Main class for loading tensors from SafeTensors format files (single or sharded)
Main transformer model class for TinyLlama.
const GGUFData * get_gguf_data() const
bool use_bf16_tensor_cores_
void free_layer_gpu_weights(int layer_idx)
~TinyLlamaModel()
Destructor. Cleans up all allocated resources.
bool f32_concatenated_weights_loaded_
std::vector< LayerWeights > & get_layers()
std::vector< block_q6_K > embed_tokens_q6k
std::vector< float > lookup_embedding(int token_id)
Lookup the embedding vector for a given token ID.
const ModelConfig & get_config() const
void ensure_up_proj_dequantized(int layer_idx)
std::vector< float > final_norm_f32
void free_bf16_concatenated_weights()
std::vector< block_q4_K > final_norm_q4k
void initialize_rope_freqs()
std::vector< uint16_t > final_norm
int get_vocab_size() const
Get the vocabulary size for the model.
void ensure_v_proj_dequantized(int layer_idx)
std::vector< block_q6_K > final_norm_q6k
std::vector< block_q4_K > lm_head_q4k
std::vector< float > forward_cpu_logits_batch(const std::vector< float > &final_batch_activations, int num_tokens_in_batch)
friend void map_gguf_weights(const GGUFData &gguf, TinyLlamaModel &model)
std::vector< block_q6_K > lm_head_q6k
void ensure_layer_weights_on_gpu(int layer_idx)
const std::vector< uint16_t > & get_embed_tokens() const
std::vector< std::pair< float, float > > precomputed_freqs_cis_
void initialize_gpu_and_rope()
void ensure_embed_tokens_dequantized()
std::vector< LayerWeights > layers
std::vector< block_q8_0 > embed_tokens_q8_0
void ensure_o_proj_dequantized(int layer_idx)
void clear_layer_dequantized_weights(int layer_idx)
std::vector< block_q4_K > embed_tokens_q4k
void smart_gemm_batch_cuda(bool transa_user, bool transb_user, int m_user, int n_user, int k_user, const float *alpha_user, const float *A_f32_user, int lda_user, const float *B_f32_user, int ldb_user, const float *beta_user, float *C_f32_user, int ldc_user, cudaStream_t stream, const char *operation_name="GEMM")
void ensure_k_proj_dequantized(int layer_idx)
const std::vector< uint16_t > & get_lm_head() const
std::unique_ptr< class CPUBatchProcessor > cpu_batch_processor_
std::vector< block_q8_0 > lm_head_q8_0
std::vector< uint16_t > lm_head
void ensure_f32_concatenated_weights_loaded()
std::vector< std::vector< float > > forward_cpu_batch_generation(const std::vector< float > &batch_input_activations, const std::vector< int > &token_positions, const std::vector< int > &original_sequence_indices, int num_tokens_in_batch, KVCache *kv_cache)
std::vector< uint16_t > embed_tokens
std::vector< block_q8_K > embed_tokens_q8k
void ensure_bf16_concatenated_weights_loaded()
void ensure_q_proj_dequantized(int layer_idx)
void initialize_weights(const SafeTensorsLoader *loader, const GGUFData *gguf)
std::vector< float > forward_cpu_batch(const std::vector< float > &batch_input_activations, int num_tokens_in_batch, int num_cpu_layers_to_process, int start_pos_in_sequence, KVCache *kv_cache, const std::vector< int > &prompt_lengths={})
void ensure_down_proj_dequantized(int layer_idx)
void ensure_gate_proj_dequantized(int layer_idx)
GGUFData * get_gguf_data_ptr()
std::vector< float > embed_tokens_f32
std::vector< float > forward(std::vector< float > &input, int n_tokens, KVCache *kv_cache, const std::vector< int > *attention_mask)
Run the forward pass for the model on CPU layers.
void ensure_lm_head_dequantized()
std::unique_ptr< GGUFData > gguf_data_
std::vector< float > lm_head_f32
std::vector< block_q8_K > lm_head_q8k
static std::string tensor_name_to_string(TensorName tn)
TensorName
Enumeration of tensor names used in the TinyLlama model.
std::vector< uint16_t > uint8_vector_to_uint16_vector(const std::vector< uint8_t > &bytes, size_t numel)
void rmsnorm(const std::vector< float > &x, const std::vector< uint16_t > &weight, float eps, std::vector< float > &out)
ModelConfig parse_model_config(const nlohmann::json &json)
ModelConfig parse_model_config_from_gguf(const GGUFData &gguf)
void log_vector_summary(const std::string &name, const std::vector< float > &v, int head_count=5)
void log_vector_summary_batch(const std::string &name, const std::vector< float > &batch_vector, int num_tokens_in_batch, int single_token_vector_size, int head_count=5)
void matvec_bf16_f32(const std::vector< uint16_t > &mat, const std::vector< float > &vec, std::vector< float > &out, int M, int N)
int argmax(const std::vector< float > &v)
float bfloat16_to_float32(uint16_t b16)
void softmax(std::vector< float > &x)
std::function< void(int layer, const std::string &name, const std::vector< float > &v)> ForwardDiagCallback
Weight quantization structures and functions for model compression.
SafeTensors format loader for efficient tensor loading, supporting single and sharded models.
Complete representation of a GGUF file's contents.
Key-Value cache for a single transformer layer.
Complete Key-Value cache for all transformer layers.
void initialize_batch(int batch_size)
Initialize batch mode with specified number of sequences.
void initialize(const ModelConfig &config, int total_num_model_layers, int num_gpu_layers_to_allocate, int max_seq_len_arg, int num_kv_heads, int head_dim, int max_batch_size_arg=1)
Initializes the KV cache with given dimensions.
std::vector< KVCacheLayer > layers
void destroy_gpu_resources()
std::vector< int > batch_seq_lens
Structure holding all weights for a single transformer layer.
std::vector< uint16_t > post_attention_layernorm
std::vector< block_q6_K > down_proj_q6k
std::vector< block_q4_K > k_proj_q4k
std::vector< block_q6_K > k_proj_q6k
std::vector< float > input_layernorm_f32
std::vector< block_q8_K > o_proj_q8k
std::vector< uint16_t > gate_proj
std::vector< uint16_t > v_proj
std::vector< uint16_t > input_layernorm
std::vector< block_q4_K > v_proj_q4k
std::vector< block_q4_K > up_proj_q4k
std::vector< block_q8_0 > o_proj_q8_0
std::vector< float > up_proj_f32
std::vector< uint16_t > o_proj
std::vector< block_q8_K > down_proj_q8k
std::vector< block_q4_K > down_proj_q4k
std::vector< block_q4_K > gate_proj_q4k
std::vector< float > v_proj_f32
std::vector< block_q6_K > v_proj_q6k
std::vector< block_q8_K > up_proj_q8k
std::vector< block_q6_K > up_proj_q6k
std::vector< block_q8_0 > v_proj_q8_0
std::vector< float > k_proj_f32
std::vector< block_q8_K > v_proj_q8k
std::vector< block_q8_0 > gate_proj_q8_0
std::vector< block_q6_K > q_proj_q6k
std::vector< block_q8_K > k_proj_q8k
std::vector< block_q6_K > gate_proj_q6k
std::vector< block_q8_K > gate_proj_q8k
std::vector< float > gate_proj_f32
std::vector< float > o_proj_f32
std::vector< uint16_t > down_proj
std::vector< block_q8_0 > q_proj_q8_0
std::vector< block_q8_0 > k_proj_q8_0
std::vector< uint16_t > up_proj
std::vector< block_q4_K > o_proj_q4k
std::vector< float > q_proj_f32
std::vector< uint16_t > q_proj
std::vector< block_q8_K > q_proj_q8k
std::vector< float > post_attention_layernorm_f32
std::vector< float > down_proj_f32
std::vector< block_q6_K > o_proj_q6k
std::vector< block_q8_0 > down_proj_q8_0
std::vector< block_q8_0 > up_proj_q8_0
std::vector< block_q4_K > q_proj_q4k
std::vector< uint16_t > k_proj
Model configuration structure holding architecture and hyperparameters.
std::string chat_template_string
std::string pre_tokenizer_type
std::string chat_template_type
int num_cpu_offload_layers
bool enable_memory_efficient_layers
bool use_kvcache_quantization
bool use_optimized_cuda_kernels
bool enable_prefill_chunking
TokenizerFamily tokenizer_family
int max_position_embeddings