87 {
88 if (argc > 1 && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help")) {
90 return 0;
91 }
92
93 if (argc < 5) {
94 std::cerr << "ERROR: Missing required arguments." << std::endl;
96 return 1;
97 }
98
99 std::string model_path_or_dir = argv[1];
100 std::string tokenizer_path = argv[2];
101 int num_threads = 4;
102 try {
103 if (argc > 3) {
104 num_threads = std::stoi(argv[3]);
105 } else {
106 Logger::warning(
"Number of threads not provided, using default: " + std::to_string(num_threads));
107 }
108 } catch (const std::exception& e) {
109 Logger::warning(
"Could not parse num_threads from argv[3]: '" + (argc > 3 ? std::string(argv[3]) :
"<not provided>") +
"'. Using default: " + std::to_string(num_threads));
110 }
111
112 if (argc < 5) {
113 std::cerr << "ERROR: Missing mode argument (prompt|chat|batch) after num_threads." << std::endl;
115 return 1;
116 }
117 std::string mode_str = argv[4];
118
119 std::string system_prompt_str = "";
120 std::string user_prompt_str = "Hello, world!";
121 int max_tokens = 256;
122 int n_gpu_layers = -1;
123 bool use_mmap = true;
124 bool use_kv_quant = false;
125 bool use_batch_generation = false;
126
127
128 std::vector<std::string> batch_prompts;
129 int max_batch_size = 8;
130
131
132 float temperature = 0.1f;
133 int top_k = 40;
134 float top_p = 0.9f;
135
136 int current_arg_idx = 5;
137 while(current_arg_idx < argc) {
138 std::string arg = argv[current_arg_idx];
139 if (arg == "--system-prompt" || arg == "-sp") {
140 if (current_arg_idx + 1 < argc) {
141 system_prompt_str = argv[current_arg_idx + 1];
142 current_arg_idx += 2;
143 } else {
144 std::cerr << "ERROR: --system-prompt requires a value." << std::endl;
146 return 1;
147 }
148 } else if (arg == "--max-tokens" || arg == "-mt") {
149 if (current_arg_idx + 1 < argc) {
150 try { max_tokens = std::stoi(argv[current_arg_idx+1]); }
151 catch (
const std::exception& e) {
Logger::error(
"Invalid max_tokens: " + std::string(argv[current_arg_idx+1])); }
152 current_arg_idx += 2;
153 } else { std::cerr << "ERROR: --max-tokens requires a value." << std::endl; return 1;}
154 } else if (arg == "--n-gpu-layers" || arg == "-ngl") {
155 if (current_arg_idx + 1 < argc) {
156 try { n_gpu_layers = std::stoi(argv[current_arg_idx+1]); }
157 catch (
const std::exception& e) {
Logger::error(
"Invalid n_gpu_layers: " + std::string(argv[current_arg_idx+1])); }
158 current_arg_idx += 2;
159 } else { std::cerr << "ERROR: --n-gpu-layers requires a value." << std::endl; return 1;}
160 } else if (arg == "--use-mmap") {
161 if (current_arg_idx + 1 < argc) {
162 std::string mmap_str_val = argv[current_arg_idx+1];
163 std::transform(mmap_str_val.begin(), mmap_str_val.end(), mmap_str_val.begin(), ::tolower);
164 if (mmap_str_val == "false" || mmap_str_val == "0") use_mmap = false;
165 else if (mmap_str_val == "true" || mmap_str_val == "1") use_mmap = true;
166 else { std::cerr << "ERROR: Invalid use_mmap value." << std::endl; return 1; }
167 current_arg_idx += 2;
168 } else { std::cerr << "ERROR: --use-mmap requires a value." << std::endl; return 1;}
169 } else if (arg == "--temperature" || arg == "-t") {
170 if (current_arg_idx + 1 < argc) {
171 try { temperature = std::stof(argv[current_arg_idx+1]); }
172 catch (
const std::exception& e) {
Logger::error(
"Invalid temperature: " + std::string(argv[current_arg_idx+1]));}
173 current_arg_idx += 2;
174 } else { std::cerr << "ERROR: --temperature requires a value." << std::endl; return 1;}
175 } else if (arg == "--top-k" || arg == "-k") {
176 if (current_arg_idx + 1 < argc) {
177 try { top_k = std::stoi(argv[current_arg_idx+1]); }
178 catch (
const std::exception& e) {
Logger::error(
"Invalid top_k: " + std::string(argv[current_arg_idx+1])); }
179 current_arg_idx += 2;
180 } else { std::cerr << "ERROR: --top-k requires a value." << std::endl; return 1;}
181 } else if (arg == "--top-p" || arg == "-p") {
182 if (current_arg_idx + 1 < argc) {
183 try { top_p = std::stof(argv[current_arg_idx+1]); }
184 catch (
const std::exception& e) {
Logger::error(
"Invalid top_p: " + std::string(argv[current_arg_idx+1])); }
185 current_arg_idx += 2;
186 } else { std::cerr << "ERROR: --top-p requires a value." << std::endl; return 1;}
187 } else if (arg == "--use-kv-quant" || arg == "-kvq") {
188 if (current_arg_idx + 1 < argc) {
189 std::string kvq_str_val = argv[current_arg_idx+1];
190 std::transform(kvq_str_val.begin(), kvq_str_val.end(), kvq_str_val.begin(), ::tolower);
191 if (kvq_str_val == "false" || kvq_str_val == "0") use_kv_quant = false;
192 else if (kvq_str_val == "true" || kvq_str_val == "1") use_kv_quant = true;
193 else { std::cerr << "ERROR: Invalid use_kv_quant value: " << argv[current_arg_idx+1] << std::endl; return 1; }
194 current_arg_idx += 2;
195 } else { std::cerr << "ERROR: --use-kv-quant requires a value." << std::endl; return 1;}
196 } else if (arg == "--use-batch-generation" || arg == "-ubg") {
197 if (current_arg_idx + 1 < argc) {
198 std::string ubg_str_val = argv[current_arg_idx+1];
199 std::transform(ubg_str_val.begin(), ubg_str_val.end(), ubg_str_val.begin(), ::tolower);
200 if (ubg_str_val == "false" || ubg_str_val == "0") use_batch_generation = false;
201 else if (ubg_str_val == "true" || ubg_str_val == "1") use_batch_generation = true;
202 else { std::cerr << "ERROR: Invalid use_batch_generation value: " << argv[current_arg_idx+1] << std::endl; return 1; }
203 current_arg_idx += 2;
204 } else { std::cerr << "ERROR: --use-batch-generation requires a value." << std::endl; return 1;}
205 } else if (arg == "--batch-prompts" || arg == "-bp") {
206 current_arg_idx++;
207 while (current_arg_idx < argc && argv[current_arg_idx][0] != '-') {
208 batch_prompts.push_back(std::string(argv[current_arg_idx]));
209 current_arg_idx++;
210 }
211 if (batch_prompts.empty()) {
212 std::cerr << "ERROR: --batch-prompts requires at least one prompt." << std::endl;
213 return 1;
214 }
215 } else if (arg == "--max-batch-size" || arg == "-mbs") {
216 if (current_arg_idx + 1 < argc) {
217 try { max_batch_size = std::stoi(argv[current_arg_idx+1]); }
218 catch (
const std::exception& e) {
Logger::error(
"Invalid max_batch_size: " + std::string(argv[current_arg_idx+1])); }
219 current_arg_idx += 2;
220 } else { std::cerr << "ERROR: --max-batch-size requires a value." << std::endl; return 1;}
221 } else {
222 if (user_prompt_str == "Hello, world!") {
224 } else if (argv[current_arg_idx][0] != '-') {
225 std::cerr << "ERROR: Unexpected positional argument: " << argv[current_arg_idx] << std::endl;
227 return 1;
228 } else {
229 std::cerr << "ERROR: Unknown option: " << argv[current_arg_idx] << std::endl;
231 return 1;
232 }
233 current_arg_idx++;
234 }
235 }
236
237 Logger::info(
"Using model path/directory: " + model_path_or_dir);
239 Logger::info(
"Num threads: " + std::to_string(num_threads));
241 Logger::info(
"System Prompt: \"" + system_prompt_str +
"\"");
242 Logger::info(
"Default User Prompt: \"" + user_prompt_str +
"\"");
243 Logger::info(
"Max tokens: " + std::to_string(max_tokens));
244 Logger::info(
"N GPU Layers: " + std::to_string(n_gpu_layers));
245 Logger::info(std::string(
"Use mmap: ") + (use_mmap ?
"true" :
"false"));
246 Logger::info(
"Temperature: " + std::to_string(temperature));
249 Logger::info(std::string(
"Use KVCache Quantization: ") + (use_kv_quant ?
"true" :
"false"));
250 Logger::info(std::string(
"Use Batch Generation: ") + (use_batch_generation ?
"true" :
"false"));
251
252 try {
253
254 int session_max_batch_size = (mode_str == "batch") ? max_batch_size : 1;
255 bool session_use_batch_generation = (mode_str == "batch") ? true : use_batch_generation;
256
257 tinyllama::TinyLlamaSession session(model_path_or_dir, tokenizer_path, num_threads, n_gpu_layers, use_mmap, use_kv_quant, session_use_batch_generation, session_max_batch_size);
258 Logger::info(
"TinyLlamaSession initialized successfully.");
259
261 bool apply_qa_formatting_decision;
262
264 (session.get_tokenizer() && !session.get_tokenizer()->get_gguf_chat_template().empty())) {
265 apply_qa_formatting_decision = false;
266 Logger::info(
"[Main.cpp] Llama 3 model or GGUF chat template detected. Internal Q&A prompt formatting will be DISABLED.");
267 } else {
268 apply_qa_formatting_decision = true;
269 Logger::info(
"[Main.cpp] Non-Llama 3 model and no GGUF chat template. Internal Q&A prompt formatting will be ENABLED.");
270 }
271
272 Logger::info(
"[Main.cpp] Mode: '" + mode_str +
"'. Final decision for apply_qa_formatting_decision: " + std::string(apply_qa_formatting_decision ?
"true" :
"false"));
273
274 if (mode_str == "prompt") {
275 std::string generated_text =
276 session.generate(user_prompt_str, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
277 std::cout << generated_text << std::endl;
278 } else if (mode_str == "chat") {
279 std::cout << "Entering chat mode. System Prompt: \"" << system_prompt_str << "\". Type 'exit', 'quit' to end." << std::endl;
280 std::string current_user_message;
281
282 if (!user_prompt_str.empty() && (user_prompt_str != "Hello, world!" || !system_prompt_str.empty() )) {
283 current_user_message = user_prompt_str;
284 std::cout << "You: " << current_user_message << std::endl;
285 std::string ai_response = session.generate(current_user_message, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
286 std::cout << "AI: " << ai_response << std::endl;
287 }
288
289 while (true) {
290 std::cout << "You: ";
291 std::getline(std::cin, current_user_message);
292 if (current_user_message == "exit" || current_user_message == "quit") {
293 break;
294 }
295 if (current_user_message.empty()) {
296 continue;
297 }
298 std::string ai_response = session.generate(current_user_message, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
299 std::cout << "AI: " << ai_response << std::endl;
300 }
301 } else if (mode_str == "batch") {
302
303 if (batch_prompts.empty()) {
304 std::cerr << "ERROR: Batch mode requires prompts. Use --batch-prompts \"prompt1\" \"prompt2\" ..." << std::endl;
306 return 1;
307 }
308
309 if (batch_prompts.size() > static_cast<size_t>(max_batch_size)) {
310 std::cerr << "ERROR: Number of prompts (" << batch_prompts.size()
311 << ") exceeds max batch size (" << max_batch_size << ")" << std::endl;
312 return 1;
313 }
314
315 Logger::info(
"[Batch Mode] Processing " + std::to_string(batch_prompts.size()) +
" prompts in batch...");
316
317 try {
318 std::vector<std::string> results = session.generate_batch(
319 batch_prompts, max_tokens, temperature, top_k, top_p,
320 system_prompt_str, apply_qa_formatting_decision
321 );
322
323 std::cout << "\n" << std::string(80, '=') << std::endl;
324 std::cout << "BATCH PROCESSING RESULTS" << std::endl;
325 std::cout << std::string(80, '=') << std::endl;
326
327 for (size_t i = 0; i < batch_prompts.size(); ++i) {
328 std::cout << "\n--- Prompt " << (i + 1) << " ---" << std::endl;
329 std::cout << "Input: " << batch_prompts[i] << std::endl;
330 std::cout << "Output: " << results[i] << std::endl;
331 }
332
333 } catch (const std::exception& e) {
334 Logger::error(
"[Batch Mode] Error during batch processing: " + std::string(e.what()));
335 return 1;
336 }
337 } else {
338 std::cerr << "ERROR: Invalid mode '" << mode_str << "'. Expected 'prompt', 'chat', or 'batch'." << std::endl;
340 return 1;
341 }
342
343 } catch (const std::exception& e) {
345 return 1;
346 }
347
348 return 0;
349}
static void warning(const std::string &message)
static void info(const std::string &message)
static void error(const std::string &message)
Represents an active TinyLlama session holding the loaded model and tokenizer.
std::string trim_whitespace(const std::string &s)
void print_usage(const char *program_name)
Model configuration structure holding architecture and hyperparameters.
TokenizerFamily tokenizer_family