TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
model_config.cpp
Go to the documentation of this file.
1#include "model_config.h"
2
3#include "logger.h"
4#include "gguf_parser.h"
5#include <algorithm>
6#include <cmath>
7#include <cstring>
8#include <fstream>
9#include <iomanip>
10#include <limits>
11#include <memory>
12#include <sstream>
13#include <stdexcept>
14#include <cassert>
15#include <cstdint>
16#include <iostream>
17#include <numeric>
18#include <variant>
19
20ModelConfig parse_model_config(const nlohmann::json& json) {
21 ModelConfig cfg;
22 cfg.hidden_size = json.value("hidden_size", 0);
23 cfg.intermediate_size = json.value("intermediate_size", 0);
24 cfg.num_attention_heads = json.value("num_attention_heads", 0);
25 cfg.num_key_value_heads = json.value("num_key_value_heads", 0);
26 cfg.num_hidden_layers = json.value("num_hidden_layers", 0);
27 cfg.vocab_size = json.value("vocab_size", 0);
28 cfg.max_position_embeddings = json.value("max_position_embeddings", 0);
29 cfg.rms_norm_eps = json.value("rms_norm_eps", 1e-5f);
30 cfg.rope_theta = json.value("rope_theta", 10000.0f);
31 cfg.hidden_act = json.value("hidden_act", "silu");
32 cfg.torch_dtype = json.value("torch_dtype", "bfloat16");
33 cfg.bos_token_id = json.value("bos_token_id", 1);
34 cfg.eos_token_id = json.value("eos_token_id", 2);
35 cfg.unk_token_id = json.value("unk_token_id", -1);
36 cfg.pad_token_id = json.value("pad_token_id", -1);
37
38 // Infer Architecture if available
39 if (json.contains("architectures") && json["architectures"].is_array() && !json["architectures"].empty()) {
40 // Take the first architecture string if multiple are listed
41 cfg.architecture = json["architectures"][0].get<std::string>();
42 } else {
43 cfg.architecture = "unknown";
44 }
45 cfg.model_name = json.value("model_type", cfg.architecture); // Use model_type or fallback to architecture
46
47
48 Logger::info("[parse_json_config] Inferring tokenizer family for SafeTensors. Arch: '" + cfg.architecture + "', Vocab: " + std::to_string(cfg.vocab_size));
49 bool is_llama3_vocab_size_json = (cfg.vocab_size == 128256);
50 bool is_llama3_arch_hint_json = (cfg.architecture.find("LlamaForCausalLM") != std::string::npos && // Llama 3 often uses this
51 cfg.architecture.find("Llama2") == std::string::npos); // Exclude Llama 2 explicitly if needed
52
53 if (is_llama3_vocab_size_json && is_llama3_arch_hint_json) {
55 Logger::info("[parse_json_config] Result: Identified LLAMA3_TIKTOKEN (vocab size + arch hint).");
56 if (cfg.rope_theta == 10000.0f) {
57 float llama3_rope_candidate = json.value("rope_theta", 500000.0f); // Check rope_theta in config.json
58 if (llama3_rope_candidate > 10000.0f) {
59 cfg.rope_theta = llama3_rope_candidate;
60 Logger::info("[parse_json_config] Adjusted rope_theta to " + std::to_string(cfg.rope_theta) + " for Llama 3 model (was 10000.0).");
61 }
62 }
63 } else if (cfg.vocab_size == 32000 || cfg.architecture.find("Llama") != std::string::npos) { // Common for Llama 1/2/TinyLlama
65 Logger::info("[parse_json_config] Result: Identified LLAMA_SENTENCEPIECE (vocab size or arch hint).");
66 } else {
68 Logger::warning("[parse_json_config] Result: UNKNOWN tokenizer family.");
69 }
70
71
72 return cfg;
73}
74
76 ModelConfig config;
77 Logger::info("[parse_gguf_config] Entered function.");
78
79 auto get_meta_string = [&](const std::string& key,
80 const std::string& default_val) -> std::string {
81 auto it = gguf.metadata.find(key);
82 if (it != gguf.metadata.end() &&
83 std::holds_alternative<std::string>(it->second)) {
84 return std::get<std::string>(it->second);
85 }
86 return default_val;
87 };
88
89 auto get_meta_value = [&](const std::string& key, auto default_value) {
90 using TargetType = typename std::decay<decltype(default_value)>::type;
91 auto it = gguf.metadata.find(key);
92 if (it != gguf.metadata.end()) {
93 return std::visit(
94 [&](const auto& val) -> TargetType {
95 using T = std::decay_t<decltype(val)>;
96
97 if constexpr (std::is_integral_v<TargetType>) {
98 if constexpr (std::is_integral_v<T> && !std::is_same_v<T, bool>) {
99 if constexpr (std::is_unsigned_v<T> &&
100 std::is_signed_v<TargetType>) {
101 if (val > static_cast<std::make_unsigned_t<TargetType>>(
102 std::numeric_limits<TargetType>::max())) {
103 Logger::warning("Metadata key '" + key + "' value " +
104 std::to_string(val) +
105 " overflows TargetType. Using default.");
106 return default_value;
107 }
108 }
109
110 else if constexpr (std::is_signed_v<T> &&
111 std::is_signed_v<TargetType> &&
112 sizeof(T) > sizeof(TargetType)) {
113 if (val > static_cast<T>(
114 std::numeric_limits<TargetType>::max()) ||
115 val < static_cast<T>(
116 std::numeric_limits<TargetType>::lowest())) {
117 Logger::warning("Metadata key '" + key + "' value " +
118 std::to_string(val) +
119 " overflows TargetType. Using default.");
120 return default_value;
121 }
122 }
123 return static_cast<TargetType>(val);
124 }
125 } else if constexpr (std::is_floating_point_v<TargetType>) {
126 if constexpr (std::is_floating_point_v<T>) {
127 return static_cast<TargetType>(val);
128 }
129 } else if constexpr (std::is_same_v<TargetType, bool>) {
130 if constexpr (std::is_same_v<T, bool>) {
131 return val;
132 }
133 } else if constexpr (std::is_same_v<TargetType, std::string>) {
134 if constexpr (std::is_same_v<T, std::string>) {
135 return val;
136 }
137 }
138 Logger::warning("Metadata key '" + key +
139 "' has stored type incompatible with requested "
140 "TargetType. Using default.");
141 return default_value;
142 },
143 it->second);
144 } else {
145 return default_value;
146 }
147 };
148
149 config.vocab_size = get_meta_value("tokenizer.ggml.vocab_size",
150 get_meta_value("llama.vocab_size", 32000));
151 config.hidden_size = get_meta_value("llama.embedding_length", 4096);
152 config.intermediate_size = get_meta_value("llama.feed_forward_length", 11008);
153 config.num_attention_heads = get_meta_value("llama.attention.head_count", 32);
154 config.num_hidden_layers = get_meta_value("llama.block_count", 32);
155 config.num_key_value_heads = get_meta_value("llama.attention.head_count_kv",
156 config.num_attention_heads);
157 config.max_position_embeddings = get_meta_value("llama.context_length", 4096);
158 if (config.max_position_embeddings == 0 ||
159 config.max_position_embeddings > 8192) {
160 Logger::warning("max_position_embeddings from GGUF is " +
161 std::to_string(config.max_position_embeddings) +
162 ", overriding to sensible default (2048)");
163 config.max_position_embeddings = 2048;
164 }
165 config.rms_norm_eps =
166 get_meta_value("llama.attention.layer_norm_rms_epsilon", 1e-5f);
167 config.rope_theta = get_meta_value("llama.rope.freq_base", 10000.0f);
168 config.hidden_act = "silu";
169 config.bos_token_id = get_meta_value("tokenizer.ggml.bos_token_id", -1);
170 config.eos_token_id = get_meta_value("tokenizer.ggml.eos_token_id", -1);
171 config.unk_token_id = get_meta_value("tokenizer.ggml.unk_token_id", -1);
172 config.pad_token_id = get_meta_value("tokenizer.ggml.padding_token_id", -1);
173
174 config.architecture = get_meta_string("general.architecture", "unknown");
175 config.model_name = get_meta_string("general.name", "unknown");
176 bool has_pre_key = gguf.metadata.count("tokenizer.ggml.pre");
177 bool has_merges = !gguf.tokenizer_merges.empty();
178
179 Logger::info("[parse_gguf_config] Architecture: " + config.architecture +
180 ", Vocab Size: " + std::to_string(config.vocab_size) +
181 ", Has Merges: " + (has_merges ? "Yes" : "No"));
182
183
184 Logger::info("[parse_gguf_config] Identifying tokenizer family...");
185 bool is_llama3_arch_hint = (config.architecture.find("llama3") != std::string::npos ||
186 config.architecture.find("Llama-3") != std::string::npos ||
187 config.architecture.find("Meta-Llama-3") != std::string::npos);
188 bool is_llama3_vocab_size = (config.vocab_size == 128256);
189 std::string ggml_tokenizer_model = get_meta_string("tokenizer.ggml.model", "");
190 bool is_tiktoken_style_tokenizer_model = (ggml_tokenizer_model == "gpt2");
191
192 Logger::info("[parse_gguf_config] L3 Hints: arch_hint=" + std::string(is_llama3_arch_hint ? "Y":"N") +
193 ", vocab_size_match=" + std::string(is_llama3_vocab_size ? "Y":"N") +
194 ", has_merges=" + std::string(has_merges ? "Y":"N") +
195 ", ggml_tokenizer_model_key='" + ggml_tokenizer_model + "' (is_tiktoken_style: " + std::string(is_tiktoken_style_tokenizer_model ? "Y":"N") + ")" );
196
197 if (has_merges && is_llama3_vocab_size && is_tiktoken_style_tokenizer_model) {
199 Logger::info("[parse_gguf_config] Result: Identified LLAMA3_TIKTOKEN (merges + vocab_size + ggml_tokenizer_model='gpt2'). Architecture string was: '" + config.architecture + "'");
200 if (!is_llama3_arch_hint && config.architecture == "llama") {
201 Logger::info("[parse_gguf_config] Note: Classified as Llama 3 based on tokenizer/vocab, but arch string was 'llama'.");
202 }
203 if (config.rope_theta == 10000.0f) {
204 float llama3_rope_candidate = get_meta_value("llama.rope.freq_base", 500000.0f);
205 if (llama3_rope_candidate > 10000.0f) {
206 config.rope_theta = llama3_rope_candidate;
207 Logger::info("[parse_gguf_config] Adjusted rope_theta to " + std::to_string(config.rope_theta) + " for Llama 3 model (was 10000.0).");
208 }
209 }
210 } else if (config.architecture == "llama" || config.architecture.find("Llama-2") != std::string::npos || config.architecture.find("TinyLlama") != std::string::npos) {
212 Logger::info("[parse_gguf_config] Result: Identified LLAMA_SENTENCEPIECE based on architecture: '" + config.architecture + "'");
213 } else {
215 Logger::info("[parse_gguf_config] Result: UNKNOWN tokenizer family for architecture: '" + config.architecture + "'");
216 }
217
218 // Existing chat_template_type and pre_tokenizer_type logic based on architecture and pre_key
219 if (config.model_name.find("TinyLlama") != std::string::npos ||
220 (config.architecture == "llama" && has_pre_key)) {
221 config.chat_template_type = "tinyllama";
222 } else if (config.architecture == "llama" && !has_pre_key) {
223 config.chat_template_type = "llama2";
224 } else {
225 config.chat_template_type = "unknown";
226 Logger::warning("Could not determine chat template type for arch='" +
227 config.architecture + "', name='" + config.model_name +
228 "'.");
229 }
230
231 if (has_pre_key) {
232 config.pre_tokenizer_type =
233 get_meta_string("tokenizer.ggml.pre", "unknown");
234 } else if (config.architecture == "llama") {
235 config.pre_tokenizer_type = "llama";
236 } else {
237 config.pre_tokenizer_type = "unknown";
238 }
239 Logger::info("Determined config: architecture='" + config.architecture +
240 "', model_name='" + config.model_name + "', chat_template='" +
241 config.chat_template_type + "', pre_tokenizer='" +
242 config.pre_tokenizer_type + "'");
243
244 if (config.model_name == "llama" && config.pre_tokenizer_type != "llama") {
245 config.chat_template_type = "llama2";
247 "Inferred chat_template_type='llama2' based on model_type and "
248 "missing/different pre_tokenizer_type.");
249 }
250
251 auto template_it = gguf.metadata.find("tokenizer.chat_template");
252 if (template_it != gguf.metadata.end() &&
253 std::holds_alternative<std::string>(template_it->second)) {
254 config.chat_template_string = std::get<std::string>(template_it->second);
255 Logger::info("Found tokenizer.chat_template in metadata.");
256
257 } else {
259 "tokenizer.chat_template not found or not a string in metadata. Will "
260 "use fallback logic.");
261 config.chat_template_string = "";
262 }
263 if (config.chat_template_type == "unknown") {
264 if (config.model_name == "llama" && config.pre_tokenizer_type != "llama") {
265 config.chat_template_type = "llama2";
267 "Inferred chat_template_type='llama2' based on model name and "
268 "missing/different pre_tokenizer_type.");
270 Logger::info("Llama 3 model identified. Chat template will primarily rely on 'tokenizer.chat_template' from GGUF if present.");
271 // Set a generic type for now, actual application will use the string.
272 if (gguf.metadata.count("tokenizer.chat_template")) {
273 config.chat_template_type = "llama3_gguf_direct";
274 } else {
275 config.chat_template_type = "llama3_fallback"; // Or some other indicator
276 Logger::warning("Llama 3 model detected, but 'tokenizer.chat_template' not found in GGUF metadata.");
277 }
278 }
279 }
280
281 Logger::info(std::string("[parse_gguf_config] Finished parsing. Returning config. Family: ") +
283 (config.tokenizer_family == ModelConfig::TokenizerFamily::LLAMA_SENTENCEPIECE ? "L2_SPM" : "UNKNOWN")));
284 return config;
285}
static void warning(const std::string &message)
Definition logger.cpp:139
static void info(const std::string &message)
Definition logger.cpp:135
Parser for GGUF (GPT-Generated Unified Format) files.
Logging utilities for the TinyLlama implementation.
ModelConfig parse_model_config(const nlohmann::json &json)
ModelConfig parse_model_config_from_gguf(const GGUFData &gguf)
nlohmann::json json
Definition server.cpp:54
Complete representation of a GGUF file's contents.
std::vector< std::string > tokenizer_merges
std::map< std::string, GGUFMetadataValue > metadata
Model configuration structure holding architecture and hyperparameters.
Definition model.h:80
int hidden_size
Definition model.h:81
int vocab_size
Definition model.h:86
int pad_token_id
Definition model.h:95
std::string chat_template_string
Definition model.h:100
std::string pre_tokenizer_type
Definition model.h:99
std::string architecture
Definition model.h:96
std::string model_name
Definition model.h:97
float rms_norm_eps
Definition model.h:88
int num_attention_heads
Definition model.h:83
std::string chat_template_type
Definition model.h:98
int intermediate_size
Definition model.h:82
int eos_token_id
Definition model.h:93
std::string torch_dtype
Definition model.h:91
float rope_theta
Definition model.h:89
int num_hidden_layers
Definition model.h:85
int num_key_value_heads
Definition model.h:84
int bos_token_id
Definition model.h:92
std::string hidden_act
Definition model.h:90
TokenizerFamily tokenizer_family
Definition model.h:117
int unk_token_id
Definition model.h:94
int max_position_embeddings
Definition model.h:87