TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
cuda_kernels.h
Go to the documentation of this file.
1#ifndef CUDA_KERNELS_H
2#define CUDA_KERNELS_H
3
4#ifdef HAS_CUDA
5
6// Use safe headers only for Windows CUDA 12.1+ workaround, normal headers everywhere else
7#if defined(WINDOWS_CUDA_12_1_WORKAROUND) && defined(_WIN32)
8#include "cuda_safe_headers.h"
9#else
10// Normal CUDA header inclusion for non-problematic platforms (Ubuntu, etc.)
11#include <cuda_runtime.h>
12#include <cublas_v2.h>
13#include <cuda_fp16.h>
14#include <cuda_bf16.h>
15#endif
16
17#include <cstdint>
18#include <stdexcept>
19#include <string>
20#include <vector>
21
22#include "logger.h"
23
66inline void gpuAssert(cudaError_t code, const char* file, int line,
67 bool abort = true) {
68 if (code != cudaSuccess) {
69 std::string err_msg =
70 "GPUassert: " + std::string(cudaGetErrorString(code)) + " " +
71 std::string(file) + " " + std::to_string(line);
72 Logger::error(err_msg);
73 if (abort) throw std::runtime_error(err_msg);
74 }
75}
76
81#define gpuErrchk(ans) \
82 { gpuAssert((ans), __FILE__, __LINE__); }
83
104void rmsnorm_vector_cuda(const float* x_dev, const float* weight_dev,
105 float* out_dev, int n, float eps,
106 cudaStream_t stream = 0);
107
119void rmsnorm_vector_cuda(const std::vector<float>& x_in_host,
120 const std::vector<float>& weight_host,
121 std::vector<float>& out_host, int n, float eps);
122
123// New Batch RMSNorm
124void rmsnorm_batch_cuda(float* d_out, float* d_in, const float* d_weight,
125 int num_tokens, int hidden_size, float eps,
126 cudaStream_t stream);
127
128__global__ void reduce_partial_sums_kernel(const float* partial_sums, float* total_sum_sq_out, int num_partial_sums);
129
153void matvec_f32_f32_cuda(cublasHandle_t handle,
154 const std::vector<float>& mat_f32_host,
155 const std::vector<float>& vec_f32_host,
156 std::vector<float>& out_f32_host, int rows, int cols);
157
171void matvec_f32_f32_cuda(cublasHandle_t handle, const float* mat_f32_dev,
172 const float* vec_f32_dev, float* out_f32_dev,
173 int rows, int cols, cudaStream_t stream = 0);
174
175// New Generic GEMM for FP32 (if needed, can be adapted for BF16 later or use cublasGemmEx)
176void gemm_f32_f32_cuda(cublasHandle_t handle,
177 bool transa, bool transb,
178 int m, int n, int k,
179 const float* alpha,
180 const float* A, int lda,
181 const float* B, int ldb,
182 const float* beta,
183 float* C, int ldc,
184 cudaStream_t stream);
185
203void silu_cuda(const std::vector<float>& x_host,
204 std::vector<float>& out_host, int n);
205
215void softmax_vector_cuda(const std::vector<float>& x_host,
216 std::vector<float>& out_host, int n);
217
239void rope_cuda(float* vec, int num_heads, int head_dim,
240 const float* freqs_cis_dev, int pos, bool use_adjacent_pairing, cudaStream_t stream);
241
259void attention_cuda(const float* Q_current_dev, const float* K_layer_cache_base,
260 const float* V_layer_cache_base, float* out_dev,
261 int num_heads, int current_seq_len, int head_dim,
262 float scale, int cache_max_seq_len, int cache_num_kv_heads,
263 cudaStream_t stream = 0);
264
287void attention_cuda_selective_dequant(const float* Q_current_dev,
288 const int8_t* K_quantized_cache_base,
289 const int8_t* V_quantized_cache_base,
290 const float* K_scales_cache_base,
291 const float* V_scales_cache_base,
292 float* selective_k_dequant_buffer,
293 float* selective_v_dequant_buffer,
294 float* out_dev,
295 int num_heads, int current_seq_len, int head_dim,
296 float scale, int cache_max_seq_len, int cache_num_kv_heads,
297 cudaStream_t stream = 0);
298
317void add_vectors_cuda(const float* a_dev, const float* b_dev,
318 float* result_dev, int n, cudaStream_t stream = 0);
319
329void add_residual_cuda(const float* matvec_out_dev, const float* residual_dev,
330 float* result_dev, int n, cudaStream_t stream = 0);
331
344void update_kv_cache_cuda(float* cache_base_ptr,
345 const float* current_kv_vector,
346 int pos, int kv_head_idx, int max_seq_len,
347 int num_kv_heads, int head_dim,
348 cudaStream_t stream = 0);
349
363void rope_and_update_kv_cache_cuda(float* cache_base_ptr,
364 const float* kv_vector_head,
365 const float* all_freqs_cis_base,
366 int pos, int kv_head_idx, int max_seq_len,
367 int num_kv_heads, int head_dim,
368 cudaStream_t stream = 0);
369
379void swiglu_cuda(const float* gate_dev, const float* up_dev,
380 float* out_dev, int n, cudaStream_t stream = 0);
381
392void lookup_embedding_bf16_f32_cuda(const uint16_t* embedding_table_dev,
393 float* output_vector_dev,
394 int token_id, int hidden_size,
395 int vocab_size, cudaStream_t stream = 0);
396
408void lookup_embedding_cuda(const void* table_dev, float* output_dev,
409 int token_id, int hidden_size, int vocab_size,
410 bool is_bf16, cudaStream_t stream);
411
424void matvec_bf16_f32_cuda(cublasHandle_t handle,
425 const uint16_t* mat_bf16_dev,
426 const float* vec_f32_dev,
427 float* out_f32_dev,
428 int rows, int cols,
429 bool use_tensor_cores,
430 cudaStream_t stream = 0);
431
443__global__ void convert_bf16_to_fp32_kernel(const uint16_t* __restrict__ bf16_in,
444 float* __restrict__ fp32_out,
445 size_t n_elements);
446
447// KVCache Quantization Kernels (FP32 <-> INT8)
448
458void quantize_fp32_to_int8_symmetric_per_tensor_cuda(
459 const float* fp32_in_dev,
460 int8_t* int8_out_dev,
461 float* scale_out_dev,
462 int num_elements,
463 cudaStream_t stream = 0);
464
474void dequantize_int8_to_fp32_symmetric_per_tensor_cuda(
475 const int8_t* int8_in_dev,
476 const float* scale_in_dev,
477 float* fp32_out_dev,
478 int num_elements,
479 cudaStream_t stream = 0);
480
481// New Batched SwiGLU (SiLU + element-wise multiply)
482void swiglu_batch_cuda(float* d_out_batch, // Output: [num_tokens, intermediate_size]
483 const float* d_gate_act_batch, // Input: Gate activations [num_tokens, intermediate_size]
484 const float* d_up_act_batch, // Input: Up activations [num_tokens, intermediate_size]
485 int num_tokens,
486 int intermediate_size,
487 cudaStream_t stream);
488
489// New Batched RoPE
490void rope_batch_cuda(float* d_q_batch, float* d_k_batch,
491 const float* d_all_freqs_cis_base, // Changed name
492 int num_tokens, int num_q_heads, int num_kv_heads, int head_dim,
493 int start_pos_offset, // Changed name
494 bool use_adjacent_pairing, // Changed name and type
495 cudaStream_t stream);
496
497// New Batched Attention for Prefill
498void attention_batch_prefill_cuda(
499 const float* d_q_batch_strided, // Input Q: [B, H_q, D_h]
500 const float* d_k_batch_strided, // Input K for current batch
501 const float* d_v_batch_strided, // Input V for current batch
502 float* d_kv_cache_k_base, // K Cache: [S_max, H_kv, D_h]
503 float* d_kv_cache_v_base, // V Cache: [S_max, H_kv, D_h]
504 float* d_output_batch_strided, // Output: [B, H_q, D_h]
505 int num_tokens_in_batch, // B
506 int start_pos_in_kv_cache, // Start position for this batch in KV cache
507 int cache_max_seq_len, // Max capacity of KV cache
508 int num_q_heads, // H_q
509 int num_kv_heads, // H_kv
510 int head_dim, // D_h
511 float scale,
512 cudaStream_t stream,
513 const int* attention_mask_cu = nullptr // Optional attention mask, changed name
514);
515
516// New Batched Add Residual
517void add_residual_batch_cuda(float* d_output_batch, // Output: [num_tokens, hidden_size]
518 const float* d_input_a_batch, // Input A: [num_tokens, hidden_size]
519 const float* d_input_b_batch, // Input B: [num_tokens, hidden_size]
520 int num_tokens, int hidden_size,
521 cudaStream_t stream);
522
523// New Batched KV Cache Update (FP32 example)
524// This function seems to be missing a corresponding definition for separate K and V caches.
525// The existing update_kv_cache_batch_cuda takes a single cache layer base.
526// For now, I will comment out the _fp32 version and assume the generic one is intended to be used twice.
527/*
528void update_kv_cache_batch_cuda_fp32(
529 float* d_kvcache_k, float* d_kvcache_v,
530 const float* d_k_batch_current, const float* d_v_batch_current,
531 int num_tokens_in_batch, int start_pos_in_kv_cache,
532 int max_seq_len_in_cache, int num_kv_heads, int head_dim,
533 cudaStream_t stream
534);
535*/
536// Assuming the following is the intended generic function for batch KV update:
537void update_kv_cache_batch_cuda(
538 float* d_kv_cache_layer_base, // Device pointer to the K or V cache for the current layer
539 const float* d_keys_or_values_batch, // Device pointer to the batch of K or V vectors to be written
540 int start_pos_in_kv_cache, // The sequence position in the cache where writing for this batch should begin
541 int num_tokens_in_batch, // Number of tokens in the d_keys_or_values_batch
542 int num_kv_heads, // Number of K/V heads
543 int head_dim, // Dimension of each K/V head
544 int cache_max_seq_len, // Maximum sequence length capacity of the cache
545 cudaStream_t stream
546);
547
569void rmsnorm_vector_cuda_optimized(const float* x_dev, const float* weight_dev,
570 float* out_dev, int n, float eps,
571 cudaStream_t stream = 0);
572
585void softmax_vector_cuda_optimized(const float* x_dev, float* out_dev, int n,
586 cudaStream_t stream = 0);
587
607void attention_cuda_optimized(const float* Q_current_dev, const float* K_layer_cache_base,
608 const float* V_layer_cache_base, float* out_dev,
609 int num_heads, int current_seq_len, int head_dim,
610 float scale, int cache_max_seq_len, int cache_num_kv_heads,
611 cudaStream_t stream = 0);
612
613// New BF16 Tensor Core Matrix-Matrix Operations
614void gemm_bf16_bf16_cuda(cublasHandle_t handle,
615 bool transa_user, bool transb_user,
616 int m_user, int n_user, int k_user,
617 const float* alpha_user,
618 const uint16_t* A_bf16_user, int lda_user,
619 const uint16_t* B_bf16_user, int ldb_user,
620 const float* beta_user,
621 uint16_t* C_bf16_user, int ldc_user,
622 cudaStream_t stream);
623
624void gemm_f32_to_bf16_f32_cuda(cublasHandle_t handle,
625 bool transa_user, bool transb_user,
626 int m_user, int n_user, int k_user,
627 const float* alpha_user,
628 const float* A_f32_user, int lda_user,
629 const uint16_t* B_bf16_user, int ldb_user,
630 const float* beta_user,
631 float* C_f32_user, int ldc_user,
632 cudaStream_t stream);
633
634// Conversion utilities
635void convert_fp32_to_bf16_cuda(const float* fp32_in_dev, uint16_t* bf16_out_dev,
636 size_t n_elements, cudaStream_t stream);
637
638void convert_bf16_to_fp32_cuda(const uint16_t* bf16_in_dev, float* fp32_out_dev,
639 size_t n_elements, cudaStream_t stream);
640
641__global__ void convert_fp32_to_bf16_kernel(const float* __restrict__ fp32_in,
642 uint16_t* __restrict__ bf16_out,
643 size_t n_elements);
644
645#endif // HAS_CUDA
646
647#endif // CUDA_KERNELS_H
static void error(const std::string &message)
Definition logger.cpp:143
Safe CUDA header inclusion wrapper for Windows CUDA 12.1+ compatibility.
Logging utilities for the TinyLlama implementation.