TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
safetensors_loader.cpp
Go to the documentation of this file.
2#include "model.h"
3#include "logger.h"
4#include "model_macros.h" // For SAFE_MIN, SAFE_MAX (may be needed by cpu_f16_to_float32)
5
6#include <fstream>
7#include <stdexcept>
8#include <nlohmann/json.hpp>
9#include <algorithm>
10#include <cctype>
11#include <vector>
12#include <string>
13#include <map>
14#include <memory>
15#include <filesystem>
16
17
18#ifndef _WIN32
19#include <sys/stat.h>
20#include <cerrno> // For strerror
21#else
22#ifndef NOMINMAX
23#define NOMINMAX
24#endif
25#include <windows.h>
26#endif
27
28#ifdef __AVX2__
29#include <immintrin.h>
30#endif
31inline float cpu_bf16_to_float32(uint16_t bf16_raw) {
32 unsigned int bits = ((unsigned int)bf16_raw) << 16;
33 float result;
34 memcpy(&result, &bits, sizeof(float));
35 return result;
36}
37inline float cpu_f16_to_float32(uint16_t f16_raw) {
38 const uint32_t sign_mask_f16 = 0x8000;
39 const uint32_t exp_mask_f16 = 0x7C00;
40 const uint32_t mant_mask_f16 = 0x03FF;
41 const int32_t exp_bias_f16 = 15;
42 const int32_t exp_bias_f32 = 127;
43
44 uint32_t sign_f32 = (static_cast<uint32_t>(f16_raw & sign_mask_f16)) << 16;
45 int32_t exp_f16 = (f16_raw & exp_mask_f16) >> 10;
46 uint32_t mant_f16 = (f16_raw & mant_mask_f16);
47
48 uint32_t f32_bits;
49
50 if (exp_f16 == 0x1F) { // F16 NaN or Inf
51 f32_bits = sign_f32 | 0x7F800000U | (mant_f16 << 13); // Propagate mantissa for NaN
52 } else if (exp_f16 == 0) { // F16 zero or subnormal
53 if (mant_f16 == 0) { // Zero
54 f32_bits = sign_f32;
55 } else { // Subnormal F16 to normal or subnormal F32
56 int32_t s = -1;
57 mant_f16 <<= 1;
58 while ((mant_f16 & 0x0400) == 0) {
59 mant_f16 <<= 1;
60 s--;
61 }
62 mant_f16 &= 0x03FF; // Clear leading 1
63 int32_t f32_exp_val = (1 - exp_bias_f16) + s + exp_bias_f32;
64 if (f32_exp_val <= 0) { // Result is subnormal F32 or zero
65 int32_t shift = 1 - f32_exp_val;
66 if (shift > 23) { // Underflow to zero
67 f32_bits = sign_f32;
68 } else {
69 f32_bits = sign_f32 | ((mant_f16 << 13) >> shift) ;
70 }
71 } else { // Result is normal F32
72 f32_bits = sign_f32 | (static_cast<uint32_t>(f32_exp_val) << 23) | (mant_f16 << 13);
73 }
74 }
75 } else { // Normal F16
76 int32_t f32_exp = exp_f16 - exp_bias_f16 + exp_bias_f32;
77 f32_bits = sign_f32 | (static_cast<uint32_t>(f32_exp) << 23) | (mant_f16 << 13);
78 }
79
80 float result;
81 memcpy(&result, &f32_bits, sizeof(float));
82 return result;
83}
84
85Shard::Shard(const std::string& fp) : file_path(fp) {
86 Logger::info("Shard: Initializing for file: " + file_path);
87#ifdef _WIN32
88 file_handle_ = CreateFileA(file_path.c_str(), GENERIC_READ, FILE_SHARE_READ,
89 NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
90 if (file_handle_ == INVALID_HANDLE_VALUE) {
91 throw std::runtime_error("Shard: Failed to open file (Windows): " + file_path + " Error: " + std::to_string(GetLastError()));
92 }
93
94 LARGE_INTEGER size_li;
95 if (!GetFileSizeEx(file_handle_, &size_li)) {
96 CloseHandle(file_handle_);
97 file_handle_ = INVALID_HANDLE_VALUE;
98 throw std::runtime_error("Shard: Failed to get file size (Windows): " + file_path);
99 }
100 file_size = static_cast<size_t>(size_li.QuadPart);
101 if (file_size == 0) {
102 CloseHandle(file_handle_);
103 file_handle_ = INVALID_HANDLE_VALUE;
104 throw std::runtime_error("Shard: File is empty: " + file_path);
105 }
106
107 mapping_handle_ = CreateFileMapping(file_handle_, NULL, PAGE_READONLY, 0, 0, NULL);
108 if (mapping_handle_ == NULL) {
109 CloseHandle(file_handle_);
110 file_handle_ = INVALID_HANDLE_VALUE;
111 throw std::runtime_error("Shard: Failed to create file mapping (Windows): " + file_path + " Error: " + std::to_string(GetLastError()));
112 }
113
114 mapped_data = MapViewOfFile(mapping_handle_, FILE_MAP_READ, 0, 0, file_size);
115 if (mapped_data == nullptr) {
116 CloseHandle(mapping_handle_);
117 mapping_handle_ = NULL;
118 CloseHandle(file_handle_);
119 file_handle_ = INVALID_HANDLE_VALUE;
120 throw std::runtime_error("Shard: Failed to map view of file (Windows): " + file_path + " Error: " + std::to_string(GetLastError()));
121 }
122#else // POSIX
123 fd_ = open(file_path.c_str(), O_RDONLY);
124 if (fd_ == -1) {
125 throw std::runtime_error("Shard: Failed to open file: " + file_path + " Error: " + strerror(errno));
126 }
127
128 struct stat sb;
129 if (fstat(fd_, &sb) == -1) {
130 close(fd_);
131 fd_ = -1;
132 throw std::runtime_error("Shard: Failed to get file size: " + file_path + " Error: " + strerror(errno));
133 }
134 file_size = sb.st_size;
135 if (file_size == 0) {
136 close(fd_);
137 fd_ = -1;
138 throw std::runtime_error("Shard: File is empty: " + file_path);
139 }
140
141 mapped_data = mmap(NULL, file_size, PROT_READ, MAP_SHARED, fd_, 0);
142 if (mapped_data == MAP_FAILED) {
143 close(fd_);
144 fd_ = -1;
145 mapped_data = nullptr;
146 throw std::runtime_error("Shard: Failed to memory map file: " + file_path + " Error: " + strerror(errno));
147 }
148#endif
149 Logger::debug("Shard: Successfully mapped file: " + file_path + ", size: " + std::to_string(file_size));
150
151 if (file_size < 8) {
152 throw std::runtime_error("Shard: File too small (" + std::to_string(file_size) + " bytes) to be a valid SafeTensors shard (min 8 bytes for metadata length): " + file_path);
153 }
154 metadata_size = *reinterpret_cast<const uint64_t*>(mapped_data);
155
156 if (metadata_size == 0) {
157 throw std::runtime_error("Shard: Metadata size is 0 in file header: " + file_path);
158 }
159 if (8 + metadata_size > file_size) {
160 throw std::runtime_error("Shard: Declared metadata size (" + std::to_string(metadata_size) + ") plus header (8 bytes) exceeds file size (" + std::to_string(file_size) + ") in: " + file_path);
161 }
162 metadata_ptr = static_cast<const uint8_t*>(mapped_data) + 8;
164 Logger::debug("Shard: Metadata size from header: " + std::to_string(metadata_size) + " for " + file_path);
165}
166
168 Logger::debug("Shard: Cleaning up for file: " + (file_path.empty() ? "(moved or uninitialized)" : file_path) );
169#ifdef _WIN32
170 if (mapped_data != nullptr) {
171 if (!UnmapViewOfFile(mapped_data)) {
172 Logger::error("Shard: Failed to unmap view of file (Windows) for \"" + file_path + "\" Error: " + std::to_string(GetLastError()));
173 }
174 }
175 if (mapping_handle_ != NULL) {
176 if (!CloseHandle(mapping_handle_)) {
177 Logger::error("Shard: Failed to close mapping handle (Windows) for \"" + file_path + "\" Error: " + std::to_string(GetLastError()));
178 }
179 }
180 if (file_handle_ != INVALID_HANDLE_VALUE) {
181 if (!CloseHandle(file_handle_)) {
182 Logger::error("Shard: Failed to close file handle (Windows) for \"" + file_path + "\" Error: " + std::to_string(GetLastError()));
183 }
184 }
185 mapped_data = nullptr;
186 file_handle_ = INVALID_HANDLE_VALUE;
187 mapping_handle_ = NULL;
188#else // POSIX
189 if (mapped_data != nullptr && mapped_data != MAP_FAILED) {
190 if (munmap(mapped_data, file_size) == -1) {
191 Logger::error("Shard: Failed to munmap file: \"" + file_path + "\" Error: " + strerror(errno));
192 }
193 }
194 if (fd_ != -1) {
195 if (close(fd_) == -1) {
196 Logger::error("Shard: Failed to close file descriptor for \"" + file_path + "\" Error: " + strerror(errno));
197 }
198 }
199 mapped_data = nullptr;
200 fd_ = -1;
201#endif
202}
203
204Shard::Shard(Shard&& other) noexcept
205 : file_path(std::move(other.file_path)),
206 mapped_data(other.mapped_data),
207 file_size(other.file_size),
208 metadata_size(other.metadata_size),
209 metadata_ptr(other.metadata_ptr),
210 tensor_data_block_ptr(other.tensor_data_block_ptr)
211#ifdef _WIN32
212 , file_handle_(other.file_handle_)
213 , mapping_handle_(other.mapping_handle_)
214#else
215 , fd_(other.fd_)
216#endif
217{
218 other.mapped_data = nullptr;
219 other.file_size = 0;
220 other.metadata_size = 0;
221 other.metadata_ptr = nullptr;
222 other.tensor_data_block_ptr = nullptr;
223#ifdef _WIN32
224 other.file_handle_ = INVALID_HANDLE_VALUE;
225 other.mapping_handle_ = NULL;
226#else
227 other.fd_ = -1;
228#endif
229}
230
231Shard& Shard::operator=(Shard&& other) noexcept {
232 if (this != &other) {
233 this->~Shard();
234 file_path = std::move(other.file_path);
235 mapped_data = other.mapped_data;
236 file_size = other.file_size;
237 metadata_size = other.metadata_size;
238 metadata_ptr = other.metadata_ptr;
239 tensor_data_block_ptr = other.tensor_data_block_ptr;
240#ifdef _WIN32
241 file_handle_ = other.file_handle_;
242 mapping_handle_ = other.mapping_handle_;
243#else
244 fd_ = other.fd_;
245#endif
246 other.mapped_data = nullptr;
247 other.file_size = 0;
248 other.metadata_size = 0;
249 other.metadata_ptr = nullptr;
250 other.tensor_data_block_ptr = nullptr;
251#ifdef _WIN32
252 other.file_handle_ = INVALID_HANDLE_VALUE;
253 other.mapping_handle_ = NULL;
254#else
255 other.fd_ = -1;
256#endif
257 }
258 return *this;
259}
260
261const uint8_t* Shard::get_tensor_raw_data(size_t local_offset, size_t n_bytes) const {
262#ifdef _WIN32
263 if (!mapped_data || mapped_data == NULL || !tensor_data_block_ptr) {
264#else // POSIX
265 if (!mapped_data || mapped_data == MAP_FAILED || !tensor_data_block_ptr) {
266#endif
267 throw std::logic_error("Shard not properly mapped or initialized to get tensor data: " + file_path);
268 }
269 const uint8_t* data_start = tensor_data_block_ptr + local_offset;
270 const uint8_t* shard_data_block_end = tensor_data_block_ptr + (file_size - (8 + metadata_size));
271
272 if (data_start < tensor_data_block_ptr || data_start + n_bytes > shard_data_block_end || n_bytes > (file_size - (8 + metadata_size))) {
273 throw std::out_of_range(
274 "Tensor data (local_offset: " + std::to_string(local_offset) +
275 ", n_bytes: " + std::to_string(n_bytes) +
276 ") out of bounds for data block of shard: " + file_path +
277 ". Shard data block size: " + std::to_string(file_size - (8 + metadata_size)) + " bytes."
278 );
279 }
280 return data_start;
281}
282
283SafeTensorsLoader::SafeTensorsLoader(const std::string& model_load_path)
284 : model_load_path_(model_load_path), is_sharded_(false) {
285 Logger::info("SafeTensorsLoader: Initializing for path: " + model_load_path_);
286 std::filesystem::path path_obj(model_load_path_);
287
288 if (!std::filesystem::exists(path_obj)){
289 throw std::runtime_error("SafeTensorsLoader: Provided model_load_path does not exist: " + model_load_path_);
290 }
291
292 if (std::filesystem::is_directory(path_obj)) {
293 Logger::info("SafeTensorsLoader: Path is a directory. Attempting to load from directory.");
295 } else if (std::filesystem::is_regular_file(path_obj)) {
296 Logger::info("SafeTensorsLoader: Path is a single file. Loading single file.");
297 std::string file_key = path_obj.filename().string();
299 is_sharded_ = false;
300 } else {
301 throw std::runtime_error("SafeTensorsLoader: model_load_path is not a valid file or directory: " + model_load_path_);
302 }
303
304 if (tensors_.empty() && loaded_shards_.empty()) {
305 Logger::warning("SafeTensorsLoader: Initialization complete, but no tensors were loaded and no shards mapped. Check model path and format: " + model_load_path_);
306 } else {
307 Logger::info("SafeTensorsLoader: Initialization complete. Total unique tensors mapped: " + std::to_string(tensors_.size()) +
308 " from " + std::to_string(loaded_shards_.size()) + " shard(s).");
309 }
310}
311
313 Logger::info("SafeTensorsLoader: Destructing. Clearing " + std::to_string(loaded_shards_.size()) + " loaded shards.");
314 loaded_shards_.clear();
315 Logger::info("SafeTensorsLoader: All shards cleared.");
316}
317
318void SafeTensorsLoader::load_from_directory(const std::string& directory_path_str) {
319 Logger::debug("SafeTensorsLoader::load_from_directory for '" + directory_path_str + "'.");
320 std::filesystem::path dir_p(directory_path_str);
321 std::filesystem::path index_json_path_v1 = dir_p / "model.safetensors.index.json";
322 std::filesystem::path index_json_path_v2 = dir_p / "pytorch_model.bin.index.json";
323 std::filesystem::path actual_index_path;
324
325 bool index_found = false;
326 if (std::filesystem::exists(index_json_path_v1) && std::filesystem::is_regular_file(index_json_path_v1)) {
327 actual_index_path = index_json_path_v1;
328 index_found = true;
329 } else if (std::filesystem::exists(index_json_path_v2) && std::filesystem::is_regular_file(index_json_path_v2)) {
330 actual_index_path = index_json_path_v2;
331 index_found = true;
332 }
333
334 if (index_found) {
335 Logger::info("SafeTensorsLoader: Found index file: " + actual_index_path.string());
336 is_sharded_ = true;
337 std::ifstream f(actual_index_path.string());
338 if (!f.is_open()) {
339 throw std::runtime_error("SafeTensorsLoader: Failed to open index file: " + actual_index_path.string());
340 }
341 nlohmann::json index_json_data;
342 try {
343 index_json_data = nlohmann::json::parse(f);
344 } catch (const nlohmann::json::parse_error& e) {
345 f.close();
346 throw std::runtime_error("SafeTensorsLoader: Failed to parse index JSON from " + actual_index_path.string() + ": " + e.what());
347 }
348 f.close();
349
350 if (index_json_data.count("weight_map") && index_json_data["weight_map"].is_object()) {
351 // First pass: populate tensor_name_to_shard_key_map_ and identify unique shards to load
352 std::map<std::string, std::string> unique_shards_to_load; // shard_filename -> full_path
353 for (auto const& [tensor_name, shard_filename_json] : index_json_data["weight_map"].items()) {
354 if (!shard_filename_json.is_string()) {
355 Logger::warning("SafeTensorsLoader: Shard filename for tensor '" + tensor_name + "' in index is not a string. Skipping.");
356 continue;
357 }
358 std::string shard_filename = shard_filename_json.get<std::string>();
359 tensor_name_to_shard_key_map_[tensor_name] = shard_filename;
360 if (unique_shards_to_load.find(shard_filename) == unique_shards_to_load.end()) {
361 unique_shards_to_load[shard_filename] = (dir_p / shard_filename).string();
362 }
363 }
364
365 // Second pass: load each unique shard and parse its metadata
366 for(const auto& pair : unique_shards_to_load){
367 const std::string& shard_filename = pair.first;
368 const std::string& full_shard_path = pair.second;
369 if (loaded_shards_.find(shard_filename) == loaded_shards_.end()) {
370 Logger::info("SafeTensorsLoader: Loading and parsing shard (from index): " + full_shard_path + " (key:"+ shard_filename + ")");
371 load_single_file(full_shard_path, shard_filename);
372 } else {
373 Logger::debug("SafeTensorsLoader: Shard '" + shard_filename + "' already loaded/parsed (should not happen if unique_shards logic is correct).");
374 }
375 }
376
377 } else {
378 throw std::runtime_error("SafeTensorsLoader: Index file " + actual_index_path.string() + " does not contain a valid 'weight_map'.");
379 }
380 } else {
381 Logger::info("SafeTensorsLoader: No index file found in " + directory_path_str + ". Scanning for *.safetensors files.");
382 std::vector<std::filesystem::path> shard_files;
383 for (const auto& entry : std::filesystem::directory_iterator(dir_p)) {
384 if (entry.is_regular_file() && entry.path().extension() == ".safetensors") {
385 shard_files.push_back(entry.path());
386 }
387 }
388
389 if (shard_files.empty()) {
390 Logger::warning("SafeTensorsLoader: No .safetensors files found directly in directory: " + directory_path_str + ". Checking for model.safetensors as last resort.");
391 std::filesystem::path single_model_file = dir_p / "model.safetensors";
392 if(std::filesystem::exists(single_model_file) && std::filesystem::is_regular_file(single_model_file)){
393 Logger::info("SafeTensorsLoader: Found 'model.safetensors' in directory, loading it as a single non-sharded model.");
394 load_single_file(single_model_file.string(), single_model_file.filename().string());
395 is_sharded_ = false;
396 } else {
397 Logger::info("SafeTensorsLoader: No .safetensors files or index.json found in directory: " + directory_path_str + ". No model weights will be loaded from this path directly.");
398 }
399 } else if (shard_files.size() == 1) {
400 Logger::info("SafeTensorsLoader: Found single .safetensors file: " + shard_files[0].string() + ". Loading as non-sharded.");
401 load_single_file(shard_files[0].string(), shard_files[0].filename().string());
402 is_sharded_ = false;
403 } else {
404 Logger::info("SafeTensorsLoader: Found " + std::to_string(shard_files.size()) + " .safetensors files (no index). Loading all as individual shards.");
405 is_sharded_ = true;
406 for (const auto& p : shard_files) {
407 load_single_file(p.string(), p.filename().string());
408 }
409 }
410 }
411}
412
413void SafeTensorsLoader::load_single_file(const std::string& file_path, const std::string& shard_key_override) {
414 std::string key_to_use = shard_key_override.empty() ? std::filesystem::path(file_path).filename().string() : shard_key_override;
415 if (key_to_use.empty()) key_to_use = file_path;
416
417 if (loaded_shards_.count(key_to_use)) {
418 Logger::debug("SafeTensorsLoader: Shard/file '" + key_to_use + "' (path: " + file_path + ") already processed/loaded.");
419 return;
420 }
421 Logger::info("SafeTensorsLoader: Loading single file/shard: " + file_path + " with key: " + key_to_use);
422 try {
423 auto shard = std::make_unique<Shard>(file_path);
424 parse_shard_metadata(*shard, key_to_use);
425 loaded_shards_[key_to_use] = std::move(shard);
426 } catch (const std::exception& e) {
427 throw std::runtime_error("SafeTensorsLoader: Error processing file/shard '" + file_path + "' (key: " + key_to_use + "): " + e.what());
428 }
429}
430
431void SafeTensorsLoader::parse_shard_metadata(Shard& shard, const std::string& shard_key) {
432 Logger::debug("SafeTensorsLoader: Parsing metadata for shard: " + shard_key + " (file: " + shard.file_path + ")");
433 if (!shard.metadata_ptr || shard.metadata_size == 0) {
434 throw std::runtime_error("Shard metadata is not available for parsing (nullptr or zero size): " + shard.file_path);
435 }
436 std::string metadata_json_str;
437 try {
438 metadata_json_str.assign(reinterpret_cast<const char*>(shard.metadata_ptr), shard.metadata_size);
439 } catch (const std::length_error& le) {
440 throw std::runtime_error("Error constructing metadata string for shard " + shard.file_path + ": " + le.what());
441 }
442
443 nlohmann::json metadata_root;
444 try {
445 metadata_root = nlohmann::json::parse(metadata_json_str);
446 } catch (const nlohmann::json::parse_error& e) {
447 throw std::runtime_error("Failed to parse metadata JSON for shard " + shard.file_path + " (key: " + shard_key + ") at offset 8, metadata_size: " +
448 std::to_string(shard.metadata_size) + ". Error: " + e.what() +
449 "\nJSON content snippet (first 200 chars): " + metadata_json_str.substr(0, 200));
450 }
451
452 size_t tensors_in_this_shard_count = 0;
453 for (auto const& [tensor_name_str, info_json] : metadata_root.items()) {
454 if (tensor_name_str == "__metadata__") continue;
455
456 TensorInfo tensor_info;
457 tensor_info.name = tensor_name_str;
458 try {
459 tensor_info.dtype = info_json.at("dtype").get<std::string>();
460 std::transform(tensor_info.dtype.begin(), tensor_info.dtype.end(), tensor_info.dtype.begin(),
461 [](unsigned char c){ return static_cast<char>(std::toupper(c)); });
462
463 for (const auto& dim : info_json.at("shape")) {
464 tensor_info.shape.push_back(dim.get<size_t>());
465 }
466 const auto& data_offsets_json = info_json.at("data_offsets");
467 if (!data_offsets_json.is_array() || data_offsets_json.size() != 2) {
468 throw std::runtime_error("Tensor '" + tensor_name_str + "' 'data_offsets' must be an array of two numbers.");
469 }
470 size_t start_offset_in_data_block = data_offsets_json[0].get<size_t>();
471 size_t end_offset_in_data_block = data_offsets_json[1].get<size_t>();
472
473 tensor_info.data_offset = start_offset_in_data_block;
474 tensor_info.nbytes = end_offset_in_data_block - start_offset_in_data_block;
475 tensor_info.shard_key = shard_key;
476
477 if (tensors_.count(tensor_info.name)) {
478 Logger::warning("SafeTensorsLoader: Duplicate tensor name '" + tensor_info.name + "' encountered. " +
479 "Previous shard key: '" + tensors_[tensor_info.name].shard_key + "', New shard key: '" + shard_key + "'. " +
480 "Overwriting with info from current shard being parsed. This can happen with unindexed multi-file loads or inconsistent index files.");
481 }
482 tensors_[tensor_info.name] = tensor_info;
484 tensor_name_to_shard_key_map_[tensor_info.name] = shard_key;
485 }
486
487 tensors_in_this_shard_count++;
488
489 } catch (const nlohmann::json::exception& e) {
490 throw std::runtime_error("Failed to parse tensor info for '" + tensor_name_str + "' in shard " +
491 shard.file_path + " (key: " + shard_key + "): " + e.what());
492 }
493 }
494 Logger::debug("SafeTensorsLoader: Finished parsing metadata for shard: " + shard_key + ". Parsed " + std::to_string(tensors_in_this_shard_count) + " tensor entries from this shard.");
495}
496
497std::vector<std::string> SafeTensorsLoader::tensor_names() const {
498 std::vector<std::string> names;
499 names.reserve(tensors_.size());
500 for (const auto& pair : tensors_) {
501 names.push_back(pair.first);
502 }
503 return names;
504}
505
507 auto it = tensors_.find(name);
508 if (it == tensors_.end()) {
509 throw std::runtime_error("Tensor not found in SafeTensorsLoader metadata: " + name);
510 }
511 return it->second;
512}
513
514const Shard* SafeTensorsLoader::get_shard_for_tensor(const std::string& tensor_name) const {
515 auto map_it = tensor_name_to_shard_key_map_.find(tensor_name);
516 std::string determined_shard_key;
517
518 if (map_it != tensor_name_to_shard_key_map_.end()){
519 determined_shard_key = map_it->second;
520 } else {
521 const auto& tensor_info_direct = get_tensor_info(tensor_name);
522 determined_shard_key = tensor_info_direct.shard_key;
523 }
524
525 if (determined_shard_key.empty()){
526 throw std::logic_error("Internal inconsistency: Could not determine shard key for tensor '" + tensor_name + "'.");
527 }
528
529 auto shard_it = loaded_shards_.find(determined_shard_key);
530 if (shard_it == loaded_shards_.end()) {
531 throw std::logic_error("Internal inconsistency: Shard key '" + determined_shard_key + "' for tensor '" + tensor_name + "' not found in loaded_shards_ map. Tensors map has it, but shard object itself is missing.");
532 }
533 return shard_it->second.get();
534}
535
536std::vector<uint8_t> SafeTensorsLoader::get_tensor_bytes(const std::string& name) const {
537 const TensorInfo& info = get_tensor_info(name);
538 const Shard* shard = get_shard_for_tensor(name);
539
540 const uint8_t* raw_data_ptr = shard->get_tensor_raw_data(info.data_offset, info.nbytes);
541 return convert_tensor_data(raw_data_ptr, info.nbytes, info.dtype);
542}
543
544std::map<std::string, std::vector<uint8_t>> SafeTensorsLoader::load_all_tensors_parallel() const {
545 std::map<std::string, std::vector<uint8_t>> result_map;
546 if (tensors_.empty()) {
547 Logger::debug("SafeTensorsLoader::load_all_tensors_parallel: No tensors to load.");
548 return result_map;
549 }
550
551 std::vector<std::future<std::pair<std::string, std::vector<uint8_t>>>> futures;
552 unsigned int n_threads = std::max(1u, std::thread::hardware_concurrency());
553 n_threads = std::min(n_threads, static_cast<unsigned int>(tensors_.size()));
554 if (n_threads > 16) n_threads = 16;
555
556 ThreadPool pool(n_threads);
557 Logger::info("SafeTensorsLoader: Loading all " + std::to_string(tensors_.size()) + " tensors in parallel using " + std::to_string(n_threads) + " threads.");
558
559 for (const auto& pair : tensors_) {
560 const std::string& tensor_name = pair.first;
561 futures.push_back(pool.submit([this, tensor_name]() {
562 std::vector<uint8_t> data = this->get_tensor_bytes(tensor_name);
563 return std::make_pair(tensor_name, std::move(data));
564 }));
565 }
566
567 for (auto& fut : futures) {
568 try {
569 std::pair<std::string, std::vector<uint8_t>> tensor_pair = fut.get();
570 result_map[tensor_pair.first] = std::move(tensor_pair.second);
571 } catch (const std::exception& e) {
572 Logger::error("SafeTensorsLoader: Error loading a tensor in parallel task: " + std::string(e.what()));
573 throw;
574 }
575 }
576 Logger::info("SafeTensorsLoader: Finished loading all tensors in parallel.");
577 return result_map;
578}
579
580std::vector<uint8_t> SafeTensorsLoader::convert_tensor_data(const uint8_t* data_ptr, size_t n_bytes, const std::string& dtype_str_upper) const {
581 if (dtype_str_upper == "F32") {
582 return std::vector<uint8_t>(data_ptr, data_ptr + n_bytes);
583 } else if (dtype_str_upper == "F16") {
584 size_t num_elements = n_bytes / 2;
585 std::vector<float> f32_vec(num_elements);
586 const uint16_t* f16_ptr = reinterpret_cast<const uint16_t*>(data_ptr);
587 for (size_t i = 0; i < num_elements; ++i) {
588 f32_vec[i] = cpu_f16_to_float32(f16_ptr[i]);
589 }
590 std::vector<uint8_t> bytes_out(num_elements * sizeof(float));
591 memcpy(bytes_out.data(), f32_vec.data(), bytes_out.size());
592 return bytes_out;
593 } else if (dtype_str_upper == "BF16") {
594 size_t num_elements = n_bytes / 2;
595 std::vector<float> f32_vec(num_elements);
596 const uint16_t* bf16_ptr = reinterpret_cast<const uint16_t*>(data_ptr);
597 for (size_t i = 0; i < num_elements; ++i) {
598 f32_vec[i] = cpu_bf16_to_float32(bf16_ptr[i]);
599 }
600 std::vector<uint8_t> bytes_out(num_elements * sizeof(float));
601 memcpy(bytes_out.data(), f32_vec.data(), bytes_out.size());
602 return bytes_out;
603 }
604 throw std::runtime_error("SafeTensorsLoader: Unsupported tensor dtype for conversion: " + dtype_str_upper);
605}
606
607bool SafeTensorsLoader::load_model_config_from_json(const std::string& model_path_or_dir_str, ModelConfig& config_to_populate) {
608 std::filesystem::path model_fs_path(model_path_or_dir_str);
609 std::filesystem::path config_json_path;
610
611 if (std::filesystem::is_directory(model_fs_path)) {
612 config_json_path = model_fs_path / "config.json";
613 } else if (std::filesystem::is_regular_file(model_fs_path)) {
614 config_json_path = model_fs_path.parent_path() / "config.json";
615 } else {
616 Logger::error("SafeTensorsLoader::load_model_config_from_json: Provided model path is not a valid file or directory: " + model_path_or_dir_str);
617 return false;
618 }
619 std::string config_json_path_str = config_json_path.string();
620
621 std::ifstream f(config_json_path_str);
622 if (!f.is_open()) {
623 Logger::warning("SafeTensorsLoader: config.json not found at: " + config_json_path_str);
624 return false;
625 }
626
627 try {
628 nlohmann::json data = nlohmann::json::parse(f);
629 f.close();
630
631 config_to_populate.hidden_size = data.value("hidden_size", 0);
632 config_to_populate.intermediate_size = data.value("intermediate_size", 0);
633 config_to_populate.num_attention_heads = data.value("num_attention_heads", 0);
634 config_to_populate.num_key_value_heads = data.value("num_key_value_heads", config_to_populate.num_attention_heads);
635 config_to_populate.num_hidden_layers = data.value("num_hidden_layers", 0);
636 config_to_populate.vocab_size = data.value("vocab_size", 0);
637 config_to_populate.max_position_embeddings = data.value("max_position_embeddings", 2048);
638 config_to_populate.rms_norm_eps = data.value("rms_norm_eps", 1e-5f);
639 config_to_populate.rope_theta = data.value("rope_theta", 10000.0f);
640 config_to_populate.bos_token_id = data.value("bos_token_id", 1);
641 config_to_populate.eos_token_id = data.value("eos_token_id", 2);
642 config_to_populate.pad_token_id = data.value("pad_token_id", -1);
643 config_to_populate.unk_token_id = data.value("unk_token_id", 0);
644
645 if (data.contains("architectures") && data["architectures"].is_array() && !data["architectures"].empty()) {
646 config_to_populate.architecture = data["architectures"][0].get<std::string>();
647 } else {
648 config_to_populate.architecture = data.value("model_type", "unknown");
649 }
650 config_to_populate.model_name = data.value("model_type", config_to_populate.architecture);
651
652 bool is_llama3_vocab_size_json = (config_to_populate.vocab_size == 128256);
653 bool is_llama3_arch_hint_json = (config_to_populate.architecture.find("LlamaForCausalLM") != std::string::npos &&
654 config_to_populate.architecture.find("Llama2") == std::string::npos);
655
656 if (is_llama3_vocab_size_json && is_llama3_arch_hint_json) {
658 if (config_to_populate.rope_theta == 10000.0f) {
659 float llama3_rope_candidate = data.value("rope_theta", 500000.0f);
660 if (llama3_rope_candidate > 10000.0f) {
661 config_to_populate.rope_theta = llama3_rope_candidate;
662 } else if (config_to_populate.rope_theta == 10000.0f) {
663 config_to_populate.rope_theta = 500000.0f;
664 }
665 }
666 } else if (config_to_populate.vocab_size == 32000 || config_to_populate.architecture.find("Llama") != std::string::npos) {
668 } else {
670 }
671 config_to_populate.is_gguf_file_loaded = false;
672
673 Logger::info("SafeTensorsLoader: Successfully loaded and parsed model config from: " + config_json_path_str);
674 return true;
675
676 } catch (const nlohmann::json::exception& e) {
677 Logger::error("SafeTensorsLoader: Failed to parse config.json: " + config_json_path_str + ". Error: " + e.what());
678 return false;
679 }
680 return false;
681}
682
683
684ThreadPool::ThreadPool(size_t num_threads) : stop_(false) {
685 for (size_t i = 0; i < num_threads; ++i) {
686 workers_.emplace_back([this] {
687 while (true) {
688 std::function<void()> task;
689 {
690 std::unique_lock<std::mutex> lock(this->queue_mutex_);
691 this->condition_.wait(lock, [this] {
692 return this->stop_ || !this->tasks_.empty();
693 });
694 if (this->stop_ && this->tasks_.empty()) return;
695 task = std::move(this->tasks_.front());
696 this->tasks_.pop();
697 }
698 if(task) task();
699 }
700 });
701 }
702}
703
705 {
706 std::unique_lock<std::mutex> lock(queue_mutex_);
707 stop_ = true;
708 }
709 condition_.notify_all();
710 for (std::thread& worker : workers_) {
711 if (worker.joinable()) {
712 worker.join();
713 }
714 }
715}
static void debug(const std::string &message)
Definition logger.cpp:131
static void warning(const std::string &message)
Definition logger.cpp:139
static void info(const std::string &message)
Definition logger.cpp:135
static void error(const std::string &message)
Definition logger.cpp:143
std::map< std::string, std::unique_ptr< Shard > > loaded_shards_
std::map< std::string, TensorInfo > tensors_
SafeTensorsLoader(const std::string &model_load_path)
Constructs a SafeTensorsLoader.
const Shard * get_shard_for_tensor(const std::string &tensor_name) const
Get the Shard object for a given tensor name.
void load_from_directory(const std::string &directory_path)
Load tensors from a directory, handling index files and multiple shards.
static bool load_model_config_from_json(const std::string &model_path_or_dir, ModelConfig &config_to_populate)
Loads model configuration from a JSON file corresponding to a .safetensors model path.
std::map< std::string, std::string > tensor_name_to_shard_key_map_
void load_single_file(const std::string &file_path, const std::string &shard_key_override="")
Load a single .safetensors file as a shard.
std::map< std::string, std::vector< uint8_t > > load_all_tensors_parallel() const
Load all tensors in parallel.
std::vector< std::string > tensor_names() const
Get a list of all tensor names available in the loaded model.
const TensorInfo & get_tensor_info(const std::string &name) const
Get information about a specific tensor.
std::vector< uint8_t > convert_tensor_data(const uint8_t *data, size_t size, const std::string &dtype) const
Convert raw tensor data to FP32 if needed.
std::vector< uint8_t > get_tensor_bytes(const std::string &name) const
Get the raw bytes for a tensor, converting to FP32 if needed.
~SafeTensorsLoader()
Destructor. Cleans up all memory-mapped shards.
void parse_shard_metadata(Shard &shard, const std::string &shard_key)
Parse the metadata of a shard and populate tensor information.
Thread pool for parallel tensor loading operations.
~ThreadPool()
Destructor that ensures proper cleanup of threads.
std::future< typename std::result_of< F(Args...)>::type > submit(F &&f, Args &&... args)
Submits a task to the thread pool.
ThreadPool(size_t num_threads)
Constructs a thread pool with specified number of threads.
std::queue< std::function< void()> > tasks_
std::vector< std::thread > workers_
std::condition_variable condition_
std::mutex queue_mutex_
Logging utilities for the TinyLlama implementation.
float cpu_f16_to_float32(uint16_t f16_raw)
float cpu_bf16_to_float32(uint16_t bf16_raw)
SafeTensors format loader for efficient tensor loading, supporting single and sharded models.
Model configuration structure holding architecture and hyperparameters.
Definition model.h:80
int hidden_size
Definition model.h:81
int vocab_size
Definition model.h:86
int pad_token_id
Definition model.h:95
std::string architecture
Definition model.h:96
std::string model_name
Definition model.h:97
float rms_norm_eps
Definition model.h:88
int num_attention_heads
Definition model.h:83
int intermediate_size
Definition model.h:82
int eos_token_id
Definition model.h:93
bool is_gguf_file_loaded
Definition model.h:101
float rope_theta
Definition model.h:89
int num_hidden_layers
Definition model.h:85
int num_key_value_heads
Definition model.h:84
int bos_token_id
Definition model.h:92
TokenizerFamily tokenizer_family
Definition model.h:117
int unk_token_id
Definition model.h:94
int max_position_embeddings
Definition model.h:87
Information about a tensor stored in the SafeTensors file(s)
Represents a memory-mapped SafeTensors file (shard).
~Shard()
Destructor. Cleans up memory mapping and file handles.
uint64_t metadata_size
Size of the metadata block in bytes.
const uint8_t * tensor_data_block_ptr
Pointer to the start of the tensor data block.
void * mapped_data
Pointer to the memory-mapped data.
std::string file_path
Path to the shard file.
const uint8_t * metadata_ptr
Pointer to the start of the metadata block.
Shard & operator=(Shard &&other) noexcept
Move assignment operator.
size_t file_size
Size of the mapped file in bytes.
const uint8_t * get_tensor_raw_data(size_t local_offset, size_t n_bytes) const
Get a pointer to the raw tensor data within this shard.
Shard(const std::string &fp)
Construct and memory-map a shard file.