Skip to content

Commit

Permalink
Fix merge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
javiermtorres committed Feb 20, 2025
1 parent 3dbaaed commit e68373e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 45 deletions.
61 changes: 17 additions & 44 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from io import BytesIO, StringIO
from pathlib import Path
from typing import Any
from urllib.parse import urljoin, urlparse
from urllib.parse import urljoin
from uuid import UUID

import loguru
Expand Down Expand Up @@ -118,39 +118,24 @@ def generate_config(self, request: JobCreate, record_id: UUID, dataset_path: str
output_field=request.job_config.output_field or "predictions",
),
)
model_parsed = urlparse(request.job_config.model)
# Maybe use just the protocol to decide?
if model_parsed.scheme == "oai":
job_config.inference_server = InferenceServerConfig(
base_url=_set_model_type(request),
engine=request.job_config.model,
# FIXME Inferences may not always be summarizations!
system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT,
max_retries=3,
)
job_config.params = SamplingParameters(
max_tokens=request.job_config.max_tokens,
frequency_penalty=request.job_config.frequency_penalty,
temperature=request.job_config.temperature,
top_p=request.job_config.top_p,
)
if model_parsed.scheme == "mistral":
job_config.inference_server = InferenceServerConfig(
base_url=_set_model_type(request),
engine=request.job_config.model,
system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT,
max_retries=3,
)
job_config.params = SamplingParameters(
max_tokens=request.job_config.max_tokens,
frequency_penalty=request.job_config.frequency_penalty,
temperature=request.job_config.temperature,
top_p=request.job_config.top_p,
if request.job_config.provider == "hf":
# Custom logic: if provider is hf, we run the hf model inside the ray job
job_config.hf_pipeline = HuggingFacePipelineConfig(
model_name_or_path=request.job_config.model,
task=request.job_config.task,
accelerator=request.job_config.accelerator,
revision=request.job_config.revision,
use_fast=request.job_config.use_fast,
trust_remote_code=request.job_config.trust_remote_code,
torch_dtype=request.job_config.torch_dtype,
max_new_tokens=500,
)
if model_parsed.scheme == "llamafile":
else:
# It will be a pass through to LiteLLM
job_config.inference_server = InferenceServerConfig(
base_url=_set_model_type(request),
engine=request.job_config.model,
base_url=request.job_config.base_url if request.job_config.base_url else None,
model=request.job_config.model,
provider=request.job_config.provider,
system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT,
max_retries=3,
)
Expand All @@ -160,18 +145,6 @@ def generate_config(self, request: JobCreate, record_id: UUID, dataset_path: str
temperature=request.job_config.temperature,
top_p=request.job_config.top_p,
)
else:
# Pending fix for apis
job_config.hf_pipeline = HuggingFacePipelineConfig(
model_uri=request.job_config.model,
task=request.job_config.task,
accelerator=request.job_config.accelerator,
revision=request.job_config.revision,
use_fast=request.job_config.use_fast,
trust_remote_code=request.job_config.trust_remote_code,
torch_dtype=request.job_config.torch_dtype,
max_new_tokens=500,
)
return job_config

def store_as_dataset(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion lumigator/jobs/inference/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SamplingParameters(BaseModel):


class HuggingFacePipelineConfig(BaseModel, arbitrary_types_allowed=True):
model_uri: str
model_name_or_path: str
revision: str
use_fast: bool
trust_remote_code: bool
Expand Down

0 comments on commit e68373e

Please sign in to comment.