56 {
57 std::string model_dir = "data";
58 std::string host = "localhost";
59 int port = 8080;
60 std::string www_path = "./www";
61
62 if (argc > 1) {
63 model_dir = argv[1];
64 }
65 if (argc > 2) {
66 port = std::stoi(argv[2]);
67 }
68 if (argc > 3) {
69 host = argv[3];
70 }
71 if (argc > 4) {
72 www_path = argv[4];
73 }
74
76
77 std::shared_ptr<tinyllama::TinyLlamaSession> session;
78 try {
80 session = std::make_shared<tinyllama::TinyLlamaSession>(model_dir, "tokenizer.json", 4, -1, true);
82 } catch (const std::exception& e) {
83 Logger::error(std::string(
"Failed to load model: ") + e.what());
84 return 1;
85 }
86
87 httplib::Server svr;
88
89 if (std::filesystem::exists(www_path) &&
90 std::filesystem::is_directory(www_path)) {
92 bool mount_ok = svr.set_mount_point("/", www_path);
93 if (!mount_ok) {
94 Logger::error(
"Failed to mount static file directory: " + www_path);
95 return 1;
96 }
97 } else {
98 Logger::info(
"Static file directory not found: " + www_path +
99 ". Web client will not be served.");
100 }
101
102 svr.Post("/chat", [&session](const httplib::Request& req,
103 httplib::Response& res) {
105 res.set_header("Access-Control-Allow-Origin", "*");
106 res.set_header("Access-Control-Allow-Methods", "POST, OPTIONS");
107 res.set_header("Access-Control-Allow-Headers", "Content-Type");
108
109 std::string user_input_from_client;
110
111 float temperature = 0.1f;
112 int max_new_tokens = 60;
113 int top_k = 40;
114 float top_p = 0.9f;
115
116 try {
117 json req_json = json::parse(req.body);
118 if (req_json.contains("user_input")) {
119 user_input_from_client = req_json["user_input"].get<std::string>();
120 } else {
121 throw std::runtime_error("Missing 'user_input' field in request JSON");
122 }
123
124 if (req_json.contains("max_new_tokens"))
125 max_new_tokens = req_json["max_new_tokens"].get<int>();
126 if (req_json.contains("temperature"))
127 temperature = req_json["temperature"].get<float>();
128 if (req_json.contains("top_k"))
129 top_k = req_json["top_k"].get<int>();
130 if (req_json.contains("top_p"))
131 top_p = req_json["top_p"].get<float>();
132
134 user_input_from_client.substr(0, 100) + "...");
135
137 std::string prompt_for_session_generate;
138 bool use_q_a_format_for_session_generate = false;
139
140 const Tokenizer* tokenizer = session->get_tokenizer();
141
143 prompt_for_session_generate = user_input_from_client;
144
146 use_q_a_format_for_session_generate = false;
147 Logger::info(
"GGUF Llama 3 model detected (via tokenizer_family). Q&A prompt formatting will be DISABLED for session generate.");
148 } else {
149 use_q_a_format_for_session_generate = true;
151 "GGUF (Non-Llama 3) model detected. Using Q:A: format via session->generate.");
152 }
153 } else {
154 std::string system_prompt_text = "You are a helpful AI.";
155 if (tokenizer) {
157 user_input_from_client, system_prompt_text, config);
159 "Safetensors model detected. Applied chat template via "
160 "tokenizer. Prompt: " +
161 prompt_for_session_generate.substr(0, 200) + "...");
162 } else {
164 "CRITICAL: Tokenizer not available for Safetensors model in "
165 "server. Cannot apply chat template.");
166
167 prompt_for_session_generate = user_input_from_client;
168 }
169 use_q_a_format_for_session_generate = false;
170 }
171
172 std::string reply = session->generate(
173 prompt_for_session_generate, max_new_tokens, temperature, top_k, top_p, "",
174 use_q_a_format_for_session_generate);
175 Logger::info(
"Generated reply: " + reply.substr(0, 50) +
"...");
176
178 res_json["reply"] = reply;
179
180 res.set_content(res_json.dump(), "application/json");
182
183 } catch (const json::parse_error& e) {
184 Logger::error(
"JSON parsing error: " + std::string(e.what()));
185 res.status = 400;
187 err_json["error"] = "Invalid JSON format: " + std::string(e.what());
188 res.set_content(err_json.dump(), "application/json");
189 } catch (const std::exception& e) {
191 res.status = 500;
193 err_json["error"] = "Internal server error: " + std::string(e.what());
194 res.set_content(err_json.dump(), "application/json");
195 }
196 });
197
198 svr.Options("/chat", [](const httplib::Request& req, httplib::Response& res) {
199 res.set_header("Access-Control-Allow-Origin", "*");
200 res.set_header("Access-Control-Allow-Headers", "Content-Type");
201 res.set_header("Access-Control-Allow-Methods", "POST, OPTIONS");
202 res.status = 204;
203 });
204
205 unsigned int num_threads =
SAFE_MAX(1u, std::thread::hardware_concurrency() / 2);
206 Logger::info(
"Starting server on " + host +
":" + std::to_string(port) +
207 " with " + std::to_string(num_threads) + " threads.");
208
209 svr.listen(host.c_str(), port);
210
212 return 0;
213}
static void info(const std::string &message)
static void error(const std::string &message)
A lightweight tokenizer implementation for text processing.
std::string apply_chat_template(const std::string &user_prompt, const std::string &system_message, const ModelConfig &config) const
Applies chat template formatting to the input prompt.
Model configuration structure holding architecture and hyperparameters.
TokenizerFamily tokenizer_family