Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Headless operation for docker #408

Merged
merged 9 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,12 @@ async def _coding_assistant_caps(self):
log(f"Your refact-lsp version is deprecated, finetune is unavailable. Please update your plugin.")
return Response(content=json.dumps(self._caps_base_data(), indent=4), media_type="application/json")

async def _caps(self, authorization: str = Header(None)):
async def _caps(self, authorization: str = Header(None), user_agent: str = Header(None)):
if isinstance(user_agent, str):
m = re.match(r"^refact-lsp (\d+)\.(\d+)\.(\d+)$", user_agent)
if m:
major, minor, patch = map(int, m.groups())
log("user version %d.%d.%d" % (major, minor, patch))
data = self._caps_base_data()
running = running_models_and_loras(self._model_assigner)

Expand Down
4 changes: 2 additions & 2 deletions refact_webgui/webgui/webgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def setup_logger():
class CustomHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
if boring1.match(log_entry):
if boring1.search(log_entry):
return
if boring2.match(log_entry):
if boring2.search(log_entry):
return
sys.stderr.write(log_entry)
sys.stderr.write("\n")
Expand Down
25 changes: 17 additions & 8 deletions self_hosting_machinery/inference/inference_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from self_hosting_machinery.inference.lora_loader_mixin import LoraLoaderMixin


def log(*args):
logging.getLogger("MODEL").info(*args)


