Skip to content

Commit

Permalink
move npartitions parameter to correct function
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Dec 6, 2023
1 parent 186b71f commit d9399d8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
28 changes: 16 additions & 12 deletions crossfit/backend/torch/hf/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def __init__(
self.max_wait_seconds = max_wait_seconds
self.max_tokens = max_tokens

def __enter__(self):
self.inference_server = get_tgi_app_def(
self.runner = None
self.app_handle = None

def start(self):
inference_server = get_tgi_app_def(
self.path_or_name,
image_name=self.image_name,
image_version=self.image_version,
Expand All @@ -56,21 +59,21 @@ def __enter__(self):
self.runner = runner.get_runner()

self.app_handle = self.runner.run(
self.inference_server,
inference_server,
scheduler="local_docker",
)

self.status = self.runner.status(self.app_handle)

self.container_name = self.app_handle.split("/")[-1]
self.local_docker_client = self.runner._scheduler_instances["local_docker"]._docker_client
self.networked_containers = self.local_docker_client.networks.get("torchx").attrs[
container_name = self.app_handle.split("/")[-1]
local_docker_client = self.runner._scheduler_instances["local_docker"]._docker_client
networked_containers = local_docker_client.networks.get("torchx").attrs[
"Containers"
]

self.ip_address = None
for _, container_config in self.networked_containers.items():
if self.container_name in container_config["Name"]:
for _, container_config in networked_containers.items():
if container_name in container_config["Name"]:
self.ip_address = container_config["IPv4Address"].split("/")[0]
break
if not self.ip_address:
Expand All @@ -94,6 +97,9 @@ def __enter__(self):

logger.info(self.status)

def __enter__(self):
self.start()

return self

def __exit__(self, exc_type, exc_value, traceback):
Expand Down Expand Up @@ -123,10 +129,8 @@ def infer(self, data, col: Optional[str] = None):
output_col = "generated_text"
npartitions = getattr(data, "npartitions", self.num_gpus)
ddf = dask_cudf.from_cudf(
cudf.DataFrame(
{input_col: data, output_col: generated_text},
npartitions=npartitions,
)
cudf.DataFrame({input_col: data, output_col: generated_text}),
npartitions=npartitions,
)
return ddf

Expand Down
4 changes: 1 addition & 3 deletions examples/text_generation_with_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def parse_arguments():
parser.add_argument(
"--dataset", default="beir/fiqa", help="Dataset to load (default: beir/fiqa)"
)
parser.add_argument(
"--tiny-sample", default=True, action="store_true", help="Use tiny sample dataset"
)
parser.add_argument("--tiny-sample", default=True, action="store_true", help="Use tiny sample dataset")
parser.add_argument(
"--num-gpus", type=int, default=2, help="Number of GPUs to use (default: 1)"
)
Expand Down

0 comments on commit d9399d8

Please sign in to comment.