TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
tokenizer.cpp
Go to the documentation of this file.
1#include "tokenizer.h"
2
3#include <algorithm>
4#include <cctype>
5#include <fstream>
6#include <iomanip>
7#include <iostream>
8#include <map>
9#include <nlohmann/json.hpp>
10#include <queue>
11#include <boost/regex.hpp>
12#include <boost/xpressive/xpressive.hpp>
13#include <sstream>
14#include <stdexcept>
15#include <unordered_set>
16#include <vector>
17#include <string>
18#include <limits>
19#include <utility> // For std::pair
20#include <functional> // For std::less
21#include <filesystem>
22
23#include "logger.h"
24
25// Define BPE_SPACE_CHAR at file scope for broader accessibility
26const std::string BPE_SPACE_CHAR = "\xC4\xA0"; // GPT-2 BPE space character (Ġ)
27
28using json = nlohmann::json;
29
30// Forward declaration for helper function defined later in an anonymous namespace
31namespace {
32 size_t unicode_char_len(char src);
33} // end anonymous namespace
34
35// Helper function to check if a string represents a number.
36bool is_numeric(const std::string& s) {
37 if (s.empty()) {
38 return false; // An empty string is not considered numeric
39 }
40 for (char c : s) {
41 if (!std::isdigit(static_cast<unsigned char>(c))) {
42 return false; // Found a non-digit character
43 }
44 }
45 return true; // All characters are digits
46}
47
48
49// Finds the rank of a potential BPE merge.
50// Returns the rank (lower is better) if the merge exists, otherwise -1.
51int Tokenizer::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
52 auto it = bpe_merges_.find(token_left + token_right); // Ensure this uses the correct combined form if prefixes are involved
53 if (it != bpe_merges_.end()) {
54 return it->second; // Return the rank
55 }
56 return -1; // Merge not found
57}
58
59std::vector<std::string> Tokenizer::bpe_tokenize_from_scores(
60 const std::string& text) const {
61 std::vector<std::string> all_tokens;
62 std::vector<std::string> initial_units; // Pre-tokenized parts (words, symbols, spaces)
63
64 // Llama-like regex for pre-tokenization
65 boost::regex llama_regex(
66 // This pattern is common for SentencePiece-like splitting by words, numbers, symbols, and whitespace.
67 R"([\r\n]+|[[:space:]]+|[^\r\n[:space:][:alnum:]]+|[[:alnum:]]+)");
68 boost::smatch match;
69 std::string text_to_search = text;
70
71 // Pre-tokenize the text using the regex
72 while (boost::regex_search(text_to_search, match, llama_regex)) {
73 if (!match.str(0).empty()) { // Ensure no empty strings are added
74 initial_units.push_back(match.str(0));
75 }
76 text_to_search = match.suffix().str();
77 }
78 if (!text_to_search.empty()) { // Add any trailing part not matched
79 initial_units.push_back(text_to_search);
80 }
81
82 Logger::debug("[BPE_SCORES] Regex pre-tokenization resulted in " + std::to_string(initial_units.size()) + " initial units.");
83
84 const std::string sp_space_prefix = "\xE2\x96\x81"; // SentencePiece space U+2581
85 bool next_word_needs_prefix = true;
86
87 for (const std::string& unit_raw : initial_units) {
88 if (unit_raw.empty()) continue;
89
90 // Check if the unit is purely whitespace
91 bool unit_is_whitespace = true;
92 for (char c : unit_raw) {
93 if (!std::isspace(static_cast<unsigned char>(c))) {
94 unit_is_whitespace = false;
95 break;
96 }
97 }
98
99 if (unit_is_whitespace) {
100 // Whitespace signals that the *next* non-whitespace unit needs the prefix.
101 next_word_needs_prefix = true;
102 Logger::debug("[BPE_SCORES] Unit '" + unit_raw + "' is whitespace. Setting prefix flag for next word.");
103 continue; // Skip to the next unit
104 }
105
106
107 std::string unit_to_bpe = unit_raw;
108 if (next_word_needs_prefix) {
109 unit_to_bpe = sp_space_prefix + unit_to_bpe;
110 Logger::debug("[BPE_SCORES] Prefixed unit: '" + unit_raw + "' -> '" + unit_to_bpe + "'");
111 next_word_needs_prefix = false; // Reset flag after applying prefix
112 } else {
113 Logger::debug("[BPE_SCORES] Processing unit without prefix: '" + unit_to_bpe + "'");
114 }
115
116 if (unit_raw == "\n") {
117 Logger::debug("[BPE_SCORES] Raw unit is newline. It will be split into chars. Current unit_to_bpe: '" + unit_to_bpe + "'");
118 // If a newline is a standalone token, it should be found. If it's part of merges, it will be handled.
119 }
120
121 std::vector<std::string> chars; // Characters/sub-units of the current unit_to_bpe
122 // Split unit_to_bpe into UTF-8 characters
123 for (size_t i = 0; i < unit_to_bpe.size();) {
124 int bytes = unicode_char_len(unit_to_bpe[i]);
125
126 if (i + bytes <= unit_to_bpe.size()) {
127 chars.push_back(unit_to_bpe.substr(i, bytes));
128 } else {
129 Logger::warning("[BPE_SCORES] Invalid UTF-8 sequence or length error for: '" + unit_to_bpe.substr(i) + "'");
130 chars.push_back(unit_to_bpe.substr(i));
131 break;
132 }
133 i += bytes;
134 }
135
136 if (chars.empty()) {
137 Logger::warning("[BPE_SCORES] Unit '" + unit_to_bpe + "' (original: '" + unit_raw + "') produced no chars for BPE.");
138 continue;
139 }
140
141 // Perform BPE merges based on scores (ranks in bpe_merges_)
142 bool changes = true;
143 while (changes && chars.size() > 1) {
144 changes = false;
145 int best_rank = std::numeric_limits<int>::max(); // For rank-based merges, lower is better
146 int best_i = -1;
147
148 for (size_t i = 0; i < chars.size() - 1; ++i) {
149 std::string pair = chars[i] + chars[i + 1];
150 auto it = bpe_merges_.find(pair);
151 if (it != bpe_merges_.end() && it->second < best_rank) { // Using rank from bpe_merges_
152 best_rank = it->second;
153 best_i = i;
154 }
155 }
156
157 if (best_i >= 0) { // If a merge was found
158 std::string merged = chars[best_i] + chars[best_i + 1];
159 chars[best_i] = merged;
160 chars.erase(chars.begin() + best_i + 1);
161 changes = true;
162 }
163 }
164
165 all_tokens.insert(all_tokens.end(), chars.begin(), chars.end());
166 }
167
168 Logger::debug("[BPE_SCORES] Final token count after BPE: " + std::to_string(all_tokens.size()));
169 return all_tokens;
170}
172 const std::vector<std::string>& tokens) const {
173 std::vector<int> ids;
174 ids.reserve(tokens.size());
175
176 for (const auto& token : tokens) {
177
178 if (token == "\n") {
179 Logger::debug("[TOK_TO_ID_NL_DEBUG] Processing token: '\n' (actual newline char). Length: " + std::to_string(token.length()));
180 bool found_in_added = false;
181 for (const auto& pair : added_tokens_) {
182 if (pair.first == "\n") {
183 Logger::debug("[TOK_TO_ID_NL_DEBUG] Found '\n' key in added_tokens_ map. ID: " + std::to_string(pair.second));
184 found_in_added = true;
185 break;
186 }
187 }
188 if (!found_in_added) {
189 Logger::debug("[TOK_TO_ID_NL_DEBUG] '\n' key NOT found in added_tokens_ map by direct string compare.");
190 // Log all keys in added_tokens_ if newline is not found, to see what IS there
191 std::string keys_in_map = "Keys in added_tokens_: ";
192 for (const auto& pair : added_tokens_) {
193 std::string key_escaped;
194 for (char c_key : pair.first) {
195 if (c_key == '\n') key_escaped += "<NL>";
196 else if (c_key == '\r') key_escaped += "<CR>";
197 else if (c_key == '\t') key_escaped += "<TAB>";
198 else if (std::isprint(static_cast<unsigned char>(c_key))) key_escaped += c_key;
199 else { std::stringstream ss_hex; ss_hex << "<0x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(static_cast<unsigned char>(c_key)) << ">"; key_escaped += ss_hex.str(); }
200 }
201 keys_in_map += "['" + key_escaped + "' (len:" + std::to_string(pair.first.length()) + ")] ";
202 }
203 Logger::debug(keys_in_map);
204 }
205 }
206
207
208 auto added_it = added_tokens_.find(token);
209 if (added_it != added_tokens_.end()) { // Check added tokens first
210 ids.push_back(added_it->second);
211 Logger::debug("[TOK_TO_ID] Found added token: '" + token +
212 "' -> ID: " + std::to_string(added_it->second));
213 } else { // Not an added token, check base vocabulary
214 auto base_it = token_to_id_.find(token);
215 if (base_it != token_to_id_.end()) {
216 ids.push_back(base_it->second);
217 Logger::debug("[TOK_TO_ID] Found base token: '" + token +
218 "' -> ID: " + std::to_string(base_it->second));
219 } else { // Not in base vocab, try capitalized version
220 std::string capitalized_token = capitalize_first_letter(token);
221 if (capitalized_token != token) { // If capitalization changed something
222 auto capitalized_it = token_to_id_.find(capitalized_token);
223 if (capitalized_it != token_to_id_.end()) {
224 ids.push_back(capitalized_it->second);
226 "[TOK_TO_ID] FALLBACK: Found capitalized base token: '" +
227 token + "' -> '" + capitalized_token +
228 "' -> ID: " + std::to_string(capitalized_it->second));
229 continue; // Skip further fallbacks for this token
230 }
231 }
232
233 // Fallback for single-byte tokens if not found yet
234 if (token.length() == 1) {
235 char c = token[0];
236 auto byte_it = byte_char_to_id_.find(c);
237 if (byte_it != byte_char_to_id_.end()) {
238 ids.push_back(byte_it->second);
239 Logger::debug("[TOK_TO_ID] FALLBACK: Mapped single-byte token '" +
240 std::string(1, c) + "' to byte token ID " +
241 std::to_string(byte_it->second));
242 continue; // Skip further fallbacks
243 }
244 }
245
246 // If all fallbacks fail, use UNK token
247 Logger::debug("[TOK_TO_ID] UNKNOWN: Token '" + token +
248 "' not found in added, base, capitalized fallback, or "
249 "byte tokens. Using UNK ID: " +
250 std::to_string(unk_token_id_));
251 ids.push_back(unk_token_id_);
252 }
253 }
254 }
255
256 return ids;
257}
258std::vector<std::string> Tokenizer::ids_to_tokens(
259 const std::vector<int>& ids) const {
260 std::vector<std::string> tokens;
261 tokens.reserve(ids.size());
262
263 for (int id : ids) {
264 auto added_it = id_to_added_token_.find(id); // Check added tokens first
265 if (added_it != id_to_added_token_.end()) {
266 tokens.push_back(added_it->second);
267 } else if (id >= 0 && static_cast<size_t>(id) < id_to_token_.size()) { // Check base vocabulary
268 if (!id_to_token_[id].empty()) { // Ensure token string is not empty
269 tokens.push_back(id_to_token_[id]);
270 } else {
271 tokens.push_back(unk_token_); // Fallback to UNK string
273 "ID " + std::to_string(id) +
274 " found in base vocab range but has empty string. Using UNK token string: '" + unk_token_ + "'.");
275 }
276 } else { // ID is out of bounds or negative (and not an added token)
277 tokens.push_back(unk_token_); // Fallback to UNK string
278 }
279 }
280
281 return tokens;
282}
283
284
285Tokenizer::Tokenizer(const std::string& vocab_path,
286 const std::string& model_path,
287 const ModelConfig& config)
288 : tokenizer_family_(config.tokenizer_family),
289 unk_token_("<unk>"),
290 bos_token_("<s>"),
291 eos_token_("</s>"),
292 pad_token_("<pad>") {
293 Logger::info("[Tokenizer Constructor JSON] vocab_path: '" + vocab_path + "', model_path: '" + model_path + "'"); // Diagnostic log
294 try {
295 std::filesystem::path vocab_json_path_abs(vocab_path);
296 if (!std::filesystem::exists(vocab_json_path_abs)) {
297 throw std::runtime_error("Tokenizer vocab_path (tokenizer.json) does not exist: " + vocab_json_path_abs.string());
298 }
299
300 Logger::info(std::string("Loading tokenizer and vocab from: ") + vocab_json_path_abs.string());
301 std::string family_str = "UNKNOWN";
302 if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA_SENTENCEPIECE) family_str = "LLAMA_SENTENCEPIECE";
303 else if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA3_TIKTOKEN) family_str = "LLAMA3_TIKTOKEN";
304 Logger::info(std::string("Tokenizer family based on config: ") + family_str);
305
306 load_vocab_from_json(vocab_json_path_abs.string(), token_to_id_, id_to_token_);
307
309 Logger::info("LLAMA_SENTENCEPIECE family detected for JSON constructor, attempting to load BPE merges from: " + vocab_json_path_abs.string());
310 load_bpe_merges_from_json(vocab_json_path_abs.string());
311 }
312
313 unk_token_id_ = (token_to_id_.count(unk_token_)) ? token_to_id_[unk_token_] : config.bos_token_id; // Fallback to BOS if UNK not in vocab
317
318 if (bos_token_id_ >= 0 && static_cast<size_t>(bos_token_id_) < id_to_token_.size() && !token_to_id_.count(bos_token_)) bos_token_ = id_to_token_[bos_token_id_];
319 if (eos_token_id_ >= 0 && static_cast<size_t>(eos_token_id_) < id_to_token_.size() && !token_to_id_.count(eos_token_)) eos_token_ = id_to_token_[eos_token_id_];
320 if (unk_token_id_ >= 0 && static_cast<size_t>(unk_token_id_) < id_to_token_.size() && !token_to_id_.count(unk_token_)) unk_token_ = id_to_token_[unk_token_id_];
321 if (pad_token_id_ >= 0 && static_cast<size_t>(pad_token_id_) < id_to_token_.size()) pad_token_ = id_to_token_[pad_token_id_];
322
323 Logger::info("Final Special Tokens (JSON constructor path): BOS=" + std::to_string(bos_token_id_) +
324 " ('" + bos_token_ + "'), EOS=" + std::to_string(eos_token_id_) + " ('" +
325 eos_token_ + "'), UNK=" + std::to_string(unk_token_id_) + " ('" +
326 unk_token_ + "'), PAD=" + std::to_string(pad_token_id_) + " ('" +
327 pad_token_ + "')"); // Removed extra backslashes from PAD log
328
329 std::string init_log_message = "Tokenizer successfully initialized from JSON/Config. Detected type based on config: ";
330 init_log_message += (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA3_TIKTOKEN ? "LLAMA3_TIKTOKEN (assumed BPE)" :
331 (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA_SENTENCEPIECE ? "LLAMA_SENTENCEPIECE (assumed BPE/SPM)" : "UNKNOWN"));
332 Logger::info(init_log_message);
333
334 if (model_path.size() > 0) {
335 if (model_path.size() > 6 &&
336 model_path.substr(model_path.size() - 6) == ".model") {
337 Logger::info("Loading SentencePiece model: " + model_path);
338 load_sentencepiece_model(model_path);
339 } else if (model_path.size() > 5 &&
340 model_path.substr(model_path.size() - 5) == ".json") {
341 Logger::info("Loading BPE merges from JSON: " + model_path);
342 load_bpe_merges_from_json(model_path);
343 } else {
344 Logger::info("Unsupported model format: " + model_path +
345 " - falling back to space tokenization");
346 }
347 } else {
349 "No model path provided - falling back to space tokenization");
350 }
351 } catch (const std::exception& e) {
352 std::cerr << "Failed to load tokenizer or vocab from " << vocab_path << ": "
353 << e.what() << std::endl;
354 Logger::error(std::string("Failed to load tokenizer or vocab from \"") +
355 vocab_path + "\": " + e.what());
356 throw;
357 }
358
359 if (id_to_token_.empty()) {
360 throw std::runtime_error(
361 "Failed to initialize tokenizer vocabulary from: " + vocab_path);
362 }
363
364 Logger::info("Loaded " + std::to_string(id_to_token_.size()) +
365 " tokens from vocabulary file: " + vocab_path);
366
367 if (id_to_token_.size() > 0) {
368 std::string first_few_tokens_log = "First few (up to 10 or vocab size) tokens from " + vocab_path + ": ";
369 for (size_t i = 0; i < std::min((size_t)10, id_to_token_.size()); ++i) {
370 first_few_tokens_log += "ID[" + std::to_string(i) + "]=";
371 std::string escaped_token;
372 for (char c_tok : id_to_token_[i]) {
373 if (c_tok == '\\') {
374 escaped_token += "\\\\";
375 } else if (c_tok == '\'') {
376 escaped_token += "\\'";
377 } else if (std::isprint(static_cast<unsigned char>(c_tok))) {
378 escaped_token += c_tok;
379 } else {
380 std::stringstream ss_hex;
381 ss_hex << "<0x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(static_cast<unsigned char>(c_tok)) << ">";
382 escaped_token += ss_hex.str();
383 }
384 }
385 first_few_tokens_log += "'" + escaped_token + "' "; // Enclose in single quotes
386 }
387 Logger::info(first_few_tokens_log);
388 }
389
390 const std::vector<std::pair<std::string, int>> known_chat_tokens = {
391 {"<|system|>", 32000}, {"<|user|>", 32001}, {"<|assistant|>", 32002}};
392 int manually_injected_count = 0;
393 size_t vocab_size = id_to_token_.size();
394 for (const auto& pair : known_chat_tokens) {
395 const std::string& tok = pair.first;
396 int id = pair.second;
397
398 if (added_tokens_.find(tok) == added_tokens_.end() &&
399 static_cast<size_t>(id) >= vocab_size) {
400 added_tokens_[tok] = id;
401 id_to_added_token_[id] = tok;
402 manually_injected_count++;
403 Logger::info("[MANUAL INJECT] Added missing chat token: '" + tok +
404 "' with assumed ID: " + std::to_string(id));
405 } else if (added_tokens_.find(tok) != added_tokens_.end()) {
406 Logger::debug("[MANUAL INJECT] Chat token '" + tok +
407 "' already loaded from JSON. Skipping injection.");
408 } else {
409 Logger::warning("[MANUAL INJECT] Cannot add chat token '" + tok +
410 "', assumed ID " + std::to_string(id) +
411 " clashes with loaded vocab size (" +
412 std::to_string(vocab_size) + ").");
413 }
414 }
415 if (manually_injected_count > 0) {
416 Logger::info("Manually injected " +
417 std::to_string(manually_injected_count) +
418 " missing chat tokens.");
419 }
420}
421
422static std::unordered_map<std::string, int> generate_bpe_merges_from_vocab_scores(
423 const std::vector<std::string>& id_to_token,
424 const std::vector<float>& token_scores) {
425
426 std::unordered_map<std::string, int> generated_merges;
427
428 if (token_scores.empty() || id_to_token.empty()) {
429 Logger::warning("Cannot generate BPE merges: empty scores or vocabulary");
430 return generated_merges;
431 }
432
433 Logger::info("Generating BPE merges from vocabulary and scores for older Llama models...");
434
435 // Create a list of tokens with their scores, sorted by score (higher score = higher priority)
436 std::vector<std::pair<float, std::string>> scored_tokens;
437 for (size_t id = 0; id < id_to_token.size(); ++id) {
438 if (id < token_scores.size()) {
439 const std::string& token = id_to_token[id];
440 // Skip special tokens and single characters
441 if (token.length() > 1 &&
442 token.find("<") == std::string::npos &&
443 token.find(">") == std::string::npos &&
444 token != "▁") { // Skip SentencePiece space token
445 scored_tokens.emplace_back(token_scores[id], token);
446 }
447 }
448 }
449
450 // Sort by score (descending - higher scores first)
451 std::sort(scored_tokens.begin(), scored_tokens.end(),
452 [](const auto& a, const auto& b) { return a.first > b.first; });
453
454 Logger::info("Found " + std::to_string(scored_tokens.size()) + " candidate tokens for merge generation");
455
456 // Generate merges by finding tokens that can be decomposed into pairs
457 int merge_rank = 0;
458 std::unordered_set<std::string> processed_tokens;
459
460 for (const auto& [score, token] : scored_tokens) {
461 if (processed_tokens.count(token)) continue;
462
463 // Try to find the best split point for this token
464 std::string best_left, best_right;
465 float best_combined_score = -std::numeric_limits<float>::infinity();
466
467 // Try all possible split points
468 for (size_t split = 1; split < token.length(); ++split) {
469 std::string left = token.substr(0, split);
470 std::string right = token.substr(split);
471
472 // Check if both parts exist in vocabulary
473 auto left_it = std::find(id_to_token.begin(), id_to_token.end(), left);
474 auto right_it = std::find(id_to_token.begin(), id_to_token.end(), right);
475
476 if (left_it != id_to_token.end() && right_it != id_to_token.end()) {
477 // Both parts exist, calculate combined score
478 size_t left_id = std::distance(id_to_token.begin(), left_it);
479 size_t right_id = std::distance(id_to_token.begin(), right_it);
480 float left_score = (left_id < token_scores.size()) ?
481 token_scores[left_id] : 0.0f;
482 float right_score = (right_id < token_scores.size()) ?
483 token_scores[right_id] : 0.0f;
484 float combined_score = left_score + right_score;
485
486 if (combined_score > best_combined_score) {
487 best_combined_score = combined_score;
488 best_left = left;
489 best_right = right;
490 }
491 }
492 }
493
494 // If we found a valid decomposition, add it as a merge rule
495 if (!best_left.empty() && !best_right.empty()) {
496 std::string merge_key = best_left + best_right;
497 if (generated_merges.find(merge_key) == generated_merges.end()) {
498 generated_merges[merge_key] = merge_rank++;
499 Logger::debug("Generated merge: '" + best_left + "' + '" + best_right + "' -> '" + token + "' (rank " + std::to_string(merge_rank-1) + ")");
500 }
501 }
502
503 processed_tokens.insert(token);
504
505 // Limit the number of merges to prevent excessive computation
506 if (merge_rank >= 50000) {
507 Logger::info("Reached maximum merge limit (50000), stopping generation");
508 break;
509 }
510 }
511
512 Logger::info("Generated " + std::to_string(generated_merges.size()) + " BPE merge rules from vocabulary and scores");
513 return generated_merges;
514}
515
516Tokenizer::Tokenizer(const GGUFData& gguf_data, const ModelConfig& config)
517 : tokenizer_family_(config.tokenizer_family),
518 initialized_from_gguf_(true) {
519 Logger::info("Initializing Tokenizer from GGUFData...");
520 std::string family_str_gguf = "UNKNOWN";
521 if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA_SENTENCEPIECE) family_str_gguf = "LLAMA_SENTENCEPIECE";
522 else if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA3_TIKTOKEN) family_str_gguf = "LLAMA3_TIKTOKEN";
523 Logger::info(std::string("Tokenizer family from ModelConfig: ") + family_str_gguf);
524
525 // Attempt to load chat template from GGUF metadata
526 try {
527 auto it = gguf_data.metadata.find("tokenizer.chat_template");
528 if (it != gguf_data.metadata.end()) {
529 if (std::holds_alternative<std::string>(it->second)) {
530 gguf_chat_template_ = std::get<std::string>(it->second);
531 if (!gguf_chat_template_.empty()) {
532 Logger::info("[Tokenizer GGUF Init] Found and loaded 'tokenizer.chat_template' from GGUF metadata.");
533 // Further log the template content if it's not too long, or a snippet
534 size_t log_len = std::min(gguf_chat_template_.length(), (size_t)70); // Log up to 70 chars
535 std::string template_snippet = gguf_chat_template_.substr(0, log_len);
536 if (gguf_chat_template_.length() > log_len) template_snippet += "...";
537 // Replace newlines with printable \n for one-line logging
538 std::string loggable_snippet;
539 for (char ch : template_snippet) {
540 if (ch == '\n') loggable_snippet += "\\n";
541 else if (ch == '\r') loggable_snippet += "\\r";
542 else if (ch == '\t') loggable_snippet += "\\t";
543 else if (std::isprint(static_cast<unsigned char>(ch))) loggable_snippet += ch;
544 else loggable_snippet += "."; // Replace non-printable with a dot
545 }
546 Logger::debug("[Tokenizer GGUF Init] Chat template snippet: " + loggable_snippet);
547 } else {
548 Logger::info("[Tokenizer GGUF Init] 'tokenizer.chat_template' found in GGUF metadata but is empty.");
549 }
550 } else {
551 Logger::warning("[Tokenizer GGUF Init] 'tokenizer.chat_template' found in GGUF metadata but is not a string type.");
552 }
553 } else {
554 Logger::info("[Tokenizer GGUF Init] 'tokenizer.chat_template' not found in GGUF metadata.");
555 }
556 } catch (const std::exception& e) {
557 Logger::error("[Tokenizer GGUF Init] Exception while trying to access 'tokenizer.chat_template': " + std::string(e.what()));
558 }
559
560 if (gguf_data.tokenizer_tokens.empty()) {
561 throw std::runtime_error(
562 "GGUF data does not contain 'tokenizer.ggml.tokens'");
563 }
564
565 // Common vocabulary loading
566 id_to_token_ = gguf_data.tokenizer_tokens;
567 token_to_id_.clear(); // Ensure map is clear before populating
568 token_to_id_.reserve(id_to_token_.size());
569 for (size_t i = 0; i < id_to_token_.size(); ++i) {
570 token_to_id_[id_to_token_[i]] = static_cast<int>(i);
571
572 if (static_cast<int>(i) == 1734) {
573 const std::string& token_at_1734 = id_to_token_[i];
574 std::string escaped_token_1734;
575 for (char c : token_at_1734) {
576 if (c == '\n') escaped_token_1734 += "\\n";
577 else if (c == '\r') escaped_token_1734 += "\\r";
578 else if (c == '\t') escaped_token_1734 += "\\t";
579 else if (c == '\\') escaped_token_1734 += "\\\\";
580 else if (std::isprint(static_cast<unsigned char>(c))) escaped_token_1734 += c;
581 else {
582 std::stringstream ss_hex;
583 ss_hex << "<0x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(static_cast<unsigned char>(c)) << ">";
584 escaped_token_1734 += ss_hex.str();
585 }
586 }
587 Logger::info("[GGUF_VOCAB_SCAN] Token string at ID 1734 is: '" + escaped_token_1734 + "' (length: " + std::to_string(token_at_1734.length()) + ")");
588 }
589
590 }
591 Logger::info("Loaded " + std::to_string(id_to_token_.size()) +
592 " tokens from GGUF tokenizer_tokens.");
593
594 // Log first few tokens for inspection
595 if (id_to_token_.size() > 0) {
596 std::string first_few_tokens_log = "First few (up to 10 or vocab size) GGUF tokens: ";
597 for (size_t i = 0; i < std::min((size_t)10, id_to_token_.size()); ++i) {
598 first_few_tokens_log += "ID[" + std::to_string(i) + "]='";
599 // Safely print token, escaping non-printables for logging
600 for (char c_tok : id_to_token_[i]) {
601 if (std::isprint(static_cast<unsigned char>(c_tok))) {
602 first_few_tokens_log += c_tok;
603 } else {
604 std::stringstream ss_hex;
605 ss_hex << "<0x" << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(static_cast<unsigned char>(c_tok)) << ">";
606 first_few_tokens_log += ss_hex.str();
607 }
608 }
609 first_few_tokens_log += "' ";
610 }
611 Logger::info(first_few_tokens_log);
612 }
613
614 // Conditional loading based on family
617 Logger::info("Configuring for LLAMA3_TIKTOKEN (gpt2-style BPE).");
618
619 if (gguf_data.tokenizer_merges.empty()) {
620 Logger::warning("Llama 3 Tiktoken family specified, but GGUF data does not contain 'tokenizer.ggml.merges'. Tiktoken BPE may not function correctly without explicit merges.");
621 } else {
622 bpe_merges_.clear();
623 int rank = 0;
624 // Removed sample_merges vector and related logging logic
625 for (const std::string& merge_str : gguf_data.tokenizer_merges) {
626 std::string part1, part2;
627 size_t space_pos = merge_str.find(' ');
628 if (space_pos != std::string::npos && space_pos > 0 && space_pos < merge_str.length() - 1) {
629 part1 = merge_str.substr(0, space_pos);
630 part2 = merge_str.substr(space_pos + 1);
631 std::string merged = part1 + part2;
632 bpe_merges_[merged] = rank++; // Simplified rank assignment
633 } else {
634 Logger::warning("Skipping malformed Tiktoken merge rule from GGUF: '" + merge_str + "'");
635 }
636 }
637
638 Logger::info("Processed " + std::to_string(bpe_merges_.size()) +
639 " Tiktoken merges from GGUF tokenizer_merges into bpe_merges_ map with ranks.");
640 }
641 // Scores are usually not the primary driver for Tiktoken BPE but load if present.
642 if (!gguf_data.tokenizer_scores.empty()) {
643 Logger::info("Llama 3 GGUF contains " + std::to_string(gguf_data.tokenizer_scores.size()) + " scores. Loaded.");
644 token_scores_ = gguf_data.tokenizer_scores;
645 }
646
647 // DEBUGGING: Log vocab/merges for neoplasm
648 Logger::debug("[DEBUG_VOCAB] LLAMA3_TIKTOKEN bpe_merges_ size: " + std::to_string(bpe_merges_.size()));
649 std::string target_token_neoplasm = BPE_SPACE_CHAR + "neoplasm"; // "Ġneoplasm"
650 std::string target_sub_ne = BPE_SPACE_CHAR + "ne"; // "Ġne"
651 std::string target_sub_o = BPE_SPACE_CHAR + "o"; // "Ġo"
652 std::string target_sub_oplasm = "oplasm";
653 std::string target_sub_goplasm = BPE_SPACE_CHAR + "oplasm"; // "Ġoplasm"
654
655 auto check_and_log_vocab = [&](const std::string& token_to_check) {
656 if (token_to_id_.count(token_to_check)) {
657 Logger::debug("[DEBUG_VOCAB] Found '" + token_to_check + "' in vocab with ID: " + std::to_string(token_to_id_.at(token_to_check)));
658 } else {
659 Logger::debug("[DEBUG_VOCAB] Token '" + token_to_check + "' NOT FOUND in vocab.");
660 }
661 };
662
663 auto check_and_log_merge = [&](const std::string& p1, const std::string& p2) {
664 auto merge_it = bpe_merges_.find(p1 + p2);
665 if (merge_it != bpe_merges_.end()) {
666 Logger::debug("[DEBUG_VOCAB] Found merge for '" + p1 + "' + '" + p2 + "' ('" + (p1+p2) + "') with rank: " + std::to_string(merge_it->second));
667 } else {
668 Logger::debug("[DEBUG_VOCAB] Merge for '" + p1 + "' + '" + p2 + "' ('" + (p1+p2) + "') NOT FOUND.");
669 }
670 };
671
674 Logger::info("Configuring for LLAMA_SENTENCEPIECE.");
675 if (!gguf_data.tokenizer_scores.empty()) {
677 Logger::info("Loaded " + std::to_string(token_scores_.size()) + " token scores from GGUF for SentencePiece style.");
678 if (id_to_token_.size() != token_scores_.size()) {
679 Logger::warning("GGUF (SentencePiece path) token and score array sizes mismatch: tokens=" +
680 std::to_string(id_to_token_.size()) + ", scores=" + std::to_string(token_scores_.size()));
681 }
682 } else {
683 Logger::warning("SentencePiece family: No scores found. BPE merging will likely not work if no other SP model data is available.");
684 }
685
686
687 if (!gguf_data.tokenizer_merges.empty()) {
688 Logger::info("SentencePiece family path: Found 'tokenizer.ggml.merges' in GGUF. Loading them into bpe_merges_ map.");
689 bpe_merges_.clear();
690 int rank = 0;
691 for (const std::string& merge_str : gguf_data.tokenizer_merges) {
692 std::string part1, part2;
693 size_t space_pos = merge_str.find(' ');
694 if (space_pos != std::string::npos && space_pos > 0 && space_pos < merge_str.length() - 1) {
695 part1 = merge_str.substr(0, space_pos);
696 part2 = merge_str.substr(space_pos + 1);
697 bpe_merges_[part1 + part2] = rank++;
698 } else {
699 Logger::warning("Skipping malformed SentencePiece merge rule from GGUF: '" + merge_str + "'");
700 }
701 }
702 Logger::info("Processed " + std::to_string(bpe_merges_.size()) +
703 " merges from GGUF tokenizer_merges into bpe_merges_ map (SentencePiece path).");
704 } else {
705 Logger::warning("SentencePiece family path: No 'tokenizer.ggml.merges' found in GGUF. Attempting to generate merges from vocabulary and scores...");
706
707 // Generate BPE merges from vocabulary and scores (llama.cpp approach)
709 if (!generated_merges.empty()) {
710 bpe_merges_ = std::move(generated_merges);
711 Logger::info("Successfully generated " + std::to_string(bpe_merges_.size()) + " BPE merges from vocabulary and scores for SentencePiece tokenizer");
712 } else {
713 Logger::warning("Failed to generate BPE merges. Tokenization may be suboptimal for this model.");
714 }
715 }
716
717
718 } else { // UNKNOWN tokenizer family
720 Logger::warning("Tokenizer family is UNKNOWN. Tokenizer may not function as expected. Will attempt to load basic vocab and scores if present.");
721 if (!gguf_data.tokenizer_scores.empty()) {
723 Logger::info("Loaded " + std::to_string(token_scores_.size()) + " token scores from GGUF for UNKNOWN family as a fallback.");
724 }
725 }
726
727 if (!gguf_data.tokenizer_token_types.empty() && gguf_data.tokenizer_token_types.size() == id_to_token_.size()){
728 token_types_.resize(gguf_data.tokenizer_token_types.size());
729 std::transform(gguf_data.tokenizer_token_types.begin(),
730 gguf_data.tokenizer_token_types.end(), token_types_.begin(),
731 [](unsigned int u) { return static_cast<int32_t>(u); });
732 Logger::info("Loaded and transformed " + std::to_string(token_types_.size()) + " token types from GGUF.");
733
734 // Populate byte_char_to_id_ and added_tokens_ using token_types_
735 byte_char_to_id_.clear();
736 added_tokens_.clear();
737 id_to_added_token_.clear();
738 int byte_tokens_from_type = 0;
739 int special_tokens_from_type = 0;
740
741 for (size_t i = 0; i < token_types_.size(); ++i) {
742 int32_t tt = token_types_[i];
743 const std::string& token_str = id_to_token_[i];
744 int token_id = static_cast<int>(i);
745 bool processed_as_byte = false; // Flag to track if token was handled as byte
746
747 if (tt == 6) { // LLAMA_TOKEN_TYPE_BYTE
748 bool added_byte = false;
749 if (token_str.length() == 1) {
750 byte_char_to_id_[token_str[0]] = token_id;
751 added_byte = true;
752 } else if (token_str.rfind("<0x", 0) == 0 && token_str.back() == '>' && token_str.length() == 6) {
753 try {
754 int byte_val = std::stoi(token_str.substr(3, 2), nullptr, 16);
755 byte_char_to_id_[static_cast<char>(byte_val)] = token_id;
756 added_byte = true;
757 } catch (const std::exception& e) {
758 Logger::warning("Could not parse byte value from type-BYTE (6) token string: '" + token_str + "'");
759 }
760 } else {
761 // Log if a token is marked as BYTE but doesn't match expected formats
762 Logger::warning("Token type is BYTE (6) but does not match single char or <0xNN> format: '" + token_str + "' ID: " + std::to_string(token_id));
763 }
764
765 if(added_byte) {
766 byte_tokens_from_type++;
767 processed_as_byte = true;
768 }
769 }
770 if (!processed_as_byte && (tt == 2 || tt == 3 || tt == 4 || tt == 5)) {
771 if (added_tokens_.find(token_str) == added_tokens_.end()) {
772 added_tokens_[token_str] = token_id;
773 id_to_added_token_[token_id] = token_str;
774 special_tokens_from_type++;
775 }
776 }
777 }
778 // Log message now reflects bytes identified from type 6 tokens
779 Logger::info("From GGUF token_types (BYTE=6): Identified " + std::to_string(byte_tokens_from_type) + " byte tokens (for byte_char_to_id_). " +
780 "Identified " + std::to_string(special_tokens_from_type) + " other special/added tokens (types 2,3,4,5).");
781
782
783 // If token types were processed but yielded no byte tokens for Tiktoken, try the fallback vocab scan.
784 if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA3_TIKTOKEN && byte_tokens_from_type == 0) {
785 Logger::warning("No byte tokens identified via token_types metadata for Tiktoken. Attempting fallback scan of vocabulary.");
786 // Manually populate byte_char_to_id_ by checking vocab for <0xNN> and literal byte strings
787 byte_char_to_id_.clear(); // Clear again in case some non-byte type 3 were added incorrectly before
788 int bytes_found_in_vocab_fallback = 0;
789 for (int i = 0; i < 256; ++i) {
790 std::stringstream ss_hex_repr;
791 ss_hex_repr << "<0x" << std::hex << std::setw(2) << std::setfill('0') << i << ">";
792 std::string byte_token_str_repr = ss_hex_repr.str();
793 std::string literal_byte_char_str(1, static_cast<char>(i));
794 bool is_space_char = (static_cast<char>(i) == ' ');
795
796
797 if (is_space_char) {
798 Logger::debug("[BYTE_FALLBACK_DEBUG] Checking for SPACE (byte 32). Looking for '<0x20>' and ' '.");
799 }
800
801
802 auto it = token_to_id_.find(byte_token_str_repr);
803 if (it != token_to_id_.end()) {
804
805 if (is_space_char) {
806 Logger::debug("[BYTE_FALLBACK_DEBUG] Found '<0x20>' token with ID: " + std::to_string(it->second) + ". Adding to map.");
807 }
808
809 byte_char_to_id_[static_cast<char>(i)] = it->second;
810 bytes_found_in_vocab_fallback++;
811 } else {
812 // Also check for literal single-byte characters if they are printable
813 if (std::isprint(static_cast<unsigned char>(i))) {
814 auto lit_it = token_to_id_.find(literal_byte_char_str);
815 if (lit_it != token_to_id_.end()) {
816
817 if (is_space_char) {
818 Logger::debug("[BYTE_FALLBACK_DEBUG] Did not find '<0x20>', but found literal ' ' token with ID: " + std::to_string(lit_it->second));
819 }
820
821
822 // Ensure this token ID hasn't already been mapped (e.g., by a <0xNN> entry)
823 bool id_already_mapped = false;
824 for(const auto& pair : byte_char_to_id_) { if (pair.second == lit_it->second) { id_already_mapped = true; break; } }
825 if (!id_already_mapped) {
826
827 if (is_space_char) {
828 Logger::debug("[BYTE_FALLBACK_DEBUG] ID " + std::to_string(lit_it->second) + " for ' ' not already mapped. Adding to map.");
829 }
830
831 byte_char_to_id_[static_cast<char>(i)] = lit_it->second;
832 bytes_found_in_vocab_fallback++;
833 // Don't need a continue here, just prevents double-counting if somehow both exist
834 } else {
835
836 if (is_space_char) {
837 Logger::debug("[BYTE_FALLBACK_DEBUG] ID " + std::to_string(lit_it->second) + " for ' ' was already mapped (likely by <0x20>). Skipping literal add.");
838 }
839
840 }
841 } else {
842
843 if (is_space_char) {
844 Logger::debug("[BYTE_FALLBACK_DEBUG] Did not find '<0x20>' OR literal ' ' token in vocab.");
845 }
846
847 }
848 } else {
849
850 if (is_space_char) {
851 Logger::debug("[BYTE_FALLBACK_DEBUG] Did not find '<0x20>' token, and space is not printable, so didn't check for literal ' '.");
852 }
853
854 }
855 }
856 }
857 Logger::info("Fallback byte_char_to_id_ map population: Found representations for " + std::to_string(bytes_found_in_vocab_fallback) +
858 " byte values in GGUF vocab (using <0xNN> or literal). Intended for Tiktoken BPE.");
859 byte_tokens_from_type = bytes_found_in_vocab_fallback;
860 }
861
862
863 } else {
864 Logger::warning("GGUF tokenizer_token_types array missing or size mismatch. Byte token and special token identification will be limited.");
866 byte_char_to_id_.clear();
867 int bytes_found_in_vocab_fallback = 0;
868 for (int i = 0; i < 256; ++i) {
869 std::stringstream ss_hex_repr;
870 ss_hex_repr << "<0x" << std::hex << std::setw(2) << std::setfill('0') << i << ">";
871 std::string byte_token_str_repr = ss_hex_repr.str();
872 std::string literal_byte_char_str(1, static_cast<char>(i));
873 bool is_space_char = (static_cast<char>(i) == ' ');
874
875
876 if (is_space_char) {
877 Logger::debug("[BYTE_FALLBACK_DEBUG] Checking for SPACE (byte 32). Looking for '<0x20>' and ' '.");
878 }
879
880
881 auto it = token_to_id_.find(byte_token_str_repr);
882 if (it != token_to_id_.end()) {
883
884 if (is_space_char) {
885 Logger::debug("[BYTE_FALLBACK_DEBUG] Found '<0x20>' token with ID: " + std::to_string(it->second) + ". Adding to map.");
886 }
887
888 byte_char_to_id_[static_cast<char>(i)] = it->second;
889 bytes_found_in_vocab_fallback++;
890 } else {
891 // Also check for literal single-byte characters if they are printable
892 if (std::isprint(static_cast<unsigned char>(i))) {
893 auto lit_it = token_to_id_.find(literal_byte_char_str);
894 if (lit_it != token_to_id_.end()) {
895
896 if (is_space_char) {
897 Logger::debug("[BYTE_FALLBACK_DEBUG] Did not find '<0x20>', but found literal ' ' token with ID: " + std::to_string(lit_it->second));
898 }
899
900
901 // Ensure this token ID hasn't already been mapped (e.g., by a <0xNN> entry)
902 bool id_already_mapped = false;
903 for(const auto& pair : byte_char_to_id_) { if (pair.second == lit_it->second) { id_already_mapped = true; break; } }
904 if (!id_already_mapped) {
905
906 if (is_space_char) {
907 Logger::debug("[BYTE_FALLBACK_DEBUG] ID " + std::to_string(lit_it->second) + " for ' ' not already mapped. Adding to map.");
908 }
909
910 byte_char_to_id_[static_cast<char>(i)] = lit_it->second;
911 bytes_found_in_vocab_fallback++;
912 continue;
913 } else {
914
915 if (is_space_char) {
916 Logger::debug("[BYTE_FALLBACK_DEBUG] ID " + std::to_string(lit_it->second) + " for ' ' was already mapped (likely by <0x20>). Skipping literal add.");
917 }
918
919 }
920 } else {
921
922 if (is_space_char) {
923 Logger::debug("[BYTE_FALLBACK_DEBUG] Did not find '<0x20>' OR literal ' ' token in vocab.");
924 }
925
926 }
927 } else {
928
929 if (is_space_char) {
930 Logger::debug("[BYTE_FALLBACK_DEBUG] Did not find '<0x20>' token, and space is not printable, so didn't check for literal ' '.");
931 }
932
933 }
934 }
935 }
936 Logger::info("Fallback byte_char_to_id_ map population: Found representations for " + std::to_string(bytes_found_in_vocab_fallback) +
937 " byte values in GGUF vocab (using <0xNN> or literal). Intended for Tiktoken BPE.");
938 }
939 }
940
941
942 if (byte_char_to_id_.find(' ') == byte_char_to_id_.end()) {
943 Logger::info("[GENERAL_BYTE_FALLBACK] Space ' ' not found in byte_char_to_id_. Attempting to populate from vocab.");
944 int general_fallback_bytes_added = 0;
945 for (int i = 0; i < 256; ++i) {
946 char current_char = static_cast<char>(i);
947 // Only add if not already present from a more primary source (like token_types)
948 if (byte_char_to_id_.count(current_char)) {
949 continue;
950 }
951
952 std::stringstream ss_hex_repr;
953 ss_hex_repr << "<0x" << std::hex << std::setw(2) << std::setfill('0') << i << ">";
954 std::string byte_token_str_repr = ss_hex_repr.str();
955 std::string literal_byte_char_str(1, current_char);
956
957 auto it_hex = token_to_id_.find(byte_token_str_repr);
958 if (it_hex != token_to_id_.end()) {
959 byte_char_to_id_[current_char] = it_hex->second;
960 general_fallback_bytes_added++;
961 if (current_char == ' ') Logger::debug("[GENERAL_BYTE_FALLBACK] Found space as '" + byte_token_str_repr + "' -> ID: " + std::to_string(it_hex->second));
962 } else {
963 auto it_lit = token_to_id_.find(literal_byte_char_str);
964 if (it_lit != token_to_id_.end()) {
965 byte_char_to_id_[current_char] = it_lit->second;
966 general_fallback_bytes_added++;
967 if (current_char == ' ') Logger::debug("[GENERAL_BYTE_FALLBACK] Found space as literal '" + literal_byte_char_str + "' -> ID: " + std::to_string(it_lit->second));
968 }
969 }
970 }
971 Logger::info("[GENERAL_BYTE_FALLBACK] Added " + std::to_string(general_fallback_bytes_added) +
972 " new entries to byte_char_to_id_ map. Final size: " + std::to_string(byte_char_to_id_.size()));
973 if (byte_char_to_id_.find(' ') == byte_char_to_id_.end()) {
974 Logger::warning("[GENERAL_BYTE_FALLBACK] Space ' ' still not found in byte_char_to_id_ after fallback scan!");
975 }
976
977
978 if (byte_char_to_id_.find(' ') == byte_char_to_id_.end()) { // Check again if space wasn't found by hex/literal
979 const std::string sp_space_token = "\xE2\x96\x81"; // U+2581
980 auto it_sp_space = token_to_id_.find(sp_space_token);
981 if (it_sp_space != token_to_id_.end()) {
982 byte_char_to_id_[' '] = it_sp_space->second; // Map standard space char to the ID of the SP space token
983 Logger::info("[GENERAL_BYTE_FALLBACK] SUCCESS: Found SentencePiece space token '" + sp_space_token +
984 "' (ID: " + std::to_string(it_sp_space->second) + "). Mapped standard space ' ' to this ID.");
985 } else {
986 // This is the final warning if space still not found
987 Logger::warning("[GENERAL_BYTE_FALLBACK] Space ' ' still not found in byte_char_to_id_ after fallback scan AND specific SP space check!");
988 }
989 }
990 }
991
992
997
998
999 // Ensure UNK token ID is valid (non-negative). Default to 0 if invalid.
1000 if (unk_token_id_ < 0) {
1001 Logger::warning("[Tokenizer GGUF Init] UNK token ID from config was invalid (" + std::to_string(unk_token_id_) + "). Forcing to 0.");
1002 unk_token_id_ = 0;
1003 }
1004
1005
1006 auto setup_special_token = [&](const std::string& name, int& id_field, std::string& str_field, const std::string& default_str_val) {
1007 if (id_field >= 0 && static_cast<size_t>(id_field) < id_to_token_.size()) {
1008 str_field = id_to_token_[id_field];
1009 } else {
1010 str_field = default_str_val; // Use default string if ID is invalid or -1
1011 if (id_field != -1) { // Log warning only if ID was supposed to be valid but wasn't found
1012 Logger::warning(name + " token ID " + std::to_string(id_field) +
1013 " from config is out of vocab bounds or invalid. Using default string: '" + default_str_val + "'.");
1014 }
1015 // Attempt to find the default string in the vocab to set its ID, if ID was bad
1016 auto it = token_to_id_.find(default_str_val);
1017 if (it != token_to_id_.end()) {
1018 if (id_field == -1 || (id_field >=0 && static_cast<size_t>(id_field) >= id_to_token_.size()) ) { // If original ID was invalid/none
1019 id_field = it->second;
1020 Logger::info("Set " + name + " token ID to " + std::to_string(id_field) + " based on default string '" + default_str_val + "'.");
1021 }
1022 } else if (id_field != -1) {
1023 Logger::warning("Default string '" + default_str_val + "' for " + name + " token also not found in vocab.");
1024 }
1025 }
1026 };
1027
1028 setup_special_token("BOS", bos_token_id_, bos_token_, "<s>");
1029 setup_special_token("EOS", eos_token_id_, eos_token_, "</s>");
1030 setup_special_token("UNK", unk_token_id_, unk_token_, "<unk>");
1031 // For PAD, if config.pad_token_id is -1, it means no pad token. String should be empty.
1032 // If it's a valid ID, str_field will be set. If it's an invalid positive ID, str_field becomes <pad> by default.
1033 if (config.pad_token_id == -1) {
1034 pad_token_ = ""; // Explicitly empty if ID is -1
1035 // bos_token_id_ etc are already set directly from config so no change needed for id_field here for pad_token_id_ == -1
1036 } else {
1037 setup_special_token("PAD", pad_token_id_, pad_token_, "<pad>");
1038 }
1039
1040 Logger::info("Final Special Tokens (GGUF constructor): BOS ID=" + std::to_string(bos_token_id_) +
1041 " ('" + bos_token_ + "'), EOS ID=" + std::to_string(eos_token_id_) + " ('" + eos_token_ +
1042 "'), UNK ID=" + std::to_string(unk_token_id_) + " ('" + unk_token_ +
1043 "'), PAD ID=" + std::to_string(pad_token_id_) + " ('" + pad_token_ + "\\\\')");
1044
1045 Logger::info(std::string("Tokenizer successfully initialized from GGUFData. Final type: ") +
1046 (type_ == Type::TIKTOKEN_BPE ? "TIKTOKEN_BPE" :
1047 (type_ == Type::SENTENCEPIECE_BPE ? "SENTENCEPIECE_BPE" : "UNKNOWN")));
1048}
1049
1051 const std::vector<std::string>& tokens) const {
1052 std::string result;
1053
1054 bool using_space_prefix = false;
1055 const std::string gpt2_space_prefix = "\xC4\xA0";
1056 const std::string tinyllama_space_prefix = "\xE2\x96\x81";
1057
1058 for (const auto& token : tokens) {
1059 if (!token.empty() &&
1060 ((token.size() >= 2 && token.substr(0, 2) == gpt2_space_prefix) ||
1061 (token.size() >= 3 && token.substr(0, 3) == tinyllama_space_prefix))) {
1062 using_space_prefix = true;
1063 break;
1064 }
1065 }
1066
1067 for (size_t i = 0; i < tokens.size(); ++i) {
1068 std::string token = tokens[i];
1069
1070 if (using_space_prefix) {
1071 if (!token.empty()) {
1072 if (token.size() >= 2 && token.substr(0, 2) == gpt2_space_prefix) {
1073 if (token.size() > 2) {
1074 result += ' ' + token.substr(2);
1075 } else {
1076 result += ' ';
1077 }
1078 }
1079
1080 else if (token.size() >= 3 &&
1081 token.substr(0, 3) == tinyllama_space_prefix) {
1082 if (token.size() > 3) {
1083 result += ' ' + token.substr(3);
1084 } else {
1085 result += ' ';
1086 }
1087 } else {
1088 result += token;
1089 }
1090 }
1091 } else {
1092 if (token.size() >= 4 && token.substr(token.size() - 4) == "</w>") {
1093 result += token.substr(0, token.size() - 4);
1094 result += " ";
1095 continue;
1096 }
1097
1098 if (i > 0) {
1099 result += " ";
1100 }
1101
1102 result += token;
1103 }
1104 }
1105
1106 if (!result.empty() && result[0] == ' ') {
1107 result = result.substr(1);
1108 }
1109
1110 std::string clean_result;
1111 bool prev_space = false;
1112 for (char c : result) {
1113 if (c == ' ') {
1114 if (!prev_space) {
1115 clean_result += c;
1116 }
1117 prev_space = true;
1118 } else {
1119 clean_result += c;
1120 prev_space = false;
1121 }
1122 }
1123
1124 return clean_result;
1125}
1126
1127
1128
1129// Function to determine the length of a UTF-8 character based on its first byte.
1130// Similar to the lookup method used in llama.cpp.
1131// Keep this function local to this translation unit.
1132namespace {
1133 inline size_t unicode_char_len(char src) {
1134 const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
1135 uint8_t highbits = static_cast<uint8_t>(src) >> 4;
1136 // Bounds check for safety, although highbits should always be 0-15
1137 return (highbits < 16) ? lookup[highbits] : 1; // Default to 1 for invalid highbits
1138 }
1139} // end anonymous namespace
1140
1141std::vector<int> Tokenizer::encode(const std::string& text, bool add_bos,
1142 bool add_eos,
1143 PreTokenizeMethod pre_tok_override) const {
1144 std::vector<int> final_ids; // Initialize the vector to store final token IDs.
1145 std::string family_str_enc = "UNKNOWN"; // String representation for logging.
1146
1147 // Determine the tokenizer family string for logging purposes.
1148 if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA_SENTENCEPIECE) family_str_enc = "LLAMA_SENTENCEPIECE";
1149 else if (tokenizer_family_ == ModelConfig::TokenizerFamily::LLAMA3_TIKTOKEN) family_str_enc = "LLAMA3_TIKTOKEN";
1150
1151 // Log the start of the encoding process with relevant parameters.
1152 std::stringstream log_ss_main;
1153 log_ss_main << "[ENCODE] Encoding text: '" << text << "'"
1154 << " (add_bos=" << add_bos
1155 << ", add_eos=" << add_eos
1156 << ", family=" << family_str_enc
1157 << ", pre_tok_override=" << static_cast<int>(pre_tok_override)
1158 << ")";
1159 Logger::debug(log_ss_main.str());
1160
1161
1163 Logger::debug("[ENCODE] Using LLAMA3_TIKTOKEN (bpe_tokenize_to_ids) path.");
1164
1165 if (add_bos && this->bos_token_id_ != -1) {
1166 // Check if the text already starts with the BOS token string
1167 if (this->bos_token_.empty() || text.rfind(this->bos_token_, 0) != 0) {
1168 final_ids.push_back(this->bos_token_id_);
1169 Logger::debug("[ENCODE Llama 3 Path] Added BOS token: " + std::to_string(this->bos_token_id_) +
1170 " (text did not already start with it).");
1171 } else {
1172 Logger::debug("[ENCODE Llama 3 Path] BOS token flag was true, but text already started with BOS string. Skipping explicit BOS ID addition.");
1173 }
1174 }
1175
1176 std::vector<int> token_ids = this->bpe_tokenize_to_ids(text, false, false, false);
1177 final_ids.insert(final_ids.end(), token_ids.begin(), token_ids.end());
1178
1179 if (add_eos && this->eos_token_id_ != -1) {
1180 final_ids.push_back(this->eos_token_id_);
1181 Logger::debug("[ENCODE Llama 3 Path] Added EOS token: " + std::to_string(this->eos_token_id_));
1182 }
1183
1184
1186 Logger::debug("[ENCODE] Using LLAMA_SENTENCEPIECE (old SentencePiece/BPE logic) path.");
1187
1188 if (!this->initialized_from_gguf_) {
1190 "[ENCODE SPM Path] Using simplified merge-based tokenizer path (calling "
1191 "bpe_tokenize directly).");
1192
1193 std::vector<std::string> bpe_pieces = this->bpe_tokenize(text);
1194 Logger::debug("[ENCODE SPM Path] bpe_tokenize returned " +
1195 std::to_string(bpe_pieces.size()) + " pieces.");
1196
1197 final_ids = this->tokens_to_ids(bpe_pieces);
1198
1199 if (add_bos && this->bos_token_id_ != -1) {
1200 final_ids.insert(final_ids.begin(), this->bos_token_id_);
1201 Logger::debug("[ENCODE SPM Path] Prepended BOS token: " +
1202 std::to_string(this->bos_token_id_));
1203 }
1204 if (add_eos && this->eos_token_id_ != -1) {
1205 final_ids.push_back(this->eos_token_id_);
1206 Logger::debug("[ENCODE SPM Path] Appended EOS token: " +
1207 std::to_string(this->eos_token_id_));
1208 }
1209 Logger::debug("[ENCODE SPM Path] Final IDs (Simplified Merge Path): " +
1210 std::to_string(final_ids.size()) + " tokens.");
1211 } else {
1212 Logger::debug("[ENCODE SPM Path] Using GGUF score-based tokenizer path.");
1213
1214 if (add_bos && this->bos_token_id_ != -1) {
1215 final_ids.push_back(this->bos_token_id_);
1216 Logger::debug("[ENCODE SPM GGUF Path] Added BOS token: " +
1217 std::to_string(this->bos_token_id_));
1218 }
1219
1220 std::vector<std::pair<std::string, bool>> segments;
1221 std::string text_to_process = text;
1222 PreTokenizeMethod method_to_use;
1223
1224 if (pre_tok_override == PreTokenizeMethod::DEFAULT) {
1225 if (this->pre_tok_type_ == "default") {
1226 method_to_use = PreTokenizeMethod::DEFAULT;
1227 Logger::debug("[ENCODE SPM GGUF Path] Using DEFAULT pre-tokenization (split by special, BPE for non-specials).");
1228 } else if (this->pre_tok_type_ == "llama") {
1229 method_to_use = PreTokenizeMethod::LLAMA_REGEX;
1230 Logger::debug("[ENCODE SPM GGUF Path] Using LLAMA_REGEX pre-tokenization.");
1231 } else {
1232 Logger::warning("[ENCODE SPM GGUF Path] pre_tok_type_ is '" + this->pre_tok_type_ + "' or unset. Defaulting to WHITESPACE pre-tokenization for GGUF/SPM path.");
1233 method_to_use = PreTokenizeMethod::DEFAULT; // Fallback to DEFAULT, whitespace logic handled below
1234 }
1235 } else {
1236 method_to_use = pre_tok_override;
1237 }
1238
1239 std::string method_str_log;
1240 if (method_to_use == PreTokenizeMethod::LLAMA_REGEX) method_str_log = "LLAMA_REGEX";
1241 else method_str_log = "DEFAULT (Special Token Split or WHITESPACE Fallback)";
1242 Logger::debug("[ENCODE SPM GGUF Path] Effective pre-tokenization method: " + method_str_log);
1243
1244 if (method_to_use == PreTokenizeMethod::DEFAULT && this->pre_tok_type_ == "default") {
1245 std::unordered_set<std::string> all_special_tokens_set;
1246 for (const auto& pair : this->added_tokens_) {
1247 if (!pair.first.empty()) all_special_tokens_set.insert(pair.first);
1248 }
1249 if (!this->bos_token_.empty()) all_special_tokens_set.insert(this->bos_token_);
1250 if (!this->eos_token_.empty()) all_special_tokens_set.insert(this->eos_token_);
1251 if (!this->unk_token_.empty()) all_special_tokens_set.insert(this->unk_token_);
1252
1253 std::string special_pattern_str = "(";
1254 bool first_special = true;
1255 for (const std::string& st : all_special_tokens_set) {
1256 if (!first_special) special_pattern_str += "|";
1257 std::string escaped_st;
1258 for (char c : st) {
1259 if (strchr(".^$*+?()[{\\|", c)) escaped_st += '\\';
1260 escaped_st += c;
1261 }
1262 special_pattern_str += escaped_st;
1263 first_special = false;
1264 }
1265 special_pattern_str += ")";
1266
1267 if (all_special_tokens_set.empty()) {
1268 Logger::debug("[ENCODE SPM GGUF Path] No special tokens defined for DEFAULT pre-tok. Treating whole text as one segment.");
1269 segments.push_back({text_to_process, false});
1270 } else {
1271 Logger::debug("[ENCODE SPM GGUF Path] Splitting by special tokens regex: " + special_pattern_str);
1272 try {
1273 boost::regex special_regex(special_pattern_str);
1274 boost::sregex_iterator it(text_to_process.begin(), text_to_process.end(), special_regex);
1275 boost::sregex_iterator end;
1276 size_t last_pos = 0;
1277 while (it != end) {
1278 boost::smatch match = *it;
1279 if (match.position() > last_pos) {
1280 segments.push_back({text_to_process.substr(last_pos, match.position() - last_pos), false});
1281 }
1282 segments.push_back({match.str(), true});
1283 last_pos = match.position() + match.length();
1284 ++it;
1285 }
1286 if (last_pos < text_to_process.length()) {
1287 segments.push_back({text_to_process.substr(last_pos), false});
1288 }
1289 } catch (const boost::regex_error& e) {
1290 Logger::error("[ENCODE SPM GGUF Path] Regex error splitting by special tokens: " + std::string(e.what()) + ". Treating as single segment.");
1291 segments.clear();
1292 segments.push_back({text_to_process, false});
1293 }
1294 }
1295 } else if (method_to_use == PreTokenizeMethod::LLAMA_REGEX) {
1296 boost::regex llama_segment_regex( R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s[:alpha:][:digit:]]+|\\s+(?!\\S)|\\s+)");
1297 Logger::debug("[ENCODE SPM GGUF Path] Using LLAMA_REGEX for pre-tokenization.");
1298 try {
1299 boost::sregex_iterator it(text_to_process.begin(), text_to_process.end(), llama_segment_regex);
1300 boost::sregex_iterator end;
1301 size_t last_pos = 0;
1302 while(it != end) {
1303 boost::smatch match = *it;
1304 if (match.position() > last_pos) {
1305 segments.push_back({text_to_process.substr(last_pos, match.position() - last_pos), false});
1306 }
1307 segments.push_back({match.str(), false});
1308 last_pos = match.position() + match.length();
1309 ++it;
1310 }
1311 if (last_pos < text_to_process.length()) {
1312 segments.push_back({text_to_process.substr(last_pos), false});
1313 }
1314 } catch (const boost::regex_error& e) {
1315 Logger::error("[ENCODE SPM GGUF Path] Regex error during LLAMA_REGEX splitting: " + std::string(e.what()) + ". Treating as single segment.");
1316 segments.clear();
1317 segments.push_back({text_to_process, false});
1318 }
1319 } else { // WHITESPACE or fallback (method_to_use is DEFAULT here if pre_tok_type_ was not "default")
1320 Logger::debug("[ENCODE SPM GGUF Path] Using WHITESPACE pre-tokenization (or fallback).");
1321 std::string current_ws_segment;
1322 for (char c : text_to_process) {
1323 if (std::isspace(static_cast<unsigned char>(c))) {
1324 if (!current_ws_segment.empty()) {
1325 segments.push_back({current_ws_segment, false});
1326 current_ws_segment.clear();
1327 }
1328 segments.push_back({{c}, false});
1329 } else {
1330 current_ws_segment += c;
1331 }
1332 }
1333 if (!current_ws_segment.empty()) {
1334 segments.push_back({current_ws_segment, false});
1335 }
1336 }
1337
1338 Logger::debug("[ENCODE SPM GGUF Path] Pre-tokenization resulted in " + std::to_string(segments.size()) + " segments.");
1339
1340 std::vector<int> segment_ids;
1341 for (const auto& seg_pair : segments) {
1342 const std::string& segment_str = seg_pair.first;
1343 bool is_special = seg_pair.second;
1344
1345 if (segment_str.empty()) continue;
1346
1347 if (is_special) {
1348 auto it = this->token_to_id_.find(segment_str);
1349 if (it != this->token_to_id_.end()) {
1350 segment_ids.push_back(it->second);
1351 Logger::debug("[ENCODE SPM GGUF Path] Found special segment: '" + segment_str + "' -> ID: " + std::to_string(it->second));
1352 } else {
1353 Logger::warning("[ENCODE SPM GGUF Path] Special segment '" + segment_str +
1354 "' not in vocab. Using UNK ID: " + std::to_string(this->unk_token_id_));
1355 segment_ids.push_back(this->unk_token_id_);
1356 }
1357 } else {
1358 std::vector<std::string> pieces = this->bpe_tokenize_from_scores(segment_str);
1359 std::vector<int> piece_ids = this->tokens_to_ids(pieces);
1360 segment_ids.insert(segment_ids.end(), piece_ids.begin(), piece_ids.end());
1361 Logger::debug("[ENCODE SPM GGUF Path] BPE for non-special segment '" + segment_str + "' -> " + std::to_string(piece_ids.size()) + " IDs.");
1362 }
1363 }
1364 final_ids.insert(final_ids.end(), segment_ids.begin(), segment_ids.end());
1365
1366 if (add_eos && this->eos_token_id_ != -1) {
1367 final_ids.push_back(this->eos_token_id_);
1368 Logger::debug("[ENCODE SPM GGUF Path] Appended EOS token: " +
1369 std::to_string(this->eos_token_id_));
1370 }
1371 Logger::debug("[ENCODE SPM GGUF Path] Final IDs (GGUF Score Path): " + std::to_string(final_ids.size()) + " tokens.");
1372 }
1373 } // This closes the LLAMA_SENTENCEPIECE block
1374 else { // Unknown Tokenizer Family Path
1375 Logger::error("[ENCODE] Unknown or unsupported tokenizer family: " + family_str_enc + ". Cannot encode text.");
1376 if (add_bos && this->bos_token_id_ != -1) {
1377 final_ids.push_back(this->bos_token_id_);
1378 Logger::debug("[ENCODE Unknown Path] Added BOS token: " + std::to_string(this->bos_token_id_));
1379 }
1380 if (add_eos && this->eos_token_id_ != -1) {
1381 final_ids.push_back(this->eos_token_id_);
1382 Logger::debug("[ENCODE Unknown Path] Added EOS token: " + std::to_string(this->eos_token_id_));
1383 }
1384 }
1385
1386 Logger::debug("[ENCODE] Final IDs count (end of function): " + std::to_string(final_ids.size()));
1387 if (final_ids.empty() && !text.empty()) {
1388 Logger::warning("[ENCODE] Tokenization resulted in empty ID list for non-empty text: '" + text + "'");
1389 }
1390
1391 return final_ids;
1392}
1393
1394std::string Tokenizer::decode(const std::vector<int>& ids,
1395 bool skip_special_tokens) const {
1396 // Dispatch based on tokenizer family
1398 // Use the dedicated SentencePiece decoding logic
1399 return decode_sentencepiece(ids, skip_special_tokens);
1400 }
1401
1402 // Default to Llama 3 / Tiktoken BPE decoding logic
1403 Logger::debug("[decode] Decoding using Llama 3 / Tiktoken logic.");
1404 std::stringstream ss;
1405 bool first_token = true;
1406
1407 for (int id : ids) {
1408 // Handle potential invalid IDs first
1409 if (id < 0 || static_cast<size_t>(id) >= id_to_token_.size()) {
1410 if (!skip_special_tokens) { // Only show invalid ID if not skipping specials
1411 ss << "[INVALID_ID:" << id << "]";
1412 Logger::debug("[decode] Invalid token ID: " + std::to_string(id));
1413 first_token = false; // Considered as outputted content
1414 }
1415 continue;
1416 }
1417
1418 // Handle special tokens skip
1419 if (skip_special_tokens) {
1420 if (id == bos_token_id_ || id == eos_token_id_ || id == pad_token_id_ || id == unk_token_id_) {
1421 Logger::debug("[decode] Skipping special token ID: " + std::to_string(id) +
1422 " (BOS/EOS/PAD/UNK)");
1423 continue;
1424 }
1425 if (id_to_added_token_.count(id)) {
1426 Logger::debug("[decode] Skipping added token ID: " + std::to_string(id));
1427 continue;
1428 }
1429 }
1430
1431 std::string token = id_to_token_[id];
1432 std::string token_debug = token;
1433 // Make non-printable characters visible in logs
1434 for (size_t i = 0; i < token_debug.length(); i++) {
1435 if (!std::isprint(static_cast<unsigned char>(token_debug[i]))) {
1436 char hex[5];
1437 snprintf(hex, sizeof(hex), "\\x%02x", static_cast<unsigned char>(token_debug[i]));
1438 token_debug.replace(i, 1, hex);
1439 i += 3; // Skip the added hex chars
1440 }
1441 }
1442 Logger::debug("[decode] Processing token ID " + std::to_string(id) +
1443 ": '" + token_debug + "'");
1444
1445 if (token.empty()) {
1446 if (!skip_special_tokens && unk_token_id_ != -1) {
1447 token = unk_token_;
1448 Logger::debug("[decode] Empty token replaced with UNK token");
1449 } else {
1450 Logger::debug("[decode] Empty token skipped");
1451 continue;
1452 }
1453 }
1454
1455 if (token.size() >= BPE_SPACE_CHAR.size() &&
1456 token.substr(0, BPE_SPACE_CHAR.size()) == BPE_SPACE_CHAR) {
1457 if (!first_token) {
1458 ss << " ";
1459 Logger::debug("[decode] Added space before token with BPE_SPACE_CHAR prefix");
1460 }
1461 ss << token.substr(BPE_SPACE_CHAR.size());
1462 Logger::debug("[decode] Added token content after BPE_SPACE_CHAR: '" +
1463 token.substr(BPE_SPACE_CHAR.size()) + "'");
1464 first_token = false;
1465 } else { // Token does NOT start with Ġ
1466 ss << token; // Append the token itself
1467 Logger::debug("[decode] Added non-BPE_SPACE_CHAR token: '" + token + "'");
1468 first_token = false; // Still set first_token to false as content has been added
1469 }
1470 }
1471 std::string final_text = ss.str();
1472 Logger::debug("[decode] Final decoded text: '" + final_text + "'");
1473 return final_text;
1474}
1475
1476std::string Tokenizer::decode_sentencepiece(const std::vector<int>& ids,
1477 bool skip_special_tokens) const {
1478 Logger::debug("[decode_sentencepiece] Decoding using SentencePiece logic.");
1479
1480 std::stringstream ss;
1481 bool first_token = true; // Flag to handle potential leading space correctly
1482 const std::string sp_space_prefix = "\xE2\x96\x81"; // Actual UTF-8 sequence for U+2581
1483 const std::string gpt2_space_prefix = "\xC4\xA0"; // Actual UTF-8 sequence for U+0120 (Ġ) - check just in case
1484
1485 for (int id : ids) {
1486 std::string token_str; // Holds the string representation of the current token
1487 bool is_special_or_invalid = false;
1488
1489 // Handle special token skipping FIRST
1490 if (skip_special_tokens) {
1491 if (id == bos_token_id_ || id == eos_token_id_ || id == pad_token_id_ || id == unk_token_id_) {
1492 is_special_or_invalid = true;
1493 continue; // Skip this token entirely
1494 }
1495 if (id_to_added_token_.count(id)) {
1496 is_special_or_invalid = true;
1497 continue; // Skip added special tokens
1498 }
1499 }
1500
1501 // Get the token string (handling invalid IDs)
1502 if (id >= 0 && static_cast<size_t>(id) < id_to_token_.size()) {
1503 token_str = id_to_token_[id];
1504 } else {
1505 auto added_it = id_to_added_token_.find(id);
1506 if (added_it != id_to_added_token_.end()) {
1507 token_str = added_it->second; // It's an added token (might be special)
1508 } else { // Truly invalid ID
1509 // Don't output if skipping specials/invalid
1510 if (!skip_special_tokens) {
1511 token_str = "[INVALID_ID:" + std::to_string(id) + "]";
1512 } else {
1513 token_str = ""; // Effectively skip
1514 }
1515 is_special_or_invalid = true; // Treat invalid ID as special for spacing
1516 }
1517 }
1518
1519 // NEW: Check for <0xNN> format and convert if necessary
1520 if (token_str.length() == 6 && token_str.rfind("<0x", 0) == 0 && token_str[5] == '>') {
1521 try {
1522 std::string hex_val_str = token_str.substr(3, 2);
1523 int byte_val = std::stoi(hex_val_str, nullptr, 16);
1524 token_str = std::string(1, static_cast<char>(byte_val));
1525 Logger::debug("[decode_sentencepiece] Converted '<0x" + hex_val_str + ">' to char: " + std::to_string(byte_val));
1526 } catch (const std::exception& e) {
1527 Logger::warning("[decode_sentencepiece] Failed to parse hex from token: '" + token_str + "'. Error: " + e.what());
1528 // Keep original token_str if parsing fails
1529 }
1530 }
1531
1532 if (token_str.empty() && !is_special_or_invalid) {
1533 if (unk_token_id_ != -1) {
1534 // Check if UNK should be skipped
1535 if (!skip_special_tokens || unk_token_id_ != id) { // Check if UNK is the ID causing emptiness if not skipping
1536 token_str = unk_token_;
1537 } else {
1538 // Skipping specials, and UNK is considered special
1539 is_special_or_invalid = true;
1540 continue; // Skip this empty token
1541 }
1542 } else {
1543 if (!skip_special_tokens){
1544 token_str = "[EMPTY_TOKEN_FOR_ID:" + std::to_string(id) + "]";
1545 } else {
1546 is_special_or_invalid = true;
1547 continue; // Skip
1548 }
1549 }
1550 if (!is_special_or_invalid) { // Only log if we actually output something
1551 Logger::warning("[decode_sentencepiece] Encountered empty token string for valid ID " + std::to_string(id) +
1552 ". Using: '" + token_str + "'");
1553 }
1554 }
1555
1556 // Handle cases where the token IS the space prefix itself
1557 if (token_str == sp_space_prefix || token_str == gpt2_space_prefix) {
1558 if (first_token) {
1559 // If it's the first token and it's just a space prefix, ignore it and wait for actual content.
1560 // first_token remains true.
1561 Logger::debug("[decode_sentencepiece] Ignored leading standalone space prefix token.");
1562 continue;
1563 }
1564 // If not the first token, and it's a standalone space prefix, ensure one space is added.
1565 std::string current_output_check = ss.str();
1566 if (current_output_check.empty() || current_output_check.back() != ' ') {
1567 ss << " ";
1568 Logger::debug("[decode_sentencepiece] Added space for standalone prefix token mid-sequence.");
1569 }
1570 first_token = false; // A space was effectively output.
1571 continue; // Move to the next token.
1572 }
1573
1574 // Process the token string: handle prefixes and append to result
1575 if (!token_str.empty()) {
1576 bool starts_with_sp_prefix = (token_str.rfind(sp_space_prefix, 0) == 0);
1577 // Check for GPT2 prefix only if SP prefix wasn't found
1578 bool starts_with_gpt2_prefix = (!starts_with_sp_prefix && token_str.rfind(gpt2_space_prefix, 0) == 0);
1579
1580 if (starts_with_sp_prefix) {
1581 std::string current_output = ss.str();
1582 if (!first_token && (current_output.empty() || current_output.back() != ' ')) {
1583 ss << " ";
1584 }
1585 std::string content = token_str.substr(sp_space_prefix.length());
1586 // RE-ADD: Trim any leading literal spaces from the content itself
1587 size_t first_non_space = content.find_first_not_of(' ');
1588 if (std::string::npos != first_non_space) {
1589 content = content.substr(first_non_space);
1590 }
1591 ss << content;
1592 first_token = false; // We have outputted something
1593 }
1594 else if (starts_with_gpt2_prefix) { // Handle Ġ prefix if found
1595 std::string current_output = ss.str();
1596 if (!first_token && (current_output.empty() || current_output.back() != ' ')) {
1597 ss << " ";
1598 }
1599 std::string content = token_str.substr(gpt2_space_prefix.length());
1600 // RE-ADD: Trim any leading literal spaces from the content itself
1601 size_t first_non_space = content.find_first_not_of(' ');
1602 if (std::string::npos != first_non_space) {
1603 content = content.substr(first_non_space);
1604 }
1605 ss << content;
1606 first_token = false;
1607 }
1608 else { // Token does not start with a known space prefix
1609 ss << token_str;
1610 first_token = false; // Mark that we've outputted content
1611 }
1612 }
1613 } // End for loop over IDs
1614
1615 return ss.str();
1616}
1617
1618// Helper function to replace all occurrences of a substring
1619static std::string replace_all(std::string str, const std::string& from, const std::string& to) {
1620 size_t start_pos = 0;
1621 while((start_pos = str.find(from, start_pos)) != std::string::npos) {
1622 str.replace(start_pos, from.length(), to);
1623 start_pos += to.length(); // Handles cases where 'to' is a substring of 'from'
1624 }
1625 return str;
1626}
1627
1628std::string Tokenizer::apply_chat_template(const std::string& user_prompt,
1629 const std::string& system_message,
1630 const ModelConfig& config) const {
1631 // Check if the GGUF template seems like a Jinja2 template
1632 bool is_jinja_template = (!gguf_chat_template_.empty() &&
1633 (gguf_chat_template_.find("{%") != std::string::npos ||
1634 gguf_chat_template_.find("{{") != std::string::npos));
1635
1636 // Log the determined template type and GGUF template content for debugging
1637 if (!gguf_chat_template_.empty()) {
1638 Logger::debug("[apply_chat_template] GGUF chat template content (first 100 chars): " + gguf_chat_template_.substr(0, 100));
1639 if (is_jinja_template) {
1640 Logger::info("[apply_chat_template] GGUF chat template detected as Jinja2-like.");
1641 } else {
1642 Logger::info("[apply_chat_template] GGUF chat template detected as simple placeholder template.");
1643 }
1644 }
1645
1646 if (!gguf_chat_template_.empty() && !is_jinja_template) {
1647 Logger::info("[apply_chat_template] Using simple GGUF chat template (non-Jinja).");
1648 std::string processed_template = gguf_chat_template_;
1649
1650 std::string bos_s = this->bos_token_id_ != -1 ? this->bos_token_ : "";
1651 std::string eos_s = this->eos_token_id_ != -1 ? this->eos_token_ : "";
1652
1653 processed_template = replace_all(processed_template, "{{bos_token}}", bos_s);
1654 processed_template = replace_all(processed_template, "{{eos_token}}", eos_s);
1655 processed_template = replace_all(processed_template, "{{user_prompt}}", user_prompt);
1656 if (!system_message.empty()) {
1657 processed_template = replace_all(processed_template, "{{system_message}}", system_message);
1658 } else {
1659 processed_template = replace_all(processed_template, "{{system_message}}", "");
1660 }
1661
1662 std::string snippet_to_log = processed_template.substr(0, std::min((size_t)100, processed_template.length()));
1663 Logger::debug(std::string("[apply_chat_template] Processed simple GGUF template. Snippet: ") + snippet_to_log);
1664 return processed_template;
1665 } else {
1666 if (is_jinja_template) {
1667 Logger::warning("[apply_chat_template] GGUF chat template appears to be Jinja2, which is not fully supported by this C++ implementation. Falling back to hardcoded Llama 3 Instruct template. The model's intended GGUF chat template will be ignored.");
1668 } else { // Empty GGUF template
1669 Logger::info("[apply_chat_template] GGUF chat template not found or empty. Falling back to hardcoded Llama 3 Instruct template.");
1670 }
1671
1672 // Fallback to a hardcoded Llama 3 Instruct style template
1673 auto find_added_token_str_fallback = [&](const std::string& content,
1674 const std::string& fallback_value) -> std::string {
1675 if (this->added_tokens_.count(content)) return content;
1676 if (this->token_to_id_.count(content)) return content;
1677 if ((!this->added_tokens_.empty() || !this->token_to_id_.empty()) && content.rfind("<",0) == 0 && content.rfind("|",0) != std::string::npos && content.rfind(">",0) == content.length()-1) {
1678 Logger::warning("[apply_chat_template_fallback] Could not find special token string '" + content +
1679 "' in added_tokens_ or vocab. Using default/fallback string: '" + fallback_value + "'");
1680 }
1681 return fallback_value;
1682 };
1683
1684 // Use member versions of bos_token_, etc. which are set up during constructor
1685 std::string bos_s_fallback = this->bos_token_id_ != -1 ? this->bos_token_ : "<s>";
1686 // For Llama3 specific tokens, ensure they are correctly fetched or have sensible defaults
1687 std::string start_header_s_fallback = find_added_token_str_fallback("<|start_header_id|>", "<|start_header_id|>");
1688 std::string end_header_s_fallback = find_added_token_str_fallback("<|end_header_id|>", "<|end_header_id|>");
1689 std::string eot_s_fallback = find_added_token_str_fallback("<|eot_id|>", "<|eot_id|>");
1690 // For role names, they are typically just strings, not special tokens themselves
1691 std::string system_role_name = "system";
1692 std::string user_role_name = "user";
1693 std::string assistant_role_name = "assistant";
1694
1695 std::stringstream ss;
1696 ss << bos_s_fallback;
1697 if (!system_message.empty()) {
1698 ss << start_header_s_fallback << system_role_name << end_header_s_fallback << "\n\n" << system_message << eot_s_fallback;
1699 }
1700 ss << start_header_s_fallback << user_role_name << end_header_s_fallback << "\n\n" << user_prompt << eot_s_fallback;
1701 ss << start_header_s_fallback << assistant_role_name << end_header_s_fallback << "\n\n";
1702
1703 Logger::info("[apply_chat_template] Applied hardcoded Llama 3 Instruct-like chat template as fallback. Prompt snippet: " + ss.str().substr(0,100));
1704 return ss.str();
1705 }
1706}
1707
1709 const std::string& vocab_path,
1710 std::unordered_map<std::string, int>& token_to_id_map,
1711 std::vector<std::string>& id_to_token_vec) {
1712 token_to_id_map.clear();
1713 id_to_token_vec.clear();
1714
1715 try {
1716 std::ifstream file(vocab_path);
1717 if (!file.is_open()) {
1718 throw std::runtime_error("Failed to open vocabulary file: " + vocab_path);
1719 }
1720
1721 json vocab_json;
1722 file >> vocab_json;
1723
1724 // Try to determine format (HuggingFace tokenizer.json vs. plain vocab)
1725 if (vocab_json.contains("model") && vocab_json["model"].is_object() &&
1726 vocab_json["model"].contains("vocab") && vocab_json["model"]["vocab"].is_object()) {
1727 Logger::info("load_vocab_from_json: Detected HuggingFace tokenizer.json format.");
1728 const auto& vocab = vocab_json["model"]["vocab"];
1729 size_t max_id = 0;
1730
1731 // First pass to determine max_id to size id_to_token_vec appropriately
1732 for (auto it = vocab.begin(); it != vocab.end(); ++it) {
1733 int id = it.value().get<int>();
1734 if (id < 0) {
1735 Logger::warning("load_vocab_from_json: Skipping token with negative ID: " + it.key());
1736 continue;
1737 }
1738 if (static_cast<size_t>(id) > max_id) {
1739 max_id = static_cast<size_t>(id);
1740 }
1741 }
1742 id_to_token_vec.resize(max_id + 1, "<unk>"); // Initialize with unk_token_ or a placeholder
1743
1744 // Second pass to populate maps
1745 for (auto it = vocab.begin(); it != vocab.end(); ++it) {
1746 std::string token = it.key();
1747 int id = it.value().get<int>();
1748 if (id < 0) continue; // Already warned
1749
1750 token_to_id_map[token] = id;
1751 if (static_cast<size_t>(id) < id_to_token_vec.size()) {
1752 id_to_token_vec[id] = token;
1753 } else {
1754 // This should ideally not happen if resize was correct
1755 Logger::warning("load_vocab_from_json: ID out of bounds during vocab population: " + std::to_string(id));
1756 }
1757 }
1758
1759 // Ensure `added_tokens_` (member) is populated here.
1760 if (vocab_json.contains("added_tokens") &&
1761 vocab_json["added_tokens"].is_array()) {
1762 const auto& added_tokens_json = vocab_json["added_tokens"];
1763 Logger::info("load_vocab_from_json: Processing " + std::to_string(added_tokens_json.size()) + " added_tokens.");
1764 for (const auto& token_obj : added_tokens_json) {
1765 if (token_obj.contains("content") && token_obj.contains("id")) {
1766 std::string token_content = token_obj["content"];
1767 int token_id = token_obj["id"];
1768
1769 if (token_id < 0) {
1770 Logger::warning("load_vocab_from_json: Skipping added_token with negative ID: " + token_content);
1771 continue;
1772 }
1773
1774 // Update maps for added tokens
1775 token_to_id_map[token_content] = token_id; // Also add to the main map for direct lookup
1776 this->added_tokens_[token_content] = token_id; // Populate member variable
1777 this->id_to_added_token_[token_id] = token_content; // Populate member variable
1778
1779 if (static_cast<size_t>(token_id) >= id_to_token_vec.size()) {
1780 id_to_token_vec.resize(token_id + 1, "<unk>"); // Ensure vector is large enough
1781 }
1782 id_to_token_vec[token_id] = token_content; // Ensure id_to_token_vec also has added tokens
1783
1784 if (token_content == this->unk_token_) this->unk_token_id_ = token_id;
1785 else if (token_content == this->bos_token_) this->bos_token_id_ = token_id;
1786 else if (token_content == this->eos_token_) this->eos_token_id_ = token_id;
1787 else if (token_content == this->pad_token_) this->pad_token_id_ = token_id;
1788
1789 Logger::debug("load_vocab_from_json: Processed added_token: '" + token_content + "' with ID " +
1790 std::to_string(token_id));
1791 }
1792 }
1793 }
1794
1795 } else if (vocab_json.is_object()) {
1796 Logger::info("load_vocab_from_json: Detected plain vocabulary format (direct map).");
1797 size_t max_id = 0;
1798 for (auto it = vocab_json.begin(); it != vocab_json.end(); ++it) {
1799 int id = it.value().get<int>();
1800 if (id < 0) continue;
1801 if (static_cast<size_t>(id) > max_id) {
1802 max_id = static_cast<size_t>(id);
1803 }
1804 }
1805 id_to_token_vec.resize(max_id + 1, "<unk>");
1806
1807 for (auto it = vocab_json.begin(); it != vocab_json.end(); ++it) {
1808 std::string token = it.key();
1809 int id = it.value().get<int>();
1810 if (id < 0) {
1811 Logger::warning("load_vocab_from_json: Skipping token with negative ID: " + token);
1812 continue;
1813 }
1814 token_to_id_map[token] = id;
1815 if (static_cast<size_t>(id) < id_to_token_vec.size()) {
1816 id_to_token_vec[id] = token;
1817 }
1818
1819 if (token == this->unk_token_) this->unk_token_id_ = id;
1820 else if (token == this->bos_token_) this->bos_token_id_ = id;
1821 else if (token == this->eos_token_) this->eos_token_id_ = id;
1822 else if (token == this->pad_token_) this->pad_token_id_ = id;
1823 }
1824 } else {
1825 throw std::runtime_error("load_vocab_from_json: Vocabulary JSON has an unsupported format.");
1826 }
1827
1828 for (size_t i = 0; i < id_to_token_vec.size(); ++i) {
1829 if (id_to_token_vec[i].empty() || id_to_token_vec[i] == "<unk>") {
1830 auto added_it = this->id_to_added_token_.find(static_cast<int>(i));
1831 if (added_it != this->id_to_added_token_.end()) {
1832 id_to_token_vec[i] = added_it->second;
1833 } else if (id_to_token_vec[i].empty()) {
1834 if (id_to_token_vec[i].empty()) id_to_token_vec[i] = "<missing_id_" + std::to_string(i) + ">";
1835 }
1836 }
1837 }
1838
1839 Logger::info("load_vocab_from_json: Loaded vocabulary with " +
1840 std::to_string(token_to_id_map.size()) + " unique token strings and " +
1841 std::to_string(id_to_token_vec.size()) + " ID entries.");
1842 Logger::debug("load_vocab_from_json: Special tokens after JSON load: UNK_ID=" + std::to_string(unk_token_id_) +
1843 " ('" + unk_token_ + "'), BOS_ID=" + std::to_string(bos_token_id_) +
1844 " ('" + bos_token_ + "'), EOS_ID=" + std::to_string(eos_token_id_) +
1845 " ('" + eos_token_ + "'), PAD_ID=" + std::to_string(pad_token_id_) +
1846 " ('" + pad_token_ + "')");
1847
1848 } catch (const json::exception& e) {
1849 throw std::runtime_error("Error parsing vocabulary JSON from " + vocab_path + ": " + e.what());
1850 } catch (const std::exception& e) {
1851 throw std::runtime_error("Error loading vocabulary from " + vocab_path + ": " + std::string(e.what()));
1852 }
1853}
1854
1855void Tokenizer::load_sentencepiece_model(const std::string& model_path) {
1856 Logger::warning("load_sentencepiece_model: Loading from SentencePiece model file ('" + model_path + "') is currently not implemented.");
1858}
1859
1860void Tokenizer::load_bpe_merges_from_json(const std::string& tokenizer_json_path) {
1861 try {
1862 std::ifstream file(tokenizer_json_path);
1863 if (!file.is_open()) {
1864 throw std::runtime_error("load_bpe_merges_from_json: Failed to open BPE merges file: " + tokenizer_json_path);
1865 }
1866
1867 json model_json;
1868 file >> model_json;
1869
1870 bpe_merges_.clear(); // Ensure merges map is empty before loading
1871
1872 // Check for HuggingFace tokenizer.json structure first
1873 // Merges are typically under model.merges
1874 if (model_json.contains("model") && model_json["model"].is_object()) {
1875 const auto& model_section = model_json["model"];
1876 if (model_section.contains("merges") && model_section["merges"].is_array()) {
1877 Logger::info("load_bpe_merges_from_json: Detected HuggingFace tokenizer.json format with BPE merges from: " + tokenizer_json_path);
1878 const auto& merges = model_section["merges"];
1879 int rank = 0; // Use index as rank for merges from HF JSON
1880 for (const auto& merge_entry_json : merges) {
1881 if (merge_entry_json.is_string()) {
1882 std::string merge_entry = merge_entry_json.get<std::string>();
1883 size_t space_pos = merge_entry.find(' ');
1884
1885 // Expecting format "part1 part2"
1886 if (space_pos != std::string::npos && space_pos > 0 && space_pos < merge_entry.length() - 1) {
1887 std::string first = merge_entry.substr(0, space_pos);
1888 std::string second = merge_entry.substr(space_pos + 1);
1889 // Combine without the space to form the key for the map
1890 std::string pair_key = first + second;
1891 bpe_merges_[pair_key] = rank++;
1892 } else {
1893 Logger::warning("load_bpe_merges_from_json: Skipping malformed merge rule: '" + merge_entry + "' from " + tokenizer_json_path);
1894 }
1895 } else {
1896 Logger::warning("load_bpe_merges_from_json: Merge entry is not a string, skipping. File: " + tokenizer_json_path);
1897 }
1898 }
1899 } else {
1900 // Handle case where tokenizer.json doesn't have expected BPE structure
1901 Logger::warning("load_bpe_merges_from_json: HuggingFace format detected, but no 'model.merges' array found in model section of: " + tokenizer_json_path);
1902 }
1903 }
1904 // Fallback: Check for a simple top-level "merges" array (less common format)
1905 else if (model_json.contains("merges") && model_json["merges"].is_array()) {
1906 Logger::info("load_bpe_merges_from_json: Detected simple top-level 'merges' array format in: " + tokenizer_json_path);
1907 const auto& merges = model_json["merges"];
1908 int rank = 0;
1909 for (const auto& merge_entry_json : merges) {
1910 if (merge_entry_json.is_string()) {
1911 std::string merge_entry = merge_entry_json.get<std::string>();
1912 size_t space_pos = merge_entry.find(' ');
1913 if (space_pos != std::string::npos && space_pos > 0 && space_pos < merge_entry.length() - 1) {
1914 std::string first = merge_entry.substr(0, space_pos);
1915 std::string second = merge_entry.substr(space_pos + 1);
1916 std::string pair_key = first + second;
1917 bpe_merges_[pair_key] = rank++;
1918 } else {
1919 Logger::warning("load_bpe_merges_from_json: Skipping malformed merge rule from top-level array: '" + merge_entry + "' from " + tokenizer_json_path);
1920 }
1921 } else {
1922 Logger::warning("load_bpe_merges_from_json: Merge entry in top-level array is not a string, skipping. File: " + tokenizer_json_path);
1923 }
1924 }
1925 } else {
1926 // If neither format is found
1927 throw std::runtime_error(
1928 "load_bpe_merges_from_json: Unsupported BPE model format: no 'model.merges' or top-level 'merges' array found in '" + tokenizer_json_path + "'");
1929 }
1930
1931 if (bpe_merges_.empty()) {
1932 Logger::warning("load_bpe_merges_from_json: No BPE merges were loaded from the file: " + tokenizer_json_path);
1933 } else {
1934 Logger::info("load_bpe_merges_from_json: Loaded " + std::to_string(bpe_merges_.size()) +
1935 " BPE merges with ranks from " + tokenizer_json_path);
1936 }
1937
1938 } catch (const json::exception& e) {
1939 throw std::runtime_error("Error parsing BPE merges JSON from " + tokenizer_json_path + ": " + e.what());
1940 } catch (const std::exception& e) {
1941 throw std::runtime_error("An unexpected error occurred while loading BPE merges from " + tokenizer_json_path + ": " + std::string(e.what()));
1942 }
1943}
1944
1945std::string Tokenizer::capitalize_first_letter(std::string s) const { // Added Tokenizer:: scope and const
1946 if (s.empty()) return s;
1947
1948
1949 size_t first_letter_pos = 0;
1950 const std::string sp_space = "\xE2\x96\x81"; // SentencePiece space U+2581
1951
1952 // Check if the string starts with the SentencePiece space
1953 // Using s instead of result
1954 if (s.rfind(sp_space, 0) == 0) {
1955 // If it does, the actual first letter is after the space prefix
1956 if (s.length() > sp_space.length()) {
1957 first_letter_pos = sp_space.length();
1958 } else {
1959 // String is just the space prefix, nothing to capitalize
1960 return s;
1961 }
1962 }
1963
1964 // Capitalize the character at the determined position
1965 // Create result string here to modify
1966 std::string result = s;
1967 if (first_letter_pos < result.length()) {
1968 result[first_letter_pos] =
1969 std::toupper(static_cast<unsigned char>(result[first_letter_pos]));
1970 }
1971
1972 return result;
1973}
1974
1975std::vector<std::string> Tokenizer::bpe_tokenize(const std::string& text) const {
1976 Logger::debug("[Original bpe_tokenize for SentencePiece] Entered. bpe_merges_ size: " + std::to_string(bpe_merges_.size()));
1977 std::vector<std::string> all_final_tokens;
1978 const std::string sp_space_prefix = "\xE2\x96\x81"; // SentencePiece space U+2581
1979
1980 std::vector<std::string> pieces;
1981 std::string current_piece;
1982 bool last_char_was_space = true;
1983
1984 for (char c : text) {
1985 if (std::isspace(static_cast<unsigned char>(c))) {
1986 if (!current_piece.empty()) {
1987 pieces.push_back(current_piece);
1988 current_piece.clear();
1989 }
1990 pieces.push_back(std::string(1, c));
1991 last_char_was_space = true;
1992 } else {
1993 current_piece += c;
1994 last_char_was_space = false;
1995 }
1996 }
1997 if (!current_piece.empty()) {
1998 pieces.push_back(current_piece);
1999 }
2000
2001 Logger::debug("[Original bpe_tokenize for SentencePiece] Split text into " + std::to_string(pieces.size()) + " pieces (words/spaces).");
2002
2003 bool next_word_needs_prefix = true;
2004
2005 for (const std::string& piece : pieces) {
2006 if (piece.empty()) continue;
2007
2008 bool piece_is_whitespace = std::all_of(piece.begin(), piece.end(),
2009 [](char c) { return std::isspace(static_cast<unsigned char>(c)); });
2010
2011 if (piece_is_whitespace) {
2012 next_word_needs_prefix = true;
2013 Logger::debug("[Original bpe_tokenize for SentencePiece] Piece '" + piece + "' is whitespace. Setting prefix flag.");
2014 continue;
2015 }
2016
2017 std::string word_to_process = piece;
2018 if (next_word_needs_prefix) {
2019 word_to_process = sp_space_prefix + word_to_process;
2020 Logger::debug("[Original bpe_tokenize for SentencePiece] Prefixed word: '" + piece + "' -> '" + word_to_process + "'");
2021 next_word_needs_prefix = false;
2022 } else {
2023 Logger::debug("[Original bpe_tokenize for SentencePiece] Processing word without prefix: '" + word_to_process + "'");
2024 }
2025
2026 std::vector<std::string> chars;
2027 for (size_t i = 0; i < word_to_process.size();) {
2028 size_t bytes = unicode_char_len(word_to_process[i]);
2029 if (i + bytes <= word_to_process.size()) {
2030 chars.push_back(word_to_process.substr(i, bytes));
2031 } else {
2032 Logger::warning("[Original bpe_tokenize for SentencePiece] Invalid UTF-8 near: '" + word_to_process.substr(i) + "'");
2033 chars.push_back(word_to_process.substr(i, 1));
2034 bytes = 1;
2035 }
2036 i += bytes;
2037 }
2038
2039 if (chars.empty()) {
2040 Logger::warning("[Original bpe_tokenize for SentencePiece] Word '" + word_to_process + "' produced no chars.");
2041 continue;
2042 }
2043
2044 bool changes = true;
2045 while (changes && chars.size() > 1) {
2046 changes = false;
2047 int best_rank = std::numeric_limits<int>::max();
2048 int best_i = -1;
2049
2050 for (size_t i = 0; i < chars.size() - 1; ++i) {
2051 std::string pair = chars[i] + chars[i + 1];
2052 auto it = bpe_merges_.find(pair);
2053 if (it != bpe_merges_.end() && it->second < best_rank) {
2054 best_rank = it->second;
2055 best_i = i;
2056 }
2057 }
2058
2059 if (best_i >= 0) {
2060 std::string merged = chars[best_i] + chars[best_i + 1];
2061 chars[best_i] = merged;
2062 chars.erase(chars.begin() + best_i + 1);
2063 changes = true;
2064 Logger::debug("[Original bpe_tokenize for SentencePiece] Applied merge: '" + merged + "' with rank " +
2065 std::to_string(best_rank));
2066 }
2067 }
2068 all_final_tokens.insert(all_final_tokens.end(), chars.begin(), chars.end());
2069 }
2070
2071 Logger::debug("[Original bpe_tokenize for SentencePiece] Final token count: " + std::to_string(all_final_tokens.size()));
2072 return all_final_tokens;
2073}
2074
2075const std::string& Tokenizer::get_gguf_chat_template() const {
2076 return gguf_chat_template_;
2077}
2078
2079// Helper to sort token map by key length (descending) for longest match
2080static std::vector<std::pair<std::string, int>> sort_tokens_by_length_desc(const std::unordered_map<std::string, int>& tokens_map) {
2081 std::vector<std::pair<std::string, int>> sorted_tokens;
2082 for (const auto& pair : tokens_map) {
2083 sorted_tokens.push_back(pair);
2084 }
2085 std::sort(sorted_tokens.begin(), sorted_tokens.end(),
2086 [](const auto& a, const auto& b) {
2087 return a.first.length() > b.first.length();
2088 });
2089 return sorted_tokens;
2090}
2091
2092// The bpe_tokenize_to_ids is now specifically for Tiktoken-like BPE (Llama 3)
2093// It assumes that if this function is called, the tokenizer_family_ is LLAMA3_TIKTOKEN.
2094std::vector<int> Tokenizer::bpe_tokenize_to_ids(const std::string& text,
2095 bool add_bos_token_param,
2096 bool add_eos_token_param,
2097 bool ignore_merges_param) const {
2098 Logger::debug(std::string("[bpe_tokenize_to_ids] Starting Tiktoken BPE tokenization for text length: ") + std::to_string(text.length()) +
2099 ", add_bos=" + std::to_string(add_bos_token_param) +
2100 ", add_eos=" + std::to_string(add_eos_token_param) +
2101 ", ignore_merges=" + std::to_string(ignore_merges_param) );
2102
2103 std::vector<int> output_ids;
2104
2105 if (add_bos_token_param) {
2106 if (bos_token_id_ == -1) {
2107 Logger::warning("[bpe_tokenize_to_ids] BOS token requested but bos_token_id_ is -1.");
2108 } else {
2109 output_ids.push_back(bos_token_id_);
2110 Logger::debug(std::string("[bpe_tokenize_to_ids] Added BOS token: ") + std::to_string(bos_token_id_));
2111 }
2112 }
2113
2114 const auto sorted_special_tokens = sort_tokens_by_length_desc(this->added_tokens_);
2115
2116 // TikToken regex pattern string
2117 const std::string tiktoken_pattern_str =
2118 R"(<\|[^|]+\||[[:alnum:]]+|\.(?![<|])|[^\s<|]+|\s+)"; // Updated regex
2119
2120 // Compile with boost::xpressive::sregex and icase flag
2121 const boost::xpressive::sregex tiktoken_pattern_ = boost::xpressive::sregex::compile(
2122 tiktoken_pattern_str,
2123 boost::xpressive::regex_constants::icase
2124 );
2125
2126 size_t current_idx = 0;
2127 while (current_idx < text.length()) {
2128 bool special_match_found = false;
2129 if (!sorted_special_tokens.empty()) {
2130 for (const auto& special_pair : sorted_special_tokens) {
2131 const std::string& special_text = special_pair.first;
2132 int special_id = special_pair.second;
2133 if (text.compare(current_idx, special_text.length(), special_text) == 0) {
2134 output_ids.push_back(special_id);
2135 Logger::debug("[bpe_tokenize_to_ids] Matched special token: '" + special_text + "' -> ID: " + std::to_string(special_id));
2136 current_idx += special_text.length();
2137 special_match_found = true;
2138 break;
2139 }
2140 }
2141 }
2142
2143 if (special_match_found) {
2144 continue;
2145 }
2146
2147 if (current_idx >= text.length()) break;
2148
2149 std::string remaining_text_view_str = text.substr(current_idx);
2150 boost::xpressive::smatch word_match;
2151
2152 if (!boost::xpressive::regex_search(remaining_text_view_str, word_match, tiktoken_pattern_, boost::xpressive::regex_constants::match_continuous)) {
2153 Logger::debug(std::string("[bpe_tokenize_to_ids] No more regex-matchable words at pos ") + std::to_string(current_idx) + ". Remainder: '" + remaining_text_view_str + "'");
2154 if (!remaining_text_view_str.empty()) {
2155 Logger::warning(std::string("[bpe_tokenize_to_ids] Regex could not process remainder. Processing byte-by-byte: '") + remaining_text_view_str + "'");
2156 for (char c : remaining_text_view_str) {
2157 std::string byte_str(1, c);
2158 auto it = token_to_id_.find(byte_str);
2159 if (it != token_to_id_.end()) {
2160 output_ids.push_back(it->second);
2161 } else {
2162 if (byte_char_to_id_.count(c)) {
2163 output_ids.push_back(byte_char_to_id_.at(c));
2164 } else if (unk_token_id_ != -1) {
2165 output_ids.push_back(unk_token_id_);
2166 Logger::warning(std::string("[bpe_tokenize_to_ids] Unrecognized byte '") + byte_str + std::string("' replaced with UNK."));
2167 } else {
2168 Logger::error(std::string("[bpe_tokenize_to_ids] Unrecognized byte '") + byte_str + std::string("' and no UNK token defined. Skipping."));
2169 }
2170 }
2171 }
2172 }
2173 current_idx = text.length();
2174 break;
2175 }
2176
2177 std::string original_word = word_match.str(0);
2178
2179 if (original_word.empty()){
2180 Logger::warning("[bpe_tokenize_to_ids] Regex search succeeded but matched an empty string. Advancing one char from pos " + std::to_string(current_idx));
2181 size_t advance_len = unicode_char_len(text[current_idx]);
2182 if (advance_len == 0) advance_len = 1;
2183
2184 std::string problematic_char_str = text.substr(current_idx, advance_len);
2185 auto it_char = token_to_id_.find(problematic_char_str);
2186 if (it_char != token_to_id_.end()) {
2187 output_ids.push_back(it_char->second);
2188 } else if (advance_len == 1 && byte_char_to_id_.count(problematic_char_str[0])) {
2189 output_ids.push_back(byte_char_to_id_.at(problematic_char_str[0]));
2190 } else if (unk_token_id_ != -1) {
2191 output_ids.push_back(unk_token_id_);
2192 Logger::debug("[bpe_tokenize_to_ids] Added UNK for unmatchable leading char after empty regex match: '" + problematic_char_str + "'");
2193 }
2194 current_idx += advance_len;
2195 continue;
2196 }
2197
2198 // Check if the entire original_word (matched by regex) is a known token (especially for <|...|> cases)
2199 auto direct_match_it = token_to_id_.find(original_word);
2200 if (direct_match_it != token_to_id_.end()) {
2201 output_ids.push_back(direct_match_it->second);
2202 Logger::debug("[bpe_tokenize_to_ids] Regex-matched word '" + original_word + "' is a direct token ID: " + std::to_string(direct_match_it->second));
2203 current_idx += original_word.length();
2204 continue;
2205 }
2206
2207 Logger::debug(std::string("[bpe_tokenize_to_ids] Processing regex-derived word for BPE: '") + original_word + "'");
2208
2209 // Convert leading space of original_word to BPE_SPACE_CHAR (Ġ) for Tiktoken-style BPE
2210 // This is crucial if the vocabulary expects space-prefixed tokens like "Ġword".
2211 std::string word_to_process = original_word;
2212 if (!word_to_process.empty() && word_to_process[0] == ' ') {
2213 if (word_to_process.length() > 1) {
2214 word_to_process = BPE_SPACE_CHAR + word_to_process.substr(1);
2215 } else { // Word is just a single space
2216 word_to_process = BPE_SPACE_CHAR;
2217 }
2218 Logger::debug(std::string("[bpe_tokenize_to_ids] Converted leading space. Word for BPE: '") + word_to_process + "'");
2219 }
2220
2221 if (ignore_merges_param) { // If ignore_merges is true, try direct lookup first
2222 auto it_direct = token_to_id_.find(word_to_process);
2223 if (it_direct != token_to_id_.end()) {
2224 output_ids.push_back(it_direct->second);
2225 Logger::debug(std::string("[bpe_tokenize_to_ids] Found word directly (ignore_merges): '") + word_to_process + "' -> ID: " + std::to_string(it_direct->second));
2226 current_idx += original_word.length();
2227 continue;
2228 }
2229 Logger::debug(std::string("[bpe_tokenize_to_ids] ignore_merges=true, but word \'") + word_to_process + "\' not in vocab directly. Proceeding with BPE char split (unusual for tiktoken special words).");
2230
2231 }
2232
2233 std::vector<llm_symbol> symbols;
2234 symbols.reserve(word_to_process.length());
2235 size_t offset = 0;
2236 while (offset < word_to_process.length()) {
2237 size_t char_len = unicode_char_len(word_to_process[offset]);
2238 if (offset + char_len > word_to_process.length()) {
2239 Logger::error("[bpe_tokenize_to_ids] Invalid UTF-8 sequence in word: '" + word_to_process + "' at offset " + std::to_string(offset));
2240 symbols.clear();
2241 break;
2242 }
2243 // For Tiktoken, the llm_symbol needs `text_offset` relative to `word_to_process.data()` and `n` (length).
2244 // The `prev` and `next` are for the linked list structure during BPE.
2245 symbols.emplace_back(llm_symbol{-1, -1, word_to_process.data() + offset, char_len});
2246 offset += char_len;
2247 }
2248
2249 if (symbols.empty() && !word_to_process.empty()) {
2250 Logger::warning("[bpe_tokenize_to_ids] Word '" + word_to_process + "' resulted in no symbols. Skipping this word's BPE.");
2251 if (unk_token_id_ != -1 && !original_word.empty()){
2252 output_ids.push_back(unk_token_id_);
2253 }
2254 current_idx += original_word.length();
2255 continue;
2256 }
2257 if (symbols.empty() && word_to_process.empty()){
2258 current_idx += original_word.length();
2259 continue;
2260 }
2261
2262 for (size_t i = 0; i < symbols.size(); ++i) {
2263 symbols[i].prev = (i > 0) ? (i - 1) : -1;
2264 symbols[i].next = (i < symbols.size() - 1) ? (i + 1) : -1;
2265 }
2266
2267 // Use std::priority_queue for merges
2268 std::priority_queue<std::pair<int, int>,
2269 std::vector<std::pair<int, int>>,
2270 std::greater<std::pair<int, int>>> merge_queue;
2271
2272 for (size_t i = 0; i + 1 < symbols.size(); ++i) {
2273 // add_bigram_to_queue expects data pointer, symbols vector, index of first symbol in pair, and queue
2274 add_bigram_to_queue_refactored(word_to_process.data(), symbols, i, merge_queue);
2275 }
2276
2277 while (!merge_queue.empty()) {
2278 auto top = merge_queue.top();
2279 merge_queue.pop();
2280
2281 int rank = top.first;
2282 int p1_idx = top.second;
2283
2284 if (symbols[p1_idx].n == 0) continue;
2285 int p2_idx = symbols[p1_idx].next;
2286 if (p2_idx == -1 || symbols[p2_idx].n == 0) continue;
2287
2288
2289 symbols[p1_idx].n += symbols[p2_idx].n;
2290 symbols[p2_idx].n = 0;
2291 symbols[p1_idx].next = symbols[p2_idx].next;
2292 if (symbols[p1_idx].next != -1) {
2293 symbols[symbols[p1_idx].next].prev = p1_idx;
2294 }
2295
2296
2297 // Add new bigrams
2298 if (symbols[p1_idx].prev != -1) {
2299 add_bigram_to_queue_refactored(word_to_process.data(), symbols, symbols[p1_idx].prev, merge_queue);
2300 }
2301 if (symbols[p1_idx].next != -1) {
2302 add_bigram_to_queue_refactored(word_to_process.data(), symbols, p1_idx, merge_queue);
2303 }
2304 }
2305
2306 std::vector<int> final_word_ids;
2307 if (!symbols.empty()) {
2308 for (int i = 0; i != -1; i = symbols[i].next) {
2309 const llm_symbol & symbol = symbols[i];
2310 if (symbol.n == 0) continue;
2311
2312 std::string s(symbol.text, symbol.n);
2313 std::string lookup_s = s;
2314
2315 const auto token_it = token_to_id_.find(lookup_s);
2316
2317 if (token_it != token_to_id_.end()) {
2318 final_word_ids.push_back(token_it->second);
2319 } else {
2320 Logger::warning(std::string("[bpe_tokenize_to_ids] Symbol not found in vocab: '") + lookup_s + "'. Attempting byte-level tokenization.");
2321 for (char c_char : lookup_s) {
2322 auto byte_map_it = byte_char_to_id_.find(c_char);
2323 if (byte_map_it != byte_char_to_id_.end()){
2324 final_word_ids.push_back(byte_map_it->second);
2325 } else {
2326 if (unk_token_id_ != -1) {
2327 final_word_ids.push_back(unk_token_id_);
2328 } else {
2329 Logger::error(std::string("[bpe_tokenize_to_ids] Unhandled char '") + std::string(1, c_char) + "' and no UNK token ID.");
2330 }
2331 }
2332 }
2333 }
2334 }
2335 } else if (!word_to_process.empty()) {
2336 Logger::warning(std::string("[bpe_tokenize_to_ids] Word '") + word_to_process + std::string("' yielded no final symbols. UNK if available."));
2337 if (unk_token_id_ != -1){ final_word_ids.push_back(unk_token_id_); }
2338 }
2339
2340 if (final_word_ids.empty() && !original_word.empty()) {
2341 Logger::warning(std::string("[bpe_tokenize_to_ids] Word '") + original_word + "' resulted in no tokens. Adding UNK.");
2342 if (unk_token_id_ != -1) { output_ids.push_back(unk_token_id_); }
2343 } else {
2344 output_ids.insert(output_ids.end(), final_word_ids.begin(), final_word_ids.end());
2345 }
2346 current_idx += original_word.length();
2347 }
2348
2349 if (add_eos_token_param) {
2350 if (eos_token_id_ == -1) {
2351 Logger::warning("[bpe_tokenize_to_ids] EOS token requested but eos_token_id_ is -1.");
2352 } else {
2353 output_ids.push_back(eos_token_id_); // Corrected to output_ids
2354 Logger::debug(std::string("[bpe_tokenize_to_ids] Added EOS token: ") + std::to_string(eos_token_id_));
2355 }
2356 }
2357 Logger::debug("[bpe_tokenize_to_ids] Finished Tiktoken BPE tokenization. Total IDs: " + std::to_string(output_ids.size()));
2358 return output_ids;
2359}
2360
2361// Helper function to add a potential bigram to the priority queue (Refactored for new llm_symbol structure)
2362// Assumes llm_symbol stores text_offset and n (length) relative to a base data pointer.
2363void Tokenizer::add_bigram_to_queue_refactored(const char* text_data_base,
2364 const std::vector<llm_symbol>& symbols,
2365 llm_symbol::index first_symbol_idx,
2366 std::priority_queue<std::pair<int, int>,
2367 std::vector<std::pair<int, int>>,
2368 std::greater<std::pair<int, int>>>& work_queue) const {
2369 if (first_symbol_idx < 0 || static_cast<size_t>(first_symbol_idx) >= symbols.size()) {
2370 Logger::error(std::string("[ADD_BIGRAM_REFACTORED] Invalid first_symbol_idx: ") + std::to_string(first_symbol_idx));
2371 return;
2372 }
2373
2374 const llm_symbol& s1 = symbols[first_symbol_idx];
2375 llm_symbol::index s2_idx = s1.next;
2376
2377 if (s2_idx < 0 || static_cast<size_t>(s2_idx) >= symbols.size() || s2_idx <= first_symbol_idx) {
2378 return;
2379 }
2380 const llm_symbol& s2 = symbols[s2_idx];
2381
2382 if (s1.n == 0 || s2.n == 0) {
2383 return;
2384 }
2385
2386 std::string token_left_str(s1.text, s1.n);
2387 std::string token_right_str(s2.text, s2.n);
2388
2389 std::vector<std::string> merge_attempts;
2390
2391 // First priority: If we see Ġ, try it with a following space
2392 if (token_left_str == BPE_SPACE_CHAR) {
2393 merge_attempts.push_back(BPE_SPACE_CHAR + " " + token_right_str);
2394 Logger::debug("[ADD_BIGRAM] Attempting Ġ+space merge: '" + (BPE_SPACE_CHAR + " " + token_right_str) + "'");
2395 }
2396
2397 // Second priority: Standard merge without space
2398 merge_attempts.push_back(token_left_str + token_right_str);
2399 Logger::debug("[ADD_BIGRAM] Attempting standard merge: '" + (token_left_str + token_right_str) + "'");
2400
2401 // Third priority: If left token starts with Ġ but isn't just Ġ, try with space
2402 if (token_left_str.rfind(BPE_SPACE_CHAR, 0) == 0 && token_left_str != BPE_SPACE_CHAR) {
2403 merge_attempts.push_back(token_left_str + " " + token_right_str);
2404 Logger::debug("[ADD_BIGRAM] Attempting Ġword+space merge: '" + (token_left_str + " " + token_right_str) + "'");
2405 }
2406
2407 // Fourth priority: Special case for character splits with space
2408 if (token_left_str.length() == 2 && token_right_str.length() == 1) {
2409 std::string attempt = token_left_str.substr(0, 1) + " " + token_right_str;
2410 merge_attempts.push_back(attempt);
2411 Logger::debug("[ADD_BIGRAM] Attempting char split merge: '" + attempt + "'");
2412 }
2413
2414 int best_rank = std::numeric_limits<int>::max();
2415 bool found_merge = false;
2416 std::string matched_merge;
2417
2418 for (const auto& merge_attempt : merge_attempts) {
2419 auto it = bpe_merges_.find(merge_attempt);
2420 if (it != bpe_merges_.end() && it->second < best_rank) {
2421 best_rank = it->second;
2422 found_merge = true;
2423 matched_merge = merge_attempt;
2424 }
2425 }
2426
2427 if (found_merge) {
2428 work_queue.push({best_rank, first_symbol_idx});
2429 Logger::debug("[ADD_BIGRAM] Found merge: '" + matched_merge + "' with rank " + std::to_string(best_rank));
2430 } else {
2431 Logger::debug("[ADD_BIGRAM] No valid merges found for attempts with left='" + token_left_str +
2432 "' right='" + token_right_str + "'");
2433 }
2434}
2435
2436
static void debug(const std::string &message)
Definition logger.cpp:131
static void warning(const std::string &message)
Definition logger.cpp:139
static void info(const std::string &message)
Definition logger.cpp:135
static void error(const std::string &message)
Definition logger.cpp:143
int eos_token_id_
Definition tokenizer.h:224
PreTokenizeMethod
Enumeration of available pre-tokenization methods.
Definition tokenizer.h:66
@ LLAMA_REGEX
Definition tokenizer.h:68
int bos_token_id_
Definition tokenizer.h:223
std::unordered_map< std::string, int > token_to_id_
Definition tokenizer.h:205
std::unordered_map< int, std::string > id_to_added_token_
Definition tokenizer.h:229
std::string capitalize_first_letter(std::string s) const
void load_bpe_merges_from_json(const std::string &model_path)
Loads BPE merge rules from JSON file.
void add_bigram_to_queue_refactored(const char *text_data_base, const std::vector< llm_symbol > &symbols, llm_symbol::index first_symbol_idx, std::priority_queue< std::pair< int, int >, std::vector< std::pair< int, int > >, std::greater< std::pair< int, int > > > &work_queue) const
int pad_token_id_
Definition tokenizer.h:225
ModelConfig::TokenizerFamily tokenizer_family_
Definition tokenizer.h:211
std::vector< int > bpe_tokenize_to_ids(const std::string &text, bool add_bos_token_param, bool add_eos_token_param, bool ignore_merges_param) const
bool sentencepiece_model_loaded_
Definition tokenizer.h:227
std::string decode_sentencepiece(const std::vector< int > &ids, bool skip_special_tokens) const
std::vector< std::string > ids_to_tokens(const std::vector< int > &ids) const
Converts token IDs back to token strings.
std::vector< int > tokens_to_ids(const std::vector< std::string > &tokens) const
std::vector< float > token_scores_
Definition tokenizer.h:209
void load_sentencepiece_model(const std::string &model_path)
Loads a SentencePiece model.
const std::string & get_gguf_chat_template() const
std::vector< std::string > bpe_tokenize(const std::string &text) const
int find_bpe_rank(const std::string &token_left, const std::string &token_right) const
Definition tokenizer.cpp:51
std::string unk_token_
Definition tokenizer.h:216
std::vector< std::string > bpe_tokenize_from_scores(const std::string &text) const
Definition tokenizer.cpp:59
std::vector< int > encode(const std::string &text, bool add_bos=true, bool add_eos=false, PreTokenizeMethod pre_tok_override=PreTokenizeMethod::DEFAULT) const
Encodes text into token IDs with optional special tokens.
int unk_token_id_
Definition tokenizer.h:222
bool initialized_from_gguf_
Definition tokenizer.h:212
std::string eos_token_
Definition tokenizer.h:218
std::unordered_map< char, int > byte_char_to_id_
Definition tokenizer.h:231
int vocab_size() const
Returns the size of the vocabulary.
void load_vocab_from_json(const std::string &vocab_path, std::unordered_map< std::string, int > &token_to_id, std::vector< std::string > &id_to_token)
Loads vocabulary from JSON file.
std::vector< int32_t > token_types_
Definition tokenizer.h:210
std::string pad_token_
Definition tokenizer.h:219
std::string detokenize(const std::vector< std::string > &tokens) const
Combines tokens back into text.
std::unordered_map< std::string, int > bpe_merges_
Definition tokenizer.h:207
std::string pre_tok_type_
Definition tokenizer.h:228
std::unordered_map< std::string, int > added_tokens_
Definition tokenizer.h:213
std::vector< std::string > id_to_token_
Definition tokenizer.h:206
std::string gguf_chat_template_
Definition tokenizer.h:234
std::string apply_chat_template(const std::string &user_prompt, const std::string &system_message, const ModelConfig &config) const
Applies chat template formatting to the input prompt.
std::string bos_token_
Definition tokenizer.h:217
std::string decode(const std::vector< int > &ids, bool skip_special_tokens=true) const
Decodes token IDs back to text.
Type type_
Definition tokenizer.h:233
Tokenizer(const std::string &vocab_path, const std::string &model_path, const ModelConfig &config)
Constructs a tokenizer from vocabulary and model files (for Llama 2 style JSON)
Logging utilities for the TinyLlama implementation.
nlohmann::json json
Definition server.cpp:54
Complete representation of a GGUF file's contents.
std::vector< std::string > tokenizer_tokens
std::vector< float > tokenizer_scores
std::vector< std::string > tokenizer_merges
std::map< std::string, GGUFMetadataValue > metadata
std::vector< uint32_t > tokenizer_token_types
Model configuration structure holding architecture and hyperparameters.
Definition model.h:80
int pad_token_id
Definition model.h:95
int eos_token_id
Definition model.h:93
int bos_token_id
Definition model.h:92
int unk_token_id
Definition model.h:94
const char * text
Definition tokenizer.h:25
size_t n
Definition tokenizer.h:26
index next
Definition tokenizer.h:24
bool is_numeric(const std::string &s)
Definition tokenizer.cpp:36
const std::string BPE_SPACE_CHAR
Definition tokenizer.cpp:26
static std::string replace_all(std::string str, const std::string &from, const std::string &to)
static std::vector< std::pair< std::string, int > > sort_tokens_by_length_desc(const std::unordered_map< std::string, int > &tokens_map)
static std::unordered_map< std::string, int > generate_bpe_merges_from_vocab_scores(const std::vector< std::string > &id_to_token, const std::vector< float > &token_scores)