Skip to content

Commit 0635dc0

Browse files
committed
Implement server mode.
This new mode works by first loading the model then listening for TCP connections on a port. When a connection is received, arguments will be parsed using a simple protocol: - First the number of arguments will be read followed by a newline character. - Then each argument will be read, separated by the 0 byte. - With this we build an argument vector, similar to what is passed to the program entry point. We pass this to gpt_params_parse. Finally `run` will be executed with the input/output streams connected to the socket. Signed-off-by: Thiago Padilha <thiago@padilha.cc>
1 parent 19fa30a commit 0635dc0

9 files changed

+331
-2
lines changed

CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ add_executable(main
244244
run.cpp)
245245
target_link_libraries(main PRIVATE llama ggml utils)
246246

247+
if(NOT WIN32)
248+
target_sources(main PRIVATE tcp_server.cpp)
249+
endif()
250+
247251
add_executable(quantize quantize.cpp)
248252
target_link_libraries(quantize PRIVATE llama ggml utils)
249253

Makefile

+5-2
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,14 @@ utils.o: utils.cpp utils.h
229229
run.o: run.cpp run.h
230230
$(CXX) $(CXXFLAGS) -c run.cpp -o run.o
231231

232+
tcp_server.o: tcp_server.cpp tcp_server.h
233+
$(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o
234+
232235
clean:
233236
rm -f *.o main quantize
234237

235-
main: main.cpp ggml.o llama.o utils.o run.o
236-
$(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o -o main $(LDFLAGS)
238+
main: main.cpp ggml.o llama.o utils.o run.o tcp_server.o
239+
$(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o tcp_server.o -o main $(LDFLAGS)
237240
@echo "\x1b[36mrun ./main -h for help\x1b[0m"
238241

239242
quantize: quantize.cpp ggml.o llama.o utils.o

chat_tcp_client.sh

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env bash
2+
3+
PORT=${PORT:-8080}
4+
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
5+
6+
User:Hello, Bob.
7+
Bob:Hello. How may I help you today?
8+
User:Please tell me the largest city in Europe.
9+
Bob:Sure. The largest city in Europe is Moscow, the capital of Russia.
10+
User:"}"
11+
RPROMPT="${RPROMPT:-"User:"}"
12+
N_PREDICT="${N_PREDICT:-"4096"}"
13+
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}"
14+
N_THREADS="${N_THREADS:-"4"}"
15+
16+
# Open connection to the chat server
17+
exec 3<>/dev/tcp/127.0.0.1/${PORT}
18+
19+
# Pass the arguments. The protocol is really simple:
20+
# 1. Pass the number of arguments followed by a linefeed
21+
# 2. Pass the arguments, with each being followed by "0"
22+
(
23+
echo -en "12\n"
24+
echo -en "-t\x00"
25+
echo -en "$N_THREADS\x00"
26+
echo -en "-n\x00"
27+
echo -en "$N_PREDICT\x00"
28+
echo -en "--repeat_penalty\x00"
29+
echo -en "$REPEAT_PENALTY\x00"
30+
echo -en "--color\x00"
31+
echo -en "-i\x00"
32+
echo -en "-r\x00"
33+
echo -en "$RPROMPT\x00"
34+
echo -en "-p\x00"
35+
echo -en "$PROMPT\x00"
36+
) >&3
37+
38+
trap exit TERM
39+
40+
# When we have passed the arguments, start printing socket data to the screen.
41+
# This is done in a background job because we also want to send data when
42+
# running in interactive mode.
43+
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" &
44+
cat >&3
45+
wait

chat_tcp_server.sh

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env bash
2+
3+
PORT=${PORT:-8080}
4+
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin}
5+
6+
./main -l ${PORT} -m $MODEL

main.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "run.h"
22
#include "ggml.h"
3+
#include "tcp_server.h"
34

45
#include <iostream>
56

@@ -125,5 +126,11 @@ int main(int argc, char ** argv) {
125126
exit(0);
126127
}
127128

129+
#ifndef _WIN32
130+
if (params.listen_port != "") {
131+
return listen_tcp(ctx, params);
132+
}
133+
#endif
134+
128135
return run(ctx, params, std::cin, stdout, stderr);
129136
}

