7#if defined(WINDOWS_CUDA_12_1_WORKAROUND) && defined(_WIN32)
11#include <cuda_runtime.h>
66inline void gpuAssert(cudaError_t code,
const char* file,
int line,
68 if (code != cudaSuccess) {
70 "GPUassert: " + std::string(cudaGetErrorString(code)) +
" " +
71 std::string(file) +
" " + std::to_string(line);
73 if (abort)
throw std::runtime_error(err_msg);
81#define gpuErrchk(ans) \
82 { gpuAssert((ans), __FILE__, __LINE__); }
104void rmsnorm_vector_cuda(
const float* x_dev,
const float* weight_dev,
105 float* out_dev,
int n,
float eps,
106 cudaStream_t stream = 0);
119void rmsnorm_vector_cuda(
const std::vector<float>& x_in_host,
120 const std::vector<float>& weight_host,
121 std::vector<float>& out_host,
int n,
float eps);
124void rmsnorm_batch_cuda(
float* d_out,
float* d_in,
const float* d_weight,
125 int num_tokens,
int hidden_size,
float eps,
126 cudaStream_t stream);
128__global__
void reduce_partial_sums_kernel(
const float* partial_sums,
float* total_sum_sq_out,
int num_partial_sums);
153void matvec_f32_f32_cuda(cublasHandle_t handle,
154 const std::vector<float>& mat_f32_host,
155 const std::vector<float>& vec_f32_host,
156 std::vector<float>& out_f32_host,
int rows,
int cols);
171void matvec_f32_f32_cuda(cublasHandle_t handle,
const float* mat_f32_dev,
172 const float* vec_f32_dev,
float* out_f32_dev,
173 int rows,
int cols, cudaStream_t stream = 0);
176void gemm_f32_f32_cuda(cublasHandle_t handle,
177 bool transa,
bool transb,
180 const float* A,
int lda,
181 const float* B,
int ldb,
184 cudaStream_t stream);
203void silu_cuda(
const std::vector<float>& x_host,
204 std::vector<float>& out_host,
int n);
215void softmax_vector_cuda(
const std::vector<float>& x_host,
216 std::vector<float>& out_host,
int n);
239void rope_cuda(
float* vec,
int num_heads,
int head_dim,
240 const float* freqs_cis_dev,
int pos,
bool use_adjacent_pairing, cudaStream_t stream);
259void attention_cuda(
const float* Q_current_dev,
const float* K_layer_cache_base,
260 const float* V_layer_cache_base,
float* out_dev,
261 int num_heads,
int current_seq_len,
int head_dim,
262 float scale,
int cache_max_seq_len,
int cache_num_kv_heads,
263 cudaStream_t stream = 0);
287void attention_cuda_selective_dequant(
const float* Q_current_dev,
288 const int8_t* K_quantized_cache_base,
289 const int8_t* V_quantized_cache_base,
290 const float* K_scales_cache_base,
291 const float* V_scales_cache_base,
292 float* selective_k_dequant_buffer,
293 float* selective_v_dequant_buffer,
295 int num_heads,
int current_seq_len,
int head_dim,
296 float scale,
int cache_max_seq_len,
int cache_num_kv_heads,
297 cudaStream_t stream = 0);
317void add_vectors_cuda(
const float* a_dev,
const float* b_dev,
318 float* result_dev,
int n, cudaStream_t stream = 0);
329void add_residual_cuda(
const float* matvec_out_dev,
const float* residual_dev,
330 float* result_dev,
int n, cudaStream_t stream = 0);
344void update_kv_cache_cuda(
float* cache_base_ptr,
345 const float* current_kv_vector,
346 int pos,
int kv_head_idx,
int max_seq_len,
347 int num_kv_heads,
int head_dim,
348 cudaStream_t stream = 0);
363void rope_and_update_kv_cache_cuda(
float* cache_base_ptr,
364 const float* kv_vector_head,
365 const float* all_freqs_cis_base,
366 int pos,
int kv_head_idx,
int max_seq_len,
367 int num_kv_heads,
int head_dim,
368 cudaStream_t stream = 0);
379void swiglu_cuda(
const float* gate_dev,
const float* up_dev,
380 float* out_dev,
int n, cudaStream_t stream = 0);
392void lookup_embedding_bf16_f32_cuda(
const uint16_t* embedding_table_dev,
393 float* output_vector_dev,
394 int token_id,
int hidden_size,
395 int vocab_size, cudaStream_t stream = 0);
408void lookup_embedding_cuda(
const void* table_dev,
float* output_dev,
409 int token_id,
int hidden_size,
int vocab_size,
410 bool is_bf16, cudaStream_t stream);
424void matvec_bf16_f32_cuda(cublasHandle_t handle,
425 const uint16_t* mat_bf16_dev,
426 const float* vec_f32_dev,
429 bool use_tensor_cores,
430 cudaStream_t stream = 0);
443__global__
void convert_bf16_to_fp32_kernel(
const uint16_t* __restrict__ bf16_in,
444 float* __restrict__ fp32_out,
458void quantize_fp32_to_int8_symmetric_per_tensor_cuda(
459 const float* fp32_in_dev,
460 int8_t* int8_out_dev,
461 float* scale_out_dev,
463 cudaStream_t stream = 0);
474void dequantize_int8_to_fp32_symmetric_per_tensor_cuda(
475 const int8_t* int8_in_dev,
476 const float* scale_in_dev,
479 cudaStream_t stream = 0);
482void swiglu_batch_cuda(
float* d_out_batch,
483 const float* d_gate_act_batch,
484 const float* d_up_act_batch,
486 int intermediate_size,
487 cudaStream_t stream);
490void rope_batch_cuda(
float* d_q_batch,
float* d_k_batch,
491 const float* d_all_freqs_cis_base,
492 int num_tokens,
int num_q_heads,
int num_kv_heads,
int head_dim,
493 int start_pos_offset,
494 bool use_adjacent_pairing,
495 cudaStream_t stream);
498void attention_batch_prefill_cuda(
499 const float* d_q_batch_strided,
500 const float* d_k_batch_strided,
501 const float* d_v_batch_strided,
502 float* d_kv_cache_k_base,
503 float* d_kv_cache_v_base,
504 float* d_output_batch_strided,
505 int num_tokens_in_batch,
506 int start_pos_in_kv_cache,
507 int cache_max_seq_len,
513 const int* attention_mask_cu =
nullptr
517void add_residual_batch_cuda(
float* d_output_batch,
518 const float* d_input_a_batch,
519 const float* d_input_b_batch,
520 int num_tokens,
int hidden_size,
521 cudaStream_t stream);
537void update_kv_cache_batch_cuda(
538 float* d_kv_cache_layer_base,
539 const float* d_keys_or_values_batch,
540 int start_pos_in_kv_cache,
541 int num_tokens_in_batch,
544 int cache_max_seq_len,
569void rmsnorm_vector_cuda_optimized(
const float* x_dev,
const float* weight_dev,
570 float* out_dev,
int n,
float eps,
571 cudaStream_t stream = 0);
585void softmax_vector_cuda_optimized(
const float* x_dev,
float* out_dev,
int n,
586 cudaStream_t stream = 0);
607void attention_cuda_optimized(
const float* Q_current_dev,
const float* K_layer_cache_base,
608 const float* V_layer_cache_base,
float* out_dev,
609 int num_heads,
int current_seq_len,
int head_dim,
610 float scale,
int cache_max_seq_len,
int cache_num_kv_heads,
611 cudaStream_t stream = 0);
614void gemm_bf16_bf16_cuda(cublasHandle_t handle,
615 bool transa_user,
bool transb_user,
616 int m_user,
int n_user,
int k_user,
617 const float* alpha_user,
618 const uint16_t* A_bf16_user,
int lda_user,
619 const uint16_t* B_bf16_user,
int ldb_user,
620 const float* beta_user,
621 uint16_t* C_bf16_user,
int ldc_user,
622 cudaStream_t stream);
624void gemm_f32_to_bf16_f32_cuda(cublasHandle_t handle,
625 bool transa_user,
bool transb_user,
626 int m_user,
int n_user,
int k_user,
627 const float* alpha_user,
628 const float* A_f32_user,
int lda_user,
629 const uint16_t* B_bf16_user,
int ldb_user,
630 const float* beta_user,
631 float* C_f32_user,
int ldc_user,
632 cudaStream_t stream);
635void convert_fp32_to_bf16_cuda(
const float* fp32_in_dev, uint16_t* bf16_out_dev,
636 size_t n_elements, cudaStream_t stream);
638void convert_bf16_to_fp32_cuda(
const uint16_t* bf16_in_dev,
float* fp32_out_dev,
639 size_t n_elements, cudaStream_t stream);
641__global__
void convert_fp32_to_bf16_kernel(
const float* __restrict__ fp32_in,
642 uint16_t* __restrict__ bf16_out,
static void error(const std::string &message)
Logging utilities for the TinyLlama implementation.