class InferenceEmbeddings(InferenceBase, LoraLoaderMixin):
def __init__(
self,
Expand All @@ -34,7 +38,7 @@ def __init__(
for local_files_only in [True, False]:
try:
# WARNING: this may not work if you have no access to the web as it may try to download tokenizer
logging.getLogger("MODEL").info("loading model local_files_only=%i" % local_files_only)
log("loading model local_files_only=%i" % local_files_only)
if local_files_only:
self._model = SentenceTransformer(
os.path.join(self.cache_dir, self._model_dir),
Expand Down Expand Up @@ -70,18 +74,23 @@ def model_dict(self) -> Dict[str, Any]:
def cache_dir(self) -> str:
return env.DIR_WEIGHTS

def infer(self, request: Dict[str, Any], upload_proxy: Any, upload_proxy_args: Dict, log=print):

def infer(self, request: Dict[str, Any], upload_proxy: Any, upload_proxy_args: Dict):
request_id = request["id"]
try:
inputs = request["inputs"]
B = len(inputs)
log("embeddings B=%d" % B)
upload_proxy_args["ts_prompt"] = time.time()
if request_id in upload_proxy.check_cancelled():
return

t0 = time.time()
files = {
"results": json.dumps(self._model.encode(request["inputs"]).tolist()),
"results": json.dumps(self._model.encode(inputs).tolist()),
}

log("/embeddings %0.3fs" % (time.time() - t0))
# 8 => 0.141s 0.023s
# 64 => 0.166s 0.060s
# 128 => 0.214s 0.120s *1024 => 1.600s
upload_proxy_args["ts_batch_finished"] = time.time()
finish_reason = 'DONE'
upload_proxy.upload_result(
Expand All @@ -94,5 +103,5 @@ def infer(self, request: Dict[str, Any], upload_proxy: Any, upload_proxy_args: D
)

except Exception as e: # noqa
logging.getLogger("MODEL").error(e)
logging.getLogger("MODEL").error(traceback.format_exc())
log(e)
log(traceback.format_exc())
5 changes: 3 additions & 2 deletions self_hosting_machinery/inference/inference_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def worker_loop(model_name: str, models_db: Dict[str, Any], supported_models: Di
dummy_call = {
'id': 'emb-legit-42',
'function': 'embeddings',
'inputs': "Common Knowledge",
'inputs': 128*["A"*8000], # max size validated at 9000 chars, 128 batch size
'created': time.time(),
}
else:
Expand Down Expand Up @@ -73,7 +73,8 @@ def check_cancelled(*args, **kwargs):
return set()

log("STATUS test batch")
inference_model.infer(dummy_call, DummyUploadProxy, {})
for _ in range(2):
inference_model.infer(dummy_call, DummyUploadProxy, {})
if compile:
return

Expand Down
10 changes: 6 additions & 4 deletions self_hosting_machinery/inference/stream_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def url_complain_doesnt_work():


def model_guid_allowed_characters(name):
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
return re.sub(r"[^a-zA-Z0-9_\.]", "_", name)


def validate_description_dict(
Expand Down Expand Up @@ -108,10 +108,12 @@ def completions_wait_batch(req_session: requests.Session, my_desc, verbose=False
if json_resp is None:
return "ERROR", []
t1 = time.time()
logger.info("%0.1fms %s %s" % (1000*(t1 - t0), url, termcolor.colored(json_resp.get("retcode", "no retcode"), "green")))
retcode = json_resp.get("retcode", "ERROR")
if retcode != "WAIT":
logger.info("%0.1fms %s %s" % (1000*(t1 - t0), url, termcolor.colored(retcode, "green")))
if verbose or "retcode" not in json_resp:
logger.warning("%s unrecognized json: %s" % (url, json.dumps(json_resp, indent=4)))
return json_resp.get("retcode", "ERROR"), json_resp.get("batch", [])
return retcode, json_resp.get("batch", [])


def head_and_tail(base: str, modified: str):
Expand Down Expand Up @@ -240,7 +242,7 @@ def upload_result(
progress[original_batch[b]["id"]] = tmp
upload_dict["progress"] = progress
upload_dict["check_cancelled"] = [call["id"] for call in original_batch]
upload_dict["model_name"] = description_dict["model"]
upload_dict["model_name"] = description_dict["model"].replace("/vllm", "")
self.upload_q.put(copy.deepcopy(upload_dict))
if DEBUG_UPLOAD_NOT_SEPARATE_PROCESS:
_upload_results_loop(self.upload_q, self.cancelled_q)
Expand Down
19 changes: 6 additions & 13 deletions self_hosting_machinery/scripts/first_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from refact_utils.scripts import env


def copy_watchdog_configs_if_first_run_detected(model_assigner: ModelAssigner):
def assign_gpus_if_first_run_detected(model_assigner: ModelAssigner):
if not os.path.exists(env.CONFIG_ENUM_GPUS):
enum_gpus.enum_gpus()
model_assigner.first_run()
model_assigner.first_run() # has models_to_watchdog_configs() inside


def convert_old_configs(model_assigner: ModelAssigner):
def convert_old_configs():
# longthink.cfg and openai_api_worker.cfg are deprecated watchdog configs
old_longthink = os.path.join(env.DIR_WATCHDOG_D, "longthink.cfg")
if os.path.exists(old_longthink):
Expand All @@ -20,16 +20,9 @@ def convert_old_configs(model_assigner: ModelAssigner):
if os.path.exists(openai_watchdog_cfg_fn):
os.unlink(openai_watchdog_cfg_fn)

for gpu in range(16):
fn = os.path.join(env.DIR_WATCHDOG_D, "model-gpu%d.cfg" % gpu)
if not os.path.exists(fn):
continue
text = open(fn).read()

model_assigner.models_to_watchdog_configs()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed if we remove something from models list (deprecated or something)



if __name__ == '__main__':
convert_old_configs()
model_assigner = ModelAssigner()
convert_old_configs(model_assigner)
copy_watchdog_configs_if_first_run_detected(model_assigner)
assign_gpus_if_first_run_detected(model_assigner)
model_assigner.models_to_watchdog_configs() # removes deprecated models
20 changes: 11 additions & 9 deletions self_hosting_machinery/watchdog/docker_watchdog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import time
import uuid
import psutil

from pathlib import Path

from typing import Dict, Optional, List

from refact_utils.scripts import env
Expand Down Expand Up @@ -251,6 +249,12 @@ def maybe_send_usr1(self, sigkill_timeout=30):
pass
try:
self.p.kill()
itself = psutil.Process(self.p.pid)
for child in itself.children(recursive=True):
try:
child.kill()
except psutil.NoSuchProcess:
pass
except psutil.NoSuchProcess:
pass

Expand Down Expand Up @@ -304,7 +308,6 @@ def __str__(self):
f" remove: {self.remove_this}\n" \
f" status: {self.status_from_stderr}\n"


tracked: Dict[str, TrackedJob] = {}
watchdog_templates = list(Path(env.DIR_WATCHDOG_TEMPLATES).iterdir())

Expand Down Expand Up @@ -450,12 +453,6 @@ def first_run():


def main_loop():
# Generate a random SMALLCLOUD_API_KEY, it will be inherited by subprocesses,
# this allows inference_worker to authorize on the local web server (both use
# this variable), and work safely even if we expose http port to the world.
os.environ["SMALLCLOUD_API_KEY"] = str(uuid.uuid4())

first_run()
while 1:
main_loop_body()
time.sleep(1)
Expand All @@ -468,4 +465,9 @@ def main_loop():


if __name__ == '__main__':
# Generate a random SMALLCLOUD_API_KEY, it will be inherited by subprocesses,
# this allows inference_worker to authorize on the local web server (both use
# this variable), and work safely even if we expose http port to the world.
os.environ["SMALLCLOUD_API_KEY"] = str(uuid.uuid4())
first_run()
main_loop()