56int main(
int argc,
char** argv) {
57 std::string model_dir =
"data";
58 std::string host =
"localhost";
60 std::string www_path =
"./www";
66 port = std::stoi(argv[2]);
77 std::shared_ptr<tinyllama::TinyLlamaSession> session;
80 session = std::make_shared<tinyllama::TinyLlamaSession>(model_dir,
"tokenizer.json", 4, -1,
true);
82 }
catch (
const std::exception& e) {
83 Logger::error(std::string(
"Failed to load model: ") + e.what());
89 if (std::filesystem::exists(www_path) &&
90 std::filesystem::is_directory(www_path)) {
92 bool mount_ok = svr.set_mount_point(
"/", www_path);
94 Logger::error(
"Failed to mount static file directory: " + www_path);
98 Logger::info(
"Static file directory not found: " + www_path +
99 ". Web client will not be served.");
102 svr.Post(
"/chat", [&session](
const httplib::Request& req,
103 httplib::Response& res) {
105 res.set_header(
"Access-Control-Allow-Origin",
"*");
106 res.set_header(
"Access-Control-Allow-Methods",
"POST, OPTIONS");
107 res.set_header(
"Access-Control-Allow-Headers",
"Content-Type");
109 std::string user_input_from_client;
111 float temperature = 0.1f;
112 int max_new_tokens = 60;
117 json req_json = json::parse(req.body);
118 if (req_json.contains(
"user_input")) {
119 user_input_from_client = req_json[
"user_input"].get<std::string>();
121 throw std::runtime_error(
"Missing 'user_input' field in request JSON");
124 if (req_json.contains(
"max_new_tokens"))
125 max_new_tokens = req_json[
"max_new_tokens"].get<
int>();
126 if (req_json.contains(
"temperature"))
127 temperature = req_json[
"temperature"].get<float>();
128 if (req_json.contains(
"top_k"))
129 top_k = req_json[
"top_k"].get<int>();
130 if (req_json.contains(
"top_p"))
131 top_p = req_json[
"top_p"].get<float>();
134 user_input_from_client.substr(0, 100) +
"...");
137 std::string prompt_for_session_generate;
138 bool use_q_a_format_for_session_generate =
false;
140 const Tokenizer* tokenizer = session->get_tokenizer();
143 prompt_for_session_generate = user_input_from_client;
146 use_q_a_format_for_session_generate =
false;
147 Logger::info(
"GGUF Llama 3 model detected (via tokenizer_family). Q&A prompt formatting will be DISABLED for session generate.");
149 use_q_a_format_for_session_generate =
true;
151 "GGUF (Non-Llama 3) model detected. Using Q:A: format via session->generate.");
154 std::string system_prompt_text =
"You are a helpful AI.";
157 user_input_from_client, system_prompt_text, config);
159 "Safetensors model detected. Applied chat template via "
160 "tokenizer. Prompt: " +
161 prompt_for_session_generate.substr(0, 200) +
"...");
164 "CRITICAL: Tokenizer not available for Safetensors model in "
165 "server. Cannot apply chat template.");
167 prompt_for_session_generate = user_input_from_client;
169 use_q_a_format_for_session_generate =
false;
172 std::string reply = session->generate(
173 prompt_for_session_generate, max_new_tokens, temperature, top_k, top_p,
"",
174 use_q_a_format_for_session_generate);
175 Logger::info(
"Generated reply: " + reply.substr(0, 50) +
"...");
178 res_json[
"reply"] = reply;
180 res.set_content(res_json.dump(),
"application/json");
183 }
catch (
const json::parse_error& e) {
184 Logger::error(
"JSON parsing error: " + std::string(e.what()));
187 err_json[
"error"] =
"Invalid JSON format: " + std::string(e.what());
188 res.set_content(err_json.dump(),
"application/json");
189 }
catch (
const std::exception& e) {
193 err_json[
"error"] =
"Internal server error: " + std::string(e.what());
194 res.set_content(err_json.dump(),
"application/json");
198 svr.Options(
"/chat", [](
const httplib::Request& req, httplib::Response& res) {
199 res.set_header(
"Access-Control-Allow-Origin",
"*");
200 res.set_header(
"Access-Control-Allow-Headers",
"Content-Type");
201 res.set_header(
"Access-Control-Allow-Methods",
"POST, OPTIONS");
205 unsigned int num_threads =
SAFE_MAX(1u, std::thread::hardware_concurrency() / 2);
206 Logger::info(
"Starting server on " + host +
":" + std::to_string(port) +
207 " with " + std::to_string(num_threads) +
" threads.");
209 svr.listen(host.c_str(), port);