TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
main.cpp
Go to the documentation of this file.
1
39#include <algorithm>
40#include <cctype>
41#include <cstdio>
42#include <iomanip>
43#include <iostream>
44#include <memory>
45#include <sstream>
46#include <stdexcept>
47#include <string>
48#include <vector>
49
50#include "api.h"
51#include "logger.h"
52
53std::string trim_whitespace(const std::string& s) {
54 auto start = std::find_if_not(
55 s.begin(), s.end(), [](unsigned char c) { return std::isspace(c); });
56 auto end = std::find_if_not(s.rbegin(), s.rend(), [](unsigned char c) {
57 return std::isspace(c);
58 }).base();
59 return (start < end ? std::string(start, end) : std::string());
60}
61
62void print_usage(const char* program_name) {
63 std::cout << "Usage: " << program_name
64 << " <model_path> <tokenizer_path> <num_threads> <prompt|chat|batch> "
65 "[--system-prompt <system_prompt_string>] [initial_user_prompt] [max_tokens] [n_gpu_layers] [use_mmap] [temperature] [top_k] [top_p] [use_kv_quant] [use_batch_generation] [--batch-prompts \"prompt1\" \"prompt2\" ...] [--max-batch-size N]"
66 << std::endl;
67 std::cout << "\nArguments:\n"
68 " model_path : Path to the model file (.gguf) or directory (SafeTensors).\n"
69 " tokenizer_path : Path to the tokenizer file.\n"
70 " num_threads : Number of threads to use for generation.\n"
71 " prompt|chat|batch : 'prompt' for single prompt, 'chat' for chat mode, 'batch' for batch processing.\n"
72 " --system-prompt : (Optional) System prompt to guide the model. Default: empty.\n"
73 " initial_user_prompt : (Optional) Initial user prompt string. Default: \"Hello, world!\".\n"
74 " max_tokens : (Optional) Maximum number of tokens to generate. Default: 256.\n"
75 " n_gpu_layers : (Optional) Number of layers to offload to GPU (-1 for all, 0 for none). Default: -1.\n"
76 " use_mmap : (Optional) Use mmap for GGUF files ('true' or 'false'). Default: true.\n"
77 " temperature : (Optional) Sampling temperature. Default: 0.1.\n"
78 " top_k : (Optional) Top-K sampling parameter (0 to disable). Default: 40.\n"
79 " top_p : (Optional) Top-P/nucleus sampling parameter (0.0-1.0). Default: 0.9.\n"
80 " use_kv_quant : (Optional) Use INT8 KVCache quantization on GPU ('true' or 'false'). Default: false.\n"
81 " use_batch_generation: (Optional) Use GPU batch generation for tokens ('true' or 'false'). Default: false.\n"
82 " --batch-prompts : (For batch mode) Multiple prompts in quotes, e.g., \"prompt1\" \"prompt2\" ...\n"
83 " --max-batch-size : (For batch mode) Maximum batch size. Default: 8.\n"
84 << std::endl;
85}
86
87int main(int argc, char** argv) {
88 if (argc > 1 && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help")) {
89 print_usage(argv[0]);
90 return 0;
91 }
92
93 if (argc < 5) { // Minimum required: model_path, tokenizer_path, num_threads, mode
94 std::cerr << "ERROR: Missing required arguments." << std::endl;
95 print_usage(argv[0]);
96 return 1;
97 }
98
99 std::string model_path_or_dir = argv[1];
100 std::string tokenizer_path = argv[2];
101 int num_threads = 4; // Default, will be overwritten if provided by argv[3]
102 try {
103 if (argc > 3) { // Ensure argv[3] (number of threads) exists
104 num_threads = std::stoi(argv[3]);
105 } else {
106 Logger::warning("Number of threads not provided, using default: " + std::to_string(num_threads));
107 }
108 } catch (const std::exception& e) {
109 Logger::warning("Could not parse num_threads from argv[3]: '" + (argc > 3 ? std::string(argv[3]) : "<not provided>") + "'. Using default: " + std::to_string(num_threads));
110 }
111
112 if (argc < 5) { // Check after attempting to parse num_threads
113 std::cerr << "ERROR: Missing mode argument (prompt|chat|batch) after num_threads." << std::endl;
114 print_usage(argv[0]);
115 return 1;
116 }
117 std::string mode_str = argv[4]; // Mode (prompt|chat|batch) is argv[4]
118
119 std::string system_prompt_str = ""; // Default empty system prompt
120 std::string user_prompt_str = "Hello, world!"; // Default user prompt
121 int max_tokens = 256; // Default for steps
122 int n_gpu_layers = -1; // Default: all layers on GPU
123 bool use_mmap = true; // Default: use mmap
124 bool use_kv_quant = false; // Default: do not use KVCache quantization
125 bool use_batch_generation = false; // Default: do not use batch generation
126
127 // Batch processing specific variables
128 std::vector<std::string> batch_prompts;
129 int max_batch_size = 8; // Default max batch size
130
131 // Default sampling params for generate
132 float temperature = 0.1f; // Default temperature
133 int top_k = 40; // Default top-k sampling parameter
134 float top_p = 0.9f; // Default top-p/nucleus sampling parameter
135
136 int current_arg_idx = 5;
137 while(current_arg_idx < argc) {
138 std::string arg = argv[current_arg_idx];
139 if (arg == "--system-prompt" || arg == "-sp") {
140 if (current_arg_idx + 1 < argc) {
141 system_prompt_str = argv[current_arg_idx + 1];
142 current_arg_idx += 2;
143 } else {
144 std::cerr << "ERROR: --system-prompt requires a value." << std::endl;
145 print_usage(argv[0]);
146 return 1;
147 }
148 } else if (arg == "--max-tokens" || arg == "-mt") {
149 if (current_arg_idx + 1 < argc) {
150 try { max_tokens = std::stoi(argv[current_arg_idx+1]); }
151 catch (const std::exception& e) { Logger::error("Invalid max_tokens: " + std::string(argv[current_arg_idx+1])); }
152 current_arg_idx += 2;
153 } else { std::cerr << "ERROR: --max-tokens requires a value." << std::endl; return 1;}
154 } else if (arg == "--n-gpu-layers" || arg == "-ngl") {
155 if (current_arg_idx + 1 < argc) {
156 try { n_gpu_layers = std::stoi(argv[current_arg_idx+1]); }
157 catch (const std::exception& e) { Logger::error("Invalid n_gpu_layers: " + std::string(argv[current_arg_idx+1])); }
158 current_arg_idx += 2;
159 } else { std::cerr << "ERROR: --n-gpu-layers requires a value." << std::endl; return 1;}
160 } else if (arg == "--use-mmap") {
161 if (current_arg_idx + 1 < argc) {
162 std::string mmap_str_val = argv[current_arg_idx+1];
163 std::transform(mmap_str_val.begin(), mmap_str_val.end(), mmap_str_val.begin(), ::tolower);
164 if (mmap_str_val == "false" || mmap_str_val == "0") use_mmap = false;
165 else if (mmap_str_val == "true" || mmap_str_val == "1") use_mmap = true;
166 else { std::cerr << "ERROR: Invalid use_mmap value." << std::endl; return 1; }
167 current_arg_idx += 2;
168 } else { std::cerr << "ERROR: --use-mmap requires a value." << std::endl; return 1;}
169 } else if (arg == "--temperature" || arg == "-t") {
170 if (current_arg_idx + 1 < argc) {
171 try { temperature = std::stof(argv[current_arg_idx+1]); }
172 catch (const std::exception& e) { Logger::error("Invalid temperature: " + std::string(argv[current_arg_idx+1]));}
173 current_arg_idx += 2;
174 } else { std::cerr << "ERROR: --temperature requires a value." << std::endl; return 1;}
175 } else if (arg == "--top-k" || arg == "-k") {
176 if (current_arg_idx + 1 < argc) {
177 try { top_k = std::stoi(argv[current_arg_idx+1]); }
178 catch (const std::exception& e) { Logger::error("Invalid top_k: " + std::string(argv[current_arg_idx+1])); }
179 current_arg_idx += 2;
180 } else { std::cerr << "ERROR: --top-k requires a value." << std::endl; return 1;}
181 } else if (arg == "--top-p" || arg == "-p") {
182 if (current_arg_idx + 1 < argc) {
183 try { top_p = std::stof(argv[current_arg_idx+1]); }
184 catch (const std::exception& e) { Logger::error("Invalid top_p: " + std::string(argv[current_arg_idx+1])); }
185 current_arg_idx += 2;
186 } else { std::cerr << "ERROR: --top-p requires a value." << std::endl; return 1;}
187 } else if (arg == "--use-kv-quant" || arg == "-kvq") {
188 if (current_arg_idx + 1 < argc) {
189 std::string kvq_str_val = argv[current_arg_idx+1];
190 std::transform(kvq_str_val.begin(), kvq_str_val.end(), kvq_str_val.begin(), ::tolower);
191 if (kvq_str_val == "false" || kvq_str_val == "0") use_kv_quant = false;
192 else if (kvq_str_val == "true" || kvq_str_val == "1") use_kv_quant = true;
193 else { std::cerr << "ERROR: Invalid use_kv_quant value: " << argv[current_arg_idx+1] << std::endl; return 1; }
194 current_arg_idx += 2;
195 } else { std::cerr << "ERROR: --use-kv-quant requires a value." << std::endl; return 1;}
196 } else if (arg == "--use-batch-generation" || arg == "-ubg") {
197 if (current_arg_idx + 1 < argc) {
198 std::string ubg_str_val = argv[current_arg_idx+1];
199 std::transform(ubg_str_val.begin(), ubg_str_val.end(), ubg_str_val.begin(), ::tolower);
200 if (ubg_str_val == "false" || ubg_str_val == "0") use_batch_generation = false;
201 else if (ubg_str_val == "true" || ubg_str_val == "1") use_batch_generation = true;
202 else { std::cerr << "ERROR: Invalid use_batch_generation value: " << argv[current_arg_idx+1] << std::endl; return 1; }
203 current_arg_idx += 2;
204 } else { std::cerr << "ERROR: --use-batch-generation requires a value." << std::endl; return 1;}
205 } else if (arg == "--batch-prompts" || arg == "-bp") {
206 current_arg_idx++; // Move past the --batch-prompts argument
207 while (current_arg_idx < argc && argv[current_arg_idx][0] != '-') {
208 batch_prompts.push_back(std::string(argv[current_arg_idx]));
209 current_arg_idx++;
210 }
211 if (batch_prompts.empty()) {
212 std::cerr << "ERROR: --batch-prompts requires at least one prompt." << std::endl;
213 return 1;
214 }
215 } else if (arg == "--max-batch-size" || arg == "-mbs") {
216 if (current_arg_idx + 1 < argc) {
217 try { max_batch_size = std::stoi(argv[current_arg_idx+1]); }
218 catch (const std::exception& e) { Logger::error("Invalid max_batch_size: " + std::string(argv[current_arg_idx+1])); }
219 current_arg_idx += 2;
220 } else { std::cerr << "ERROR: --max-batch-size requires a value." << std::endl; return 1;}
221 } else {
222 if (user_prompt_str == "Hello, world!") {
223 user_prompt_str = trim_whitespace(argv[current_arg_idx]);
224 } else if (argv[current_arg_idx][0] != '-') {
225 std::cerr << "ERROR: Unexpected positional argument: " << argv[current_arg_idx] << std::endl;
226 print_usage(argv[0]);
227 return 1;
228 } else {
229 std::cerr << "ERROR: Unknown option: " << argv[current_arg_idx] << std::endl;
230 print_usage(argv[0]);
231 return 1;
232 }
233 current_arg_idx++;
234 }
235 }
236
237 Logger::info("Using model path/directory: " + model_path_or_dir);
238 Logger::info("Tokenizer path: " + tokenizer_path);
239 Logger::info("Num threads: " + std::to_string(num_threads));
240 Logger::info("Mode: " + mode_str);
241 Logger::info("System Prompt: \"" + system_prompt_str + "\"");
242 Logger::info("Default User Prompt: \"" + user_prompt_str + "\"");
243 Logger::info("Max tokens: " + std::to_string(max_tokens));
244 Logger::info("N GPU Layers: " + std::to_string(n_gpu_layers));
245 Logger::info(std::string("Use mmap: ") + (use_mmap ? "true" : "false"));
246 Logger::info("Temperature: " + std::to_string(temperature));
247 Logger::info("Top-K: " + std::to_string(top_k));
248 Logger::info("Top-P: " + std::to_string(top_p));
249 Logger::info(std::string("Use KVCache Quantization: ") + (use_kv_quant ? "true" : "false"));
250 Logger::info(std::string("Use Batch Generation: ") + (use_batch_generation ? "true" : "false"));
251
252 try {
253 // For batch mode, we need to determine max_batch_size before creating session
254 int session_max_batch_size = (mode_str == "batch") ? max_batch_size : 1;
255 bool session_use_batch_generation = (mode_str == "batch") ? true : use_batch_generation;
256
257 tinyllama::TinyLlamaSession session(model_path_or_dir, tokenizer_path, num_threads, n_gpu_layers, use_mmap, use_kv_quant, session_use_batch_generation, session_max_batch_size);
258 Logger::info("TinyLlamaSession initialized successfully.");
259
260 const ModelConfig& config = session.get_config();
261 bool apply_qa_formatting_decision; // This will be true if no advanced template is used
262
264 (session.get_tokenizer() && !session.get_tokenizer()->get_gguf_chat_template().empty())) { // Use getter
265 apply_qa_formatting_decision = false; // Llama 3 or GGUF template handles formatting
266 Logger::info("[Main.cpp] Llama 3 model or GGUF chat template detected. Internal Q&A prompt formatting will be DISABLED.");
267 } else {
268 apply_qa_formatting_decision = true; // Default for other models without GGUF template
269 Logger::info("[Main.cpp] Non-Llama 3 model and no GGUF chat template. Internal Q&A prompt formatting will be ENABLED.");
270 }
271
272 Logger::info("[Main.cpp] Mode: '" + mode_str + "'. Final decision for apply_qa_formatting_decision: " + std::string(apply_qa_formatting_decision ? "true" : "false"));
273
274 if (mode_str == "prompt") {
275 std::string generated_text =
276 session.generate(user_prompt_str, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
277 std::cout << generated_text << std::endl;
278 } else if (mode_str == "chat") {
279 std::cout << "Entering chat mode. System Prompt: \"" << system_prompt_str << "\". Type 'exit', 'quit' to end." << std::endl;
280 std::string current_user_message;
281
282 if (!user_prompt_str.empty() && (user_prompt_str != "Hello, world!" || !system_prompt_str.empty() )) {
283 current_user_message = user_prompt_str;
284 std::cout << "You: " << current_user_message << std::endl;
285 std::string ai_response = session.generate(current_user_message, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
286 std::cout << "AI: " << ai_response << std::endl;
287 }
288
289 while (true) {
290 std::cout << "You: ";
291 std::getline(std::cin, current_user_message);
292 if (current_user_message == "exit" || current_user_message == "quit") {
293 break;
294 }
295 if (current_user_message.empty()) {
296 continue;
297 }
298 std::string ai_response = session.generate(current_user_message, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
299 std::cout << "AI: " << ai_response << std::endl;
300 }
301 } else if (mode_str == "batch") {
302 // Batch processing logic
303 if (batch_prompts.empty()) {
304 std::cerr << "ERROR: Batch mode requires prompts. Use --batch-prompts \"prompt1\" \"prompt2\" ..." << std::endl;
305 print_usage(argv[0]);
306 return 1;
307 }
308
309 if (batch_prompts.size() > static_cast<size_t>(max_batch_size)) {
310 std::cerr << "ERROR: Number of prompts (" << batch_prompts.size()
311 << ") exceeds max batch size (" << max_batch_size << ")" << std::endl;
312 return 1;
313 }
314
315 Logger::info("[Batch Mode] Processing " + std::to_string(batch_prompts.size()) + " prompts in batch...");
316
317 try {
318 std::vector<std::string> results = session.generate_batch(
319 batch_prompts, max_tokens, temperature, top_k, top_p,
320 system_prompt_str, apply_qa_formatting_decision
321 );
322
323 std::cout << "\n" << std::string(80, '=') << std::endl;
324 std::cout << "BATCH PROCESSING RESULTS" << std::endl;
325 std::cout << std::string(80, '=') << std::endl;
326
327 for (size_t i = 0; i < batch_prompts.size(); ++i) {
328 std::cout << "\n--- Prompt " << (i + 1) << " ---" << std::endl;
329 std::cout << "Input: " << batch_prompts[i] << std::endl;
330 std::cout << "Output: " << results[i] << std::endl;
331 }
332
333 } catch (const std::exception& e) {
334 Logger::error("[Batch Mode] Error during batch processing: " + std::string(e.what()));
335 return 1;
336 }
337 } else {
338 std::cerr << "ERROR: Invalid mode '" << mode_str << "'. Expected 'prompt', 'chat', or 'batch'." << std::endl;
339 print_usage(argv[0]);
340 return 1;
341 }
342
343 } catch (const std::exception& e) {
344 Logger::error("Error: " + std::string(e.what()));
345 return 1;
346 }
347
348 return 0;
349}
static void warning(const std::string &message)
Definition logger.cpp:139
static void info(const std::string &message)
Definition logger.cpp:135
static void error(const std::string &message)
Definition logger.cpp:143
const std::string & get_gguf_chat_template() const
Represents an active TinyLlama session holding the loaded model and tokenizer.
Definition api.h:26
const Tokenizer * get_tokenizer() const
Definition api.h:105
std::string generate(const std::string &prompt, int steps=128, float temperature=0.1f, int top_k=40, float top_p=0.9f, const std::string &system_prompt="", bool apply_q_a_format=false)
Generates text based on a given prompt.
Definition api.cpp:433
std::vector< std::string > generate_batch(const std::vector< std::string > &prompts, int steps=128, float temperature=0.1f, int top_k=40, float top_p=0.9f, const std::string &system_prompt="", bool apply_q_a_format=false)
Generates text for multiple prompts in a single batch (parallel processing).
Definition api.cpp:780
const ModelConfig & get_config() const
Definition api.h:106
Logging utilities for the TinyLlama implementation.
std::string trim_whitespace(const std::string &s)
Definition main.cpp:53
int main(int argc, char **argv)
Definition main.cpp:87
void print_usage(const char *program_name)
Definition main.cpp:62
Model configuration structure holding architecture and hyperparameters.
Definition model.h:80
TokenizerFamily tokenizer_family
Definition model.h:117