Skip to content

Commit

Permalink
force oai endpoints to return json
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Oct 2, 2023
1 parent 0c47e79 commit 23b9d3a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens
self.path = self.path.rstrip('/')
response_body = None
force_json = False

if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without /
if args.stream and not "streaming=1" in self.path:
Expand Down Expand Up @@ -585,6 +586,7 @@ def do_GET(self):

elif self.path.endswith('/v1/models') or self.path.endswith('/models'):
response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
force_json = True

elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')):
response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode())
Expand All @@ -598,7 +600,7 @@ def do_GET(self):
else:
self.send_response(200)
self.send_header('Content-Length', str(len(response_body)))
self.end_headers()
self.end_headers(force_json=force_json)
self.wfile.write(response_body)
return

Expand All @@ -607,6 +609,7 @@ def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length)
self.path = self.path.rstrip('/')
force_json = False

if self.path.endswith(('/api/extra/tokencount')):
try:
Expand Down Expand Up @@ -686,6 +689,7 @@ def do_POST(self):

if self.path.endswith('/v1/completions') or self.path.endswith('/completions'):
api_format = 3
force_json = True

if api_format>0:
genparams = None
Expand All @@ -707,7 +711,7 @@ def do_POST(self):
# Headers are already sent when streaming
if not kai_sse_stream_flag:
self.send_response(200)
self.end_headers()
self.end_headers(force_json=force_json)
self.wfile.write(json.dumps(gen).encode())
except:
print("Generate: The response could not be sent, maybe connection was terminated?")
Expand All @@ -728,11 +732,11 @@ def do_HEAD(self):
self.send_response(200)
self.end_headers()

def end_headers(self):
def end_headers(self, force_json=False):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*')
if "/api" in self.path:
if "/api" in self.path or force_json:
if self.path.endswith("/stream"):
self.send_header('Content-type', 'text/event-stream')
self.send_header('Content-type', 'application/json')
Expand Down

0 comments on commit 23b9d3a

Please sign in to comment.