87int main(
int argc,
char** argv) {
88 if (argc > 1 && (std::string(argv[1]) ==
"-h" || std::string(argv[1]) ==
"--help")) {
94 std::cerr <<
"ERROR: Missing required arguments." << std::endl;
99 std::string model_path_or_dir = argv[1];
100 std::string tokenizer_path = argv[2];
104 num_threads = std::stoi(argv[3]);
106 Logger::warning(
"Number of threads not provided, using default: " + std::to_string(num_threads));
108 }
catch (
const std::exception& e) {
109 Logger::warning(
"Could not parse num_threads from argv[3]: '" + (argc > 3 ? std::string(argv[3]) :
"<not provided>") +
"'. Using default: " + std::to_string(num_threads));
113 std::cerr <<
"ERROR: Missing mode argument (prompt|chat|batch) after num_threads." << std::endl;
117 std::string mode_str = argv[4];
119 std::string system_prompt_str =
"";
120 std::string user_prompt_str =
"Hello, world!";
121 int max_tokens = 256;
122 int n_gpu_layers = -1;
123 bool use_mmap =
true;
124 bool use_kv_quant =
false;
125 bool use_batch_generation =
false;
128 std::vector<std::string> batch_prompts;
129 int max_batch_size = 8;
132 float temperature = 0.1f;
136 int current_arg_idx = 5;
137 while(current_arg_idx < argc) {
138 std::string arg = argv[current_arg_idx];
139 if (arg ==
"--system-prompt" || arg ==
"-sp") {
140 if (current_arg_idx + 1 < argc) {
141 system_prompt_str = argv[current_arg_idx + 1];
142 current_arg_idx += 2;
144 std::cerr <<
"ERROR: --system-prompt requires a value." << std::endl;
148 }
else if (arg ==
"--max-tokens" || arg ==
"-mt") {
149 if (current_arg_idx + 1 < argc) {
150 try { max_tokens = std::stoi(argv[current_arg_idx+1]); }
151 catch (
const std::exception& e) {
Logger::error(
"Invalid max_tokens: " + std::string(argv[current_arg_idx+1])); }
152 current_arg_idx += 2;
153 }
else { std::cerr <<
"ERROR: --max-tokens requires a value." << std::endl;
return 1;}
154 }
else if (arg ==
"--n-gpu-layers" || arg ==
"-ngl") {
155 if (current_arg_idx + 1 < argc) {
156 try { n_gpu_layers = std::stoi(argv[current_arg_idx+1]); }
157 catch (
const std::exception& e) {
Logger::error(
"Invalid n_gpu_layers: " + std::string(argv[current_arg_idx+1])); }
158 current_arg_idx += 2;
159 }
else { std::cerr <<
"ERROR: --n-gpu-layers requires a value." << std::endl;
return 1;}
160 }
else if (arg ==
"--use-mmap") {
161 if (current_arg_idx + 1 < argc) {
162 std::string mmap_str_val = argv[current_arg_idx+1];
163 std::transform(mmap_str_val.begin(), mmap_str_val.end(), mmap_str_val.begin(), ::tolower);
164 if (mmap_str_val ==
"false" || mmap_str_val ==
"0") use_mmap =
false;
165 else if (mmap_str_val ==
"true" || mmap_str_val ==
"1") use_mmap =
true;
166 else { std::cerr <<
"ERROR: Invalid use_mmap value." << std::endl;
return 1; }
167 current_arg_idx += 2;
168 }
else { std::cerr <<
"ERROR: --use-mmap requires a value." << std::endl;
return 1;}
169 }
else if (arg ==
"--temperature" || arg ==
"-t") {
170 if (current_arg_idx + 1 < argc) {
171 try { temperature = std::stof(argv[current_arg_idx+1]); }
172 catch (
const std::exception& e) {
Logger::error(
"Invalid temperature: " + std::string(argv[current_arg_idx+1]));}
173 current_arg_idx += 2;
174 }
else { std::cerr <<
"ERROR: --temperature requires a value." << std::endl;
return 1;}
175 }
else if (arg ==
"--top-k" || arg ==
"-k") {
176 if (current_arg_idx + 1 < argc) {
177 try { top_k = std::stoi(argv[current_arg_idx+1]); }
178 catch (
const std::exception& e) {
Logger::error(
"Invalid top_k: " + std::string(argv[current_arg_idx+1])); }
179 current_arg_idx += 2;
180 }
else { std::cerr <<
"ERROR: --top-k requires a value." << std::endl;
return 1;}
181 }
else if (arg ==
"--top-p" || arg ==
"-p") {
182 if (current_arg_idx + 1 < argc) {
183 try { top_p = std::stof(argv[current_arg_idx+1]); }
184 catch (
const std::exception& e) {
Logger::error(
"Invalid top_p: " + std::string(argv[current_arg_idx+1])); }
185 current_arg_idx += 2;
186 }
else { std::cerr <<
"ERROR: --top-p requires a value." << std::endl;
return 1;}
187 }
else if (arg ==
"--use-kv-quant" || arg ==
"-kvq") {
188 if (current_arg_idx + 1 < argc) {
189 std::string kvq_str_val = argv[current_arg_idx+1];
190 std::transform(kvq_str_val.begin(), kvq_str_val.end(), kvq_str_val.begin(), ::tolower);
191 if (kvq_str_val ==
"false" || kvq_str_val ==
"0") use_kv_quant =
false;
192 else if (kvq_str_val ==
"true" || kvq_str_val ==
"1") use_kv_quant =
true;
193 else { std::cerr <<
"ERROR: Invalid use_kv_quant value: " << argv[current_arg_idx+1] << std::endl;
return 1; }
194 current_arg_idx += 2;
195 }
else { std::cerr <<
"ERROR: --use-kv-quant requires a value." << std::endl;
return 1;}
196 }
else if (arg ==
"--use-batch-generation" || arg ==
"-ubg") {
197 if (current_arg_idx + 1 < argc) {
198 std::string ubg_str_val = argv[current_arg_idx+1];
199 std::transform(ubg_str_val.begin(), ubg_str_val.end(), ubg_str_val.begin(), ::tolower);
200 if (ubg_str_val ==
"false" || ubg_str_val ==
"0") use_batch_generation =
false;
201 else if (ubg_str_val ==
"true" || ubg_str_val ==
"1") use_batch_generation =
true;
202 else { std::cerr <<
"ERROR: Invalid use_batch_generation value: " << argv[current_arg_idx+1] << std::endl;
return 1; }
203 current_arg_idx += 2;
204 }
else { std::cerr <<
"ERROR: --use-batch-generation requires a value." << std::endl;
return 1;}
205 }
else if (arg ==
"--batch-prompts" || arg ==
"-bp") {
207 while (current_arg_idx < argc && argv[current_arg_idx][0] !=
'-') {
208 batch_prompts.push_back(std::string(argv[current_arg_idx]));
211 if (batch_prompts.empty()) {
212 std::cerr <<
"ERROR: --batch-prompts requires at least one prompt." << std::endl;
215 }
else if (arg ==
"--max-batch-size" || arg ==
"-mbs") {
216 if (current_arg_idx + 1 < argc) {
217 try { max_batch_size = std::stoi(argv[current_arg_idx+1]); }
218 catch (
const std::exception& e) {
Logger::error(
"Invalid max_batch_size: " + std::string(argv[current_arg_idx+1])); }
219 current_arg_idx += 2;
220 }
else { std::cerr <<
"ERROR: --max-batch-size requires a value." << std::endl;
return 1;}
222 if (user_prompt_str ==
"Hello, world!") {
224 }
else if (argv[current_arg_idx][0] !=
'-') {
225 std::cerr <<
"ERROR: Unexpected positional argument: " << argv[current_arg_idx] << std::endl;
229 std::cerr <<
"ERROR: Unknown option: " << argv[current_arg_idx] << std::endl;
237 Logger::info(
"Using model path/directory: " + model_path_or_dir);
239 Logger::info(
"Num threads: " + std::to_string(num_threads));
241 Logger::info(
"System Prompt: \"" + system_prompt_str +
"\"");
242 Logger::info(
"Default User Prompt: \"" + user_prompt_str +
"\"");
243 Logger::info(
"Max tokens: " + std::to_string(max_tokens));
244 Logger::info(
"N GPU Layers: " + std::to_string(n_gpu_layers));
245 Logger::info(std::string(
"Use mmap: ") + (use_mmap ?
"true" :
"false"));
246 Logger::info(
"Temperature: " + std::to_string(temperature));
249 Logger::info(std::string(
"Use KVCache Quantization: ") + (use_kv_quant ?
"true" :
"false"));
250 Logger::info(std::string(
"Use Batch Generation: ") + (use_batch_generation ?
"true" :
"false"));
254 int session_max_batch_size = (mode_str ==
"batch") ? max_batch_size : 1;
255 bool session_use_batch_generation = (mode_str ==
"batch") ?
true : use_batch_generation;
257 tinyllama::TinyLlamaSession session(model_path_or_dir, tokenizer_path, num_threads, n_gpu_layers, use_mmap, use_kv_quant, session_use_batch_generation, session_max_batch_size);
258 Logger::info(
"TinyLlamaSession initialized successfully.");
261 bool apply_qa_formatting_decision;
265 apply_qa_formatting_decision =
false;
266 Logger::info(
"[Main.cpp] Llama 3 model or GGUF chat template detected. Internal Q&A prompt formatting will be DISABLED.");
268 apply_qa_formatting_decision =
true;
269 Logger::info(
"[Main.cpp] Non-Llama 3 model and no GGUF chat template. Internal Q&A prompt formatting will be ENABLED.");
272 Logger::info(
"[Main.cpp] Mode: '" + mode_str +
"'. Final decision for apply_qa_formatting_decision: " + std::string(apply_qa_formatting_decision ?
"true" :
"false"));
274 if (mode_str ==
"prompt") {
275 std::string generated_text =
276 session.
generate(user_prompt_str, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
277 std::cout << generated_text << std::endl;
278 }
else if (mode_str ==
"chat") {
279 std::cout <<
"Entering chat mode. System Prompt: \"" << system_prompt_str <<
"\". Type 'exit', 'quit' to end." << std::endl;
280 std::string current_user_message;
282 if (!user_prompt_str.empty() && (user_prompt_str !=
"Hello, world!" || !system_prompt_str.empty() )) {
283 current_user_message = user_prompt_str;
284 std::cout <<
"You: " << current_user_message << std::endl;
285 std::string ai_response = session.
generate(current_user_message, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
286 std::cout <<
"AI: " << ai_response << std::endl;
290 std::cout <<
"You: ";
291 std::getline(std::cin, current_user_message);
292 if (current_user_message ==
"exit" || current_user_message ==
"quit") {
295 if (current_user_message.empty()) {
298 std::string ai_response = session.
generate(current_user_message, max_tokens, temperature, top_k, top_p, system_prompt_str, apply_qa_formatting_decision);
299 std::cout <<
"AI: " << ai_response << std::endl;
301 }
else if (mode_str ==
"batch") {
303 if (batch_prompts.empty()) {
304 std::cerr <<
"ERROR: Batch mode requires prompts. Use --batch-prompts \"prompt1\" \"prompt2\" ..." << std::endl;
309 if (batch_prompts.size() >
static_cast<size_t>(max_batch_size)) {
310 std::cerr <<
"ERROR: Number of prompts (" << batch_prompts.size()
311 <<
") exceeds max batch size (" << max_batch_size <<
")" << std::endl;
315 Logger::info(
"[Batch Mode] Processing " + std::to_string(batch_prompts.size()) +
" prompts in batch...");
319 batch_prompts, max_tokens, temperature, top_k, top_p,
320 system_prompt_str, apply_qa_formatting_decision
323 std::cout <<
"\n" << std::string(80,
'=') << std::endl;
324 std::cout <<
"BATCH PROCESSING RESULTS" << std::endl;
325 std::cout << std::string(80,
'=') << std::endl;
327 for (
size_t i = 0; i < batch_prompts.size(); ++i) {
328 std::cout <<
"\n--- Prompt " << (i + 1) <<
" ---" << std::endl;
329 std::cout <<
"Input: " << batch_prompts[i] << std::endl;
330 std::cout <<
"Output: " << results[i] << std::endl;
333 }
catch (
const std::exception& e) {
334 Logger::error(
"[Batch Mode] Error during batch processing: " + std::string(e.what()));
338 std::cerr <<
"ERROR: Invalid mode '" << mode_str <<
"'. Expected 'prompt', 'chat', or 'batch'." << std::endl;
343 }
catch (
const std::exception& e) {