TinyLlama.cpp 1.0
A lightweight C++ implementation of the TinyLlama language model
Loading...
Searching...
No Matches
Typedefs | Functions
server.cpp File Reference

HTTP server implementation for TinyLlama chat interface. More...

#include "httplib.h"
#include <filesystem>
#include <memory>
#include <nlohmann/json.hpp>
#include <string>
#include <thread>
#include <vector>
#include "api.h"
#include "logger.h"
#include "tokenizer.h"
#include "model_macros.h"
Include dependency graph for server.cpp:

Go to the source code of this file.

Typedefs

using json = nlohmann::json
 

Functions

int main (int argc, char **argv)
 

Detailed Description

HTTP server implementation for TinyLlama chat interface.

This server provides a REST API for interacting with TinyLlama models. It handles both GGUF and SafeTensors models, applying appropriate prompt formatting for each:

The server exposes a /chat endpoint that accepts POST requests with JSON body: { "user_input": "string", // Required: The prompt text "temperature": float, // Optional: Sampling temperature (default: 0.1) "max_new_tokens": int, // Optional: Max tokens to generate (default: 60) "top_k": int, // Optional: Top-K sampling parameter (default: 40) "top_p": float // Optional: Top-P sampling parameter (default: 0.9) }

Usage: tinyllama_server [model_path] [port] [host] [www_path] model_path: Path to model directory or .gguf file (default: data) port: Server port (default: 8080) host: Host to bind to (default: localhost) www_path: Path to static web files (default: ./www)

Definition in file server.cpp.

Typedef Documentation

◆ json

using json = nlohmann::json

Definition at line 54 of file server.cpp.

Function Documentation

◆ main()

int main ( int  argc,
char **  argv 
)

Definition at line 56 of file server.cpp.

56 {
57 std::string model_dir = "data";
58 std::string host = "localhost";
59 int port = 8080;
60 std::string www_path = "./www";
61
62 if (argc > 1) {
63 model_dir = argv[1];
64 }
65 if (argc > 2) {
66 port = std::stoi(argv[2]);
67 }
68 if (argc > 3) {
69 host = argv[3];
70 }
71 if (argc > 4) {
72 www_path = argv[4];
73 }
74
75 Logger::info("Starting TinyLlama Chat Server...");
76
77 std::shared_ptr<tinyllama::TinyLlamaSession> session;
78 try {
79 Logger::info("Loading model from: " + model_dir);
80 session = std::make_shared<tinyllama::TinyLlamaSession>(model_dir, "tokenizer.json", 4, -1, true);
81 Logger::info("Model loaded successfully.");
82 } catch (const std::exception& e) {
83 Logger::error(std::string("Failed to load model: ") + e.what());
84 return 1;
85 }
86
87 httplib::Server svr;
88
89 if (std::filesystem::exists(www_path) &&
90 std::filesystem::is_directory(www_path)) {
91 Logger::info("Serving static files from: " + www_path);
92 bool mount_ok = svr.set_mount_point("/", www_path);
93 if (!mount_ok) {
94 Logger::error("Failed to mount static file directory: " + www_path);
95 return 1;
96 }
97 } else {
98 Logger::info("Static file directory not found: " + www_path +
99 ". Web client will not be served.");
100 }
101
102 svr.Post("/chat", [&session](const httplib::Request& req,
103 httplib::Response& res) {
104 Logger::info("Received request for /chat");
105 res.set_header("Access-Control-Allow-Origin", "*");
106 res.set_header("Access-Control-Allow-Methods", "POST, OPTIONS");
107 res.set_header("Access-Control-Allow-Headers", "Content-Type");
108
109 std::string user_input_from_client;
110
111 float temperature = 0.1f; // Lower temperature for more focused chat responses
112 int max_new_tokens = 60;
113 int top_k = 40; // Default top-k value
114 float top_p = 0.9f; // Default top-p value
115
116 try {
117 json req_json = json::parse(req.body);
118 if (req_json.contains("user_input")) {
119 user_input_from_client = req_json["user_input"].get<std::string>();
120 } else {
121 throw std::runtime_error("Missing 'user_input' field in request JSON");
122 }
123
124 if (req_json.contains("max_new_tokens"))
125 max_new_tokens = req_json["max_new_tokens"].get<int>();
126 if (req_json.contains("temperature"))
127 temperature = req_json["temperature"].get<float>();
128 if (req_json.contains("top_k"))
129 top_k = req_json["top_k"].get<int>();
130 if (req_json.contains("top_p"))
131 top_p = req_json["top_p"].get<float>();
132
133 Logger::info("Processing user input: " +
134 user_input_from_client.substr(0, 100) + "...");
135
136 const ModelConfig& config = session->get_config();
137 std::string prompt_for_session_generate;
138 bool use_q_a_format_for_session_generate = false;
139
140 const Tokenizer* tokenizer = session->get_tokenizer();
141
142 if (config.is_gguf_file_loaded) {
143 prompt_for_session_generate = user_input_from_client;
144 // Check for Llama 3 tokenizer family to disable Q&A for it
146 use_q_a_format_for_session_generate = false;
147 Logger::info("GGUF Llama 3 model detected (via tokenizer_family). Q&A prompt formatting will be DISABLED for session generate.");
148 } else {
149 use_q_a_format_for_session_generate = true;
151 "GGUF (Non-Llama 3) model detected. Using Q:A: format via session->generate.");
152 }
153 } else {
154 std::string system_prompt_text = "You are a helpful AI.";
155 if (tokenizer) {
156 prompt_for_session_generate = tokenizer->apply_chat_template(
157 user_input_from_client, system_prompt_text, config);
159 "Safetensors model detected. Applied chat template via "
160 "tokenizer. Prompt: " +
161 prompt_for_session_generate.substr(0, 200) + "...");
162 } else {
164 "CRITICAL: Tokenizer not available for Safetensors model in "
165 "server. Cannot apply chat template.");
166
167 prompt_for_session_generate = user_input_from_client;
168 }
169 use_q_a_format_for_session_generate = false;
170 }
171
172 std::string reply = session->generate(
173 prompt_for_session_generate, max_new_tokens, temperature, top_k, top_p, "",
174 use_q_a_format_for_session_generate);
175 Logger::info("Generated reply: " + reply.substr(0, 50) + "...");
176
177 json res_json;
178 res_json["reply"] = reply;
179
180 res.set_content(res_json.dump(), "application/json");
181 Logger::info("Response sent successfully.");
182
183 } catch (const json::parse_error& e) {
184 Logger::error("JSON parsing error: " + std::string(e.what()));
185 res.status = 400;
186 json err_json;
187 err_json["error"] = "Invalid JSON format: " + std::string(e.what());
188 res.set_content(err_json.dump(), "application/json");
189 } catch (const std::exception& e) {
190 Logger::error("Generation error: " + std::string(e.what()));
191 res.status = 500;
192 json err_json;
193 err_json["error"] = "Internal server error: " + std::string(e.what());
194 res.set_content(err_json.dump(), "application/json");
195 }
196 });
197
198 svr.Options("/chat", [](const httplib::Request& req, httplib::Response& res) {
199 res.set_header("Access-Control-Allow-Origin", "*");
200 res.set_header("Access-Control-Allow-Headers", "Content-Type");
201 res.set_header("Access-Control-Allow-Methods", "POST, OPTIONS");
202 res.status = 204;
203 });
204
205 unsigned int num_threads = SAFE_MAX(1u, std::thread::hardware_concurrency() / 2);
206 Logger::info("Starting server on " + host + ":" + std::to_string(port) +
207 " with " + std::to_string(num_threads) + " threads.");
208
209 svr.listen(host.c_str(), port);
210
211 Logger::info("Server stopped.");
212 return 0;
213}
static void info(const std::string &message)
Definition logger.cpp:135
static void error(const std::string &message)
Definition logger.cpp:143
A lightweight tokenizer implementation for text processing.
Definition tokenizer.h:61
std::string apply_chat_template(const std::string &user_prompt, const std::string &system_message, const ModelConfig &config) const
Applies chat template formatting to the input prompt.
#define SAFE_MAX(a, b)
nlohmann::json json
Definition server.cpp:54
Model configuration structure holding architecture and hyperparameters.
Definition model.h:80
bool is_gguf_file_loaded
Definition model.h:101
TokenizerFamily tokenizer_family
Definition model.h:117

References Tokenizer::apply_chat_template(), Logger::error(), Logger::info(), ModelConfig::is_gguf_file_loaded, ModelConfig::LLAMA3_TIKTOKEN, SAFE_MAX, and ModelConfig::tokenizer_family.