TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
server.cpp
Go to the documentation of this file.
1
28#ifdef _MSC_VER
29#pragma warning(push)
30#pragma warning(disable : 4244)
31#pragma warning(disable : 4267)
32
33#pragma warning(disable : 4996)
34#endif
35
36#include "httplib.h"
37
38#ifdef _MSC_VER
39#pragma warning(pop)
40#endif
41
42#include <filesystem>
43#include <memory>
44#include <nlohmann/json.hpp>
45#include <string>
46#include <thread>
47#include <vector>
48
49#include "api.h"
50#include "logger.h"
51#include "tokenizer.h"
52#include "model_macros.h"
53
54using json = nlohmann::json;
55
56int main(int argc, char** argv) {
57 std::string model_dir = "data";
58 std::string host = "localhost";
59 int port = 8080;
60 std::string www_path = "./www";
61
62 if (argc > 1) {
63 model_dir = argv[1];
64 }
65 if (argc > 2) {
66 port = std::stoi(argv[2]);
67 }
68 if (argc > 3) {
69 host = argv[3];
70 }
71 if (argc > 4) {
72 www_path = argv[4];
73 }
74
75 Logger::info("Starting TinyLlama Chat Server...");
76
77 std::shared_ptr<tinyllama::TinyLlamaSession> session;
78 try {
79 Logger::info("Loading model from: " + model_dir);
80 session = std::make_shared<tinyllama::TinyLlamaSession>(model_dir, "tokenizer.json", 4, -1, true);
81 Logger::info("Model loaded successfully.");
82 } catch (const std::exception& e) {
83 Logger::error(std::string("Failed to load model: ") + e.what());
84 return 1;
85 }
86
87 httplib::Server svr;
88
89 if (std::filesystem::exists(www_path) &&
90 std::filesystem::is_directory(www_path)) {
91 Logger::info("Serving static files from: " + www_path);
92 bool mount_ok = svr.set_mount_point("/", www_path);
93 if (!mount_ok) {
94 Logger::error("Failed to mount static file directory: " + www_path);
95 return 1;
96 }
97 } else {
98 Logger::info("Static file directory not found: " + www_path +
99 ". Web client will not be served.");
100 }
101
102 svr.Post("/chat", [&session](const httplib::Request& req,
103 httplib::Response& res) {
104 Logger::info("Received request for /chat");
105 res.set_header("Access-Control-Allow-Origin", "*");
106 res.set_header("Access-Control-Allow-Methods", "POST, OPTIONS");
107 res.set_header("Access-Control-Allow-Headers", "Content-Type");
108
109 std::string user_input_from_client;
110
111 float temperature = 0.1f; // Lower temperature for more focused chat responses
112 int max_new_tokens = 60;
113 int top_k = 40; // Default top-k value
114 float top_p = 0.9f; // Default top-p value
115
116 try {
117 json req_json = json::parse(req.body);
118 if (req_json.contains("user_input")) {
119 user_input_from_client = req_json["user_input"].get<std::string>();
120 } else {
121 throw std::runtime_error("Missing 'user_input' field in request JSON");
122 }
123
124 if (req_json.contains("max_new_tokens"))
125 max_new_tokens = req_json["max_new_tokens"].get<int>();
126 if (req_json.contains("temperature"))
127 temperature = req_json["temperature"].get<float>();
128 if (req_json.contains("top_k"))
129 top_k = req_json["top_k"].get<int>();
130 if (req_json.contains("top_p"))
131 top_p = req_json["top_p"].get<float>();
132
133 Logger::info("Processing user input: " +
134 user_input_from_client.substr(0, 100) + "...");
135
136 const ModelConfig& config = session->get_config();
137 std::string prompt_for_session_generate;
138 bool use_q_a_format_for_session_generate = false;
139
140 const Tokenizer* tokenizer = session->get_tokenizer();
141
142 if (config.is_gguf_file_loaded) {
143 prompt_for_session_generate = user_input_from_client;
144 // Check for Llama 3 tokenizer family to disable Q&A for it
146 use_q_a_format_for_session_generate = false;
147 Logger::info("GGUF Llama 3 model detected (via tokenizer_family). Q&A prompt formatting will be DISABLED for session generate.");
148 } else {
149 use_q_a_format_for_session_generate = true;
151 "GGUF (Non-Llama 3) model detected. Using Q:A: format via session->generate.");
152 }
153 } else {
154 std::string system_prompt_text = "You are a helpful AI.";
155 if (tokenizer) {
156 prompt_for_session_generate = tokenizer->apply_chat_template(
157 user_input_from_client, system_prompt_text, config);
159 "Safetensors model detected. Applied chat template via "
160 "tokenizer. Prompt: " +
161 prompt_for_session_generate.substr(0, 200) + "...");
162 } else {
164 "CRITICAL: Tokenizer not available for Safetensors model in "
165 "server. Cannot apply chat template.");
166
167 prompt_for_session_generate = user_input_from_client;
168 }
169 use_q_a_format_for_session_generate = false;
170 }
171
172 std::string reply = session->generate(
173 prompt_for_session_generate, max_new_tokens, temperature, top_k, top_p, "",
174 use_q_a_format_for_session_generate);
175 Logger::info("Generated reply: " + reply.substr(0, 50) + "...");
176
177 json res_json;
178 res_json["reply"] = reply;
179
180 res.set_content(res_json.dump(), "application/json");
181 Logger::info("Response sent successfully.");
182
183 } catch (const json::parse_error& e) {
184 Logger::error("JSON parsing error: " + std::string(e.what()));
185 res.status = 400;
186 json err_json;
187 err_json["error"] = "Invalid JSON format: " + std::string(e.what());
188 res.set_content(err_json.dump(), "application/json");
189 } catch (const std::exception& e) {
190 Logger::error("Generation error: " + std::string(e.what()));
191 res.status = 500;
192 json err_json;
193 err_json["error"] = "Internal server error: " + std::string(e.what());
194 res.set_content(err_json.dump(), "application/json");
195 }
196 });
197
198 svr.Options("/chat", [](const httplib::Request& req, httplib::Response& res) {
199 res.set_header("Access-Control-Allow-Origin", "*");
200 res.set_header("Access-Control-Allow-Headers", "Content-Type");
201 res.set_header("Access-Control-Allow-Methods", "POST, OPTIONS");
202 res.status = 204;
203 });
204
205 unsigned int num_threads = SAFE_MAX(1u, std::thread::hardware_concurrency() / 2);
206 Logger::info("Starting server on " + host + ":" + std::to_string(port) +
207 " with " + std::to_string(num_threads) + " threads.");
208
209 svr.listen(host.c_str(), port);
210
211 Logger::info("Server stopped.");
212 return 0;
213}
static void info(const std::string &message)
Definition logger.cpp:135
static void error(const std::string &message)
Definition logger.cpp:143
A lightweight tokenizer implementation for text processing.
Definition tokenizer.h:61
std::string apply_chat_template(const std::string &user_prompt, const std::string &system_message, const ModelConfig &config) const
Applies chat template formatting to the input prompt.
Logging utilities for the TinyLlama implementation.
#define SAFE_MAX(a, b)
int main(int argc, char **argv)
Definition server.cpp:56
nlohmann::json json
Definition server.cpp:54
Model configuration structure holding architecture and hyperparameters.
Definition model.h:80
bool is_gguf_file_loaded
Definition model.h:101
TokenizerFamily tokenizer_family
Definition model.h:117