tcp_server.cpp

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#include "tcp_server.h"
2+
#include "llama.h"
3+
#include "utils.h"
4+
5+
#include <iostream>
6+
7+
#include <stdarg.h>
8+
#include <stdio.h>
9+
#include <stdlib.h>
10+
#include <stdbool.h>
11+
#include <string.h>
12+
#include <errno.h>
13+
14+
#include <signal.h>
15+
#include <unistd.h>
16+
#include <sys/wait.h>
17+
18+
#include <sys/types.h>
19+
#include <sys/socket.h>
20+
#include <arpa/inet.h>
21+
#include <netdb.h>
22+
23+
class PosixStream : public std::istream {
24+
public:
25+
PosixStream(int fd) : std::istream(&buf), buf(fd) {}
26+
~PosixStream() { close(buf.get_fd()); }
27+
28+
private:
29+
class PosixStreamBuf : public std::streambuf {
30+
public:
31+
PosixStreamBuf(int fd) : fd(fd) {}
32+
int get_fd() const { return fd; }
33+
34+
protected:
35+
virtual int_type underflow() {
36+
if (gptr() < egptr()) {
37+
return traits_type::to_int_type(*gptr());
38+
}
39+
40+
ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE);
41+
if (num_read <= 0) {
42+
return traits_type::eof();
43+
}
44+
45+
setg(buffer, buffer, buffer + num_read);
46+
return traits_type::to_int_type(*gptr());
47+
}
48+
49+
private:
50+
static const int BUFFER_SIZE = 1024;
51+
int fd;
52+
char buffer[BUFFER_SIZE];
53+
};
54+
55+
PosixStreamBuf buf;
56+
};
57+
58+
void die(const char *msg, ...)
59+
{
60+
va_list ap;
61+
62+
va_start(ap, msg);
63+
vfprintf(stderr, msg, ap);
64+
va_end(ap);
65+
fputc('\n', stderr);
66+
exit(1);
67+
}
68+
69+
static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) {
70+
bool done = false;
71+
uint8_t *buf = *param_buf;
72+
size_t bufsize = *param_buf_size;
73+
size_t bufpos = 0;
74+
while (!done) {
75+
if (bufpos == bufsize) {
76+
bufsize += 1024;
77+
buf = (uint8_t *)realloc(buf, bufsize);
78+
if (!buf) {
79+
die("failed to allocate memory");
80+
}
81+
}
82+
83+
int c = fgetc(instream);
84+
if (c == EOF) {
85+
die("unexpected EOF client socket");
86+
}
87+
buf[bufpos++] = (uint8_t)c;
88+
if (c == 0) {
89+
// done reading argument
90+
break;
91+
}
92+
}
93+
*param_buf = buf;
94+
*param_buf_size = bufsize;
95+
return strdup((char *)buf);
96+
}
97+
98+
static int read_arguments(int argc, char **argv, FILE *instream) {
99+
int i = 1;
100+
size_t param_buf_size = 0;
101+
uint8_t *param_buf = nullptr;
102+
103+
for (i = 1; i < argc; i++) {
104+
argv[i] = read_argument(&param_buf, &param_buf_size, instream);
105+
}
106+
107+
free(param_buf);
108+
return i;
109+
}
110+
111+
static int serve_model(llama_context * ctx,
112+
gpt_params params,
113+
int sock_fd)
114+
{
115+
int argc;
116+
char **argv;
117+
FILE *instream = fdopen(sock_fd, "r");
118+
FILE *outstream = fdopen(sock_fd, "w");
119+
setvbuf(instream, NULL, _IONBF, 0);
120+
121+
// start by reading the parameter count
122+
if (fscanf(instream, "%d\n", &argc) != 1) {
123+
fprintf(outstream, "Error: First line must be character count\n");
124+
fflush(outstream);
125+
return 1;
126+
}
127+
128+
argc += 1; // add one extra argument to emulate the program command line
129+
argv = (char **)malloc(argc * sizeof *argv);
130+
argv[0] = nullptr;
131+
if (read_arguments(argc, argv, instream) != argc) {
132+
fprintf(outstream, "Error: Failed to read arguments\n");
133+
fflush(outstream);
134+
}
135+
136+
if (gpt_params_parse(argc, argv, params) == false) {
137+
fprintf(outstream, "Error: Failed to parse parameters\n");
138+
fflush(outstream);
139+
return 1;
140+
}
141+
142+
for (int i = 1; i < argc; i++) {
143+
free(argv[i]);
144+
}
145+
free(argv);
146+
147+
PosixStream tcp_instream(sock_fd);
148+
149+
return run(ctx, params, tcp_instream, outstream, outstream);
150+
}
151+
152+
int listen_tcp(llama_context * ctx, gpt_params params) {
153+
int listen_fd;
154+
int status;
155+
pid_t child;
156+
struct addrinfo hints;
157+
struct addrinfo *servinfo, *p;
158+
int yes = 1;
159+
160+
memset(&hints, 0, sizeof hints);
161+
hints.ai_family = AF_INET;
162+
hints.ai_socktype = SOCK_STREAM;
163+
hints.ai_flags = AI_PASSIVE;
164+
165+
// This should only ever listen on a loopback address. Access from outside
166+
// should be proxied via socat or similar software
167+
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo);
168+
if (status) {
169+
die("getaddrinfo error: %s", gai_strerror(status));
170+
}
171+
172+
// bind to the first addrinfo we can from the getaddrinfo results
173+
for (p = servinfo; p != NULL; p = p->ai_next) {
174+
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
175+
if (listen_fd == -1) {
176+
perror("server: socket");
177+
continue;
178+
}
179+
180+
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &yes, sizeof yes)) {
181+
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno));
182+
}
183+
184+
if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) {
185+
struct sockaddr_in addr_in;
186+
socklen_t addr_in_len = sizeof(addr_in);
187+
memset(&addr_in, 0, addr_in_len);
188+
getsockname(listen_fd, (struct sockaddr*)&addr_in, &addr_in_len);
189+
190+
printf("Listening on %s:%d\n", inet_ntoa(addr_in.sin_addr), ntohs(addr_in.sin_port));
191+
break;
192+
}
193+
194+
close(listen_fd);
195+
perror("server: bind");
196+
}
197+
198+
freeaddrinfo(servinfo);
199+
200+
if (p == NULL) {
201+
die("failed to bind: %s", strerror(errno));
202+
}
203+
204+
if (listen(listen_fd, 20)) {
205+
die("listen error: %s", strerror(errno));
206+
}
207+
// Don't track child processes, so ignore SIGCHLD to prevent zombies
208+
signal(SIGCHLD, SIG_IGN);
209+
210+
for (;;) {
211+
struct sockaddr_in client_addr;
212+
socklen_t client_addr_len = 0;
213+
memset(&client_addr, 0, sizeof(client_addr));
214+
215+
int sock_fd = accept(listen_fd,
216+
(struct sockaddr *)&client_addr,
217+
&client_addr_len);
218+
if (sock_fd < 0) {
219+
fprintf(stderr, "accept error: %s\n", strerror(errno));
220+
break;
221+
}
222+
223+
child = fork();
224+
if (child == 0) {
225+
// close the listen_fd since we won't use it in the child
226+
close(listen_fd);
227+
int ret = serve_model(ctx, params, sock_fd);
228+
close(sock_fd);
229+
return ret;
230+
} else {
231+
// close the client since we won't use it in the server
232+
close(sock_fd);
233+
sock_fd = 0;
234+
}
235+
}
236+
close(listen_fd);
237+
238+
// ignore SIGTERM since we'll send it to the group
239+
signal(SIGTERM, SIG_IGN);
240+
// tell children to exit
241+
kill(0, SIGTERM);
242+
// wait for children to terminate
243+
wait(&status);
244+
return 0;
245+
}

