TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
quantization.h
Go to the documentation of this file.
1#pragma once
2
3#include <cstddef>
4#include <cstdint>
5#include <limits>
6#include <vector>
7
8#include "ggml_types.h"
9#include "gguf_parser.h" // Include for GGML_QK_K
10
11// Define RESTRICT macro based on compiler
12#ifdef _MSC_VER
13#define RESTRICT
14#else
15#define RESTRICT __restrict__
16#endif
17
28// Forward declarations
29struct block_q2_K;
30struct block_q3_K;
31struct block_q4_K;
32struct block_q6_K;
33
40float fp16_to_fp32(uint16_t h, bool is_gguf_scale_field = false);
41
47uint16_t fp32_to_fp16(float f);
48
49#pragma pack(push, 1)
50
57struct block_q4_K {
58 uint16_t d;
59 uint16_t dmin;
60 uint8_t scales[12];
61 uint8_t qs[GGML_QK_K / 2];
62};
63static_assert(sizeof(block_q4_K) == 2 + 2 + 12 + 128, "Size mismatch for standard block_q4_K");
64
71struct block_q6_K {
72 uint8_t ql[GGML_QK_K / 2];
73 uint8_t qh[GGML_QK_K / 4];
74 int8_t scales[GGML_QK_K / 16];
75 uint16_t d;
76};
77static_assert(sizeof(block_q6_K) == 128 + 64 + 16 + 2, "Size mismatch for block_q6_K");
78
85struct block_q2_K {
86 uint16_t d;
87 uint16_t dmin;
88 uint8_t scales[GGML_QK_K / 16];
89 uint8_t qs[GGML_QK_K / 4];
90};
91static_assert(sizeof(block_q2_K) == 2 + 2 + 16 + 64, "Size mismatch for block_q2_K");
92
99struct block_q3_K {
100 uint8_t hmask[GGML_QK_K / 8];
101 uint8_t qs[GGML_QK_K / 4];
102 uint8_t scales[12];
103 uint16_t d;
104 uint16_t dmin;
105};
106static_assert(sizeof(block_q3_K) == 32 + 64 + 12 + 2 + 2, "Size mismatch for block_q3_K");
107
112 uint16_t d;
113 int8_t qs[GGML_QK_K];
114 int16_t bsums[GGML_QK_K / 16];
115};
116
121 uint16_t d;
122 int8_t qs[GGML_QK8_0];
123};
124static_assert(sizeof(block_q8_0) == sizeof(uint16_t) + GGML_QK8_0, "Size mismatch for block_q8_0");
125
126#pragma pack(pop)
127
133const char* ggml_type_name(GGMLType type);
134
140size_t ggml_type_size(GGMLType type);
141
147size_t ggml_type_block_size(GGMLType type);
148
156void dequantize_q2_k(const void* q_data, float* f_data,
157 int num_weights_in_block,
158 bool log_details_for_this_block = false);
159
167void dequantize_q4_k_m(const block_q4_K* qblock, float* RESTRICT output_f32,
168 int num_elements, bool log_this_block = false);
169
177void dequantize_q6_k(const block_q6_K* qblock, float* RESTRICT output_f32,
178 int num_elements, bool log_this_block = false);
179
187void dequantize_vector_q6k_to_f32(const std::vector<block_q6_K>& q_weights,
188 std::vector<float>& f32_weights,
189 size_t total_num_elements,
190 int log_first_n_blocks = 0);
191
198void dequantize_q3_k(const void* q_data, float* f_data,
199 int num_weights_in_block);
200
207void handle_i8_tensor(const void* i8_data, float* f_data, size_t num_elements);
208
215void quantize_q4_k_m(const float* f_data, void* q_data, int num_elements);
216
223void quantize_q6_k(const float* f_data, void* q_data, int num_elements);
224
230std::vector<block_q8_K> quantize_fp32_to_q8_K(const std::vector<float>& f_data);
231
240float vec_dot_q6_k_q8_k_cpu(int n, const std::vector<block_q6_K>& x,
241 const std::vector<block_q8_K>& y,
242 bool log_this_call);
243
253void matvec_q6k_q8k_cpu(const std::vector<block_q6_K>& mat_q6k,
254 const std::vector<block_q8_K>& vec_q8k,
255 std::vector<float>& out_f32, int rows, int cols,
256 bool log_calls);
257
266float vec_dot_q4_k_q8_k_cpu(int n, const std::vector<block_q4_K>& x_vec,
267 const std::vector<block_q8_K>& y_vec,
268 bool log_this_call);
269
279void matvec_q4k_q8k_cpu(const std::vector<block_q4_K>& mat_q4k,
280 const std::vector<block_q8_K>& vec_q8k,
281 std::vector<float>& out_f32, int rows, int cols,
282 bool log_calls);
283
289void dequantize_q8_0_block(const block_q8_0* qblock, float* output);
290
298void dequantize_vector_q4k_to_f32(const std::vector<block_q4_K>& q_weights,
299 std::vector<float>& f32_weights,
300 size_t total_num_elements,
301 int log_first_n_blocks = 0);
302
310void dequantize_vector_q8_0_to_f32(const std::vector<block_q8_0>& q_weights,
311 std::vector<float>& f32_weights,
312 size_t total_num_elements,
313 int log_first_n_blocks = 0);
Type definitions for GGML (Georgi Gerganov Machine Learning) library.
GGMLType
Enumeration of GGML tensor data types.
Definition ggml_types.h:21
Parser for GGUF (GPT-Generated Unified Format) files.
constexpr size_t GGML_QK8_0
Definition gguf_parser.h:43
constexpr size_t GGML_QK_K
Block size constants for different quantization formats.
Definition gguf_parser.h:42
uint16_t fp32_to_fp16(float f)
Converts a 32-bit float to 16-bit floating point.
size_t ggml_type_block_size(GGMLType type)
Gets the block size for a GGML type.
size_t ggml_type_size(GGMLType type)
Gets the size in bytes of a GGML type.
void dequantize_vector_q8_0_to_f32(const std::vector< block_q8_0 > &q_weights, std::vector< float > &f32_weights, size_t total_num_elements, int log_first_n_blocks=0)
Dequantizes a vector of Q8_0 blocks to a vector of float32.
void matvec_q6k_q8k_cpu(const std::vector< block_q6_K > &mat_q6k, const std::vector< block_q8_K > &vec_q8k, std::vector< float > &out_f32, int rows, int cols, bool log_calls)
Computes matrix-vector product between Q6_K matrix and Q8_K vector on CPU.
const char * ggml_type_name(GGMLType type)
Gets the string name of a GGML type.
void matvec_q4k_q8k_cpu(const std::vector< block_q4_K > &mat_q4k, const std::vector< block_q8_K > &vec_q8k, std::vector< float > &out_f32, int rows, int cols, bool log_calls)
Computes matrix-vector product between Q4_K matrix and Q8_K vector on CPU.
void dequantize_q2_k(const void *q_data, float *f_data, int num_weights_in_block, bool log_details_for_this_block=false)
Dequantizes a Q2_K quantized block to float32.
void handle_i8_tensor(const void *i8_data, float *f_data, size_t num_elements)
Handles conversion of int8 tensor data to float32.
void dequantize_q4_k_m(const block_q4_K *qblock, float *RESTRICT output_f32, int num_elements, bool log_this_block=false)
Dequantizes a Q4_K quantized block to float32.
float vec_dot_q6_k_q8_k_cpu(int n, const std::vector< block_q6_K > &x, const std::vector< block_q8_K > &y, bool log_this_call)
Computes dot product between Q6_K and Q8_K vectors on CPU.
void quantize_q4_k_m(const float *f_data, void *q_data, int num_elements)
Quantizes float32 data to Q4_K format.
void quantize_q6_k(const float *f_data, void *q_data, int num_elements)
Quantizes float32 data to Q6_K format.
std::vector< block_q8_K > quantize_fp32_to_q8_K(const std::vector< float > &f_data)
Quantizes float32 data to Q8_K format.
float vec_dot_q4_k_q8_k_cpu(int n, const std::vector< block_q4_K > &x_vec, const std::vector< block_q8_K > &y_vec, bool log_this_call)
Computes dot product between Q4_K and Q8_K vectors on CPU.
float fp16_to_fp32(uint16_t h, bool is_gguf_scale_field=false)
Converts a 16-bit floating point number to 32-bit float.
#define RESTRICT
void dequantize_q8_0_block(const block_q8_0 *qblock, float *output)
Dequantizes a Q8_0 block to float32.
void dequantize_vector_q4k_to_f32(const std::vector< block_q4_K > &q_weights, std::vector< float > &f32_weights, size_t total_num_elements, int log_first_n_blocks=0)
Dequantizes a vector of Q4_K blocks to a vector of float32.
void dequantize_q6_k(const block_q6_K *qblock, float *RESTRICT output_f32, int num_elements, bool log_this_block=false)
Dequantizes a Q6_K quantized block to float32.
void dequantize_q3_k(const void *q_data, float *f_data, int num_weights_in_block)
Dequantizes a Q3_K quantized block to float32.
void dequantize_vector_q6k_to_f32(const std::vector< block_q6_K > &q_weights, std::vector< float > &f32_weights, size_t total_num_elements, int log_first_n_blocks=0)
Dequantizes a vector of Q6_K blocks to a vector of float32.
2-bit K-quantized block structure
uint16_t dmin
uint16_t d
uint8_t qs[GGML_QK_K/4]
uint8_t scales[GGML_QK_K/16]
3-bit K-quantized block structure
uint8_t scales[12]
uint16_t dmin
uint8_t hmask[GGML_QK_K/8]
uint16_t d
uint8_t qs[GGML_QK_K/4]
4-bit K-quantized block structure
uint16_t d
uint8_t scales[12]
uint8_t qs[GGML_QK_K/2]
uint16_t dmin
6-bit K-quantized block structure
int8_t scales[GGML_QK_K/16]
uint16_t d
uint8_t ql[GGML_QK_K/2]
uint8_t qh[GGML_QK_K/4]
Simple 8-bit quantized block structure.
uint16_t d
int8_t qs[GGML_QK8_0]
8-bit K-quantized block structure with block sums
int16_t bsums[GGML_QK_K/16]
int8_t qs[GGML_QK_K]
uint16_t d