tcp_server.h

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "utils.h"
4+
#include "llama.h"
5+
#include "run.h"
6+
7+
int listen_tcp(llama_context * ctx, gpt_params params);

utils.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
7575
params.ignore_eos = true;
7676
} else if (arg == "--n_parts") {
7777
params.n_parts = std::stoi(argv[++i]);
78+
#ifndef _WIN32
79+
} else if (arg == "-l" || arg == "--listen") {
80+
params.listen_port = argv[++i];
81+
#endif
7882
} else if (arg == "-h" || arg == "--help") {
7983
gpt_print_usage(argc, argv, params);
8084
exit(0);
@@ -122,6 +126,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
122126
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
123127
fprintf(stderr, " -m FNAME, --model FNAME\n");
124128
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
129+
#ifndef _WIN32
130+
fprintf(stderr, " -l PORT, --listen PORT\n");
131+
fprintf(stderr, " Run in TCP mode, listening on PORT\n");
132+
#endif
125133
fprintf(stderr, "\n");
126134
}
127135

utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ struct gpt_params {
4242
bool instruct = false; // instruction mode (used for Alpaca models)
4343
bool ignore_eos = false; // do not stop generating after eos
4444
bool perplexity = false; // compute perplexity over the prompt
45+
46+
#ifndef _WIN32
47+
std::string listen_port = ""; // TCP port for when running in server mode
48+
#endif
4549
};
4650

4751
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

0 commit comments

Comments
 (0)