From e3d5bae222498bbe56f7d11203efb42802c7fe70 Mon Sep 17 00:00:00 2001 From: Amrit Ghimire Date: Tue, 10 Dec 2024 02:11:32 +0545 Subject: [PATCH] Add log streaming to the created job When running a job from studio, this adds a log streaming which checks for the updates from the Studio. Relate PR: https://github.com/iterative/studio/pull/11068 --- pyproject.toml | 3 ++- src/datachain/remote/studio.py | 45 +++++++++++++++++++++++++++++++++- src/datachain/studio.py | 29 +++++++++++++++++++++- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 37604bbba..58aea130f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,8 @@ dependencies = [ "iterative-telemetry>=0.0.9", "platformdirs", "dvc-studio-client>=0.21,<1", - "tabulate" + "tabulate", + "websockets" ] [project.optional-dependencies] diff --git a/src/datachain/remote/studio.py b/src/datachain/remote/studio.py index d98acdedc..c0da8bf35 100644 --- a/src/datachain/remote/studio.py +++ b/src/datachain/remote/studio.py @@ -2,7 +2,7 @@ import json import logging import os -from collections.abc import Iterable, Iterator +from collections.abc import AsyncIterator, Iterable, Iterator from datetime import datetime, timedelta, timezone from struct import unpack from typing import ( @@ -11,6 +11,9 @@ Optional, TypeVar, ) +from urllib.parse import urlparse, urlunparse + +import websockets from datachain.config import Config from datachain.dataset import DatasetStats @@ -22,6 +25,7 @@ DatasetInfoData = Optional[dict[str, Any]] DatasetStatsData = Optional[DatasetStats] DatasetRowsData = Optional[Iterable[dict[str, Any]]] +DatasetJobVersionsData = Optional[dict[str, Any]] DatasetExportStatus = Optional[dict[str, Any]] DatasetExportSignedUrls = Optional[list[str]] FileUploadData = Optional[dict[str, Any]] @@ -231,6 +235,38 @@ def _unpacker_hook(code, data): return msgpack.ExtType(code, data) + async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]: + """ + Follow job logs via websocket connection. + + Args: + job_id: ID of the job to follow logs for + + Yields: + Dict containing either job status updates or log messages + """ + parsed_url = urlparse(self.url) + ws_url = urlunparse(parsed_url._replace(scheme="ws")) + ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}" + + async with websockets.connect( + ws_url, + additional_headers={"Authorization": f"token {self.token}"}, + ) as websocket: + while True: + try: + message = await websocket.recv() + data = json.loads(message) + + # Yield the parsed message data + yield data + + except websockets.exceptions.ConnectionClosed: + break + except Exception as e: # noqa: BLE001 + logger.error("Error receiving websocket message: %s", e) + break + def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]: # TODO: change LsData (response.data value) to be list of lists # to handle cases where a path will be expanded (i.e. globs) @@ -302,6 +338,13 @@ def dataset_rows_chunk( method="GET", ) + def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]: + return self._send_request( + "datachain/datasets/dataset_job_versions", + {"job_id": job_id}, + method="GET", + ) + def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]: response = self._send_request( "datachain/datasets/stats", diff --git a/src/datachain/studio.py b/src/datachain/studio.py index e3b250e3d..73d181f3d 100644 --- a/src/datachain/studio.py +++ b/src/datachain/studio.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import TYPE_CHECKING, Optional @@ -227,8 +228,34 @@ def create_job( if not response.data: raise DataChainError("Failed to create job") - print(f"Job {response.data.get('job', {}).get('id')} created") + job_id = response.data.get("job", {}).get("id") + print(f"Job {job_id} created") print("Open the job in Studio at", response.data.get("job", {}).get("url")) + print("=" * 40) + + # Sync usage + async def _run(): + async for message in client.tail_job_logs(job_id): + if "logs" in message: + for log in message["logs"]: + print(log["message"], end="") + elif "job" in message: + print(f"\n>>>> Job is now in {message['job']['status']} status.") + + asyncio.run(_run()) + + response = client.dataset_job_versions(job_id) + if not response.ok: + raise_remote_error(response.message) + + response_data = response.data + if response_data: + dataset_versions = response_data.get("dataset_versions", []) + print("\n\n>>>> Dataset versions created during the job:") + for version in dataset_versions: + print(f" - {version.get('dataset_name')}@v{version.get('version')}") + else: + print("No dataset versions created during the job.") def upload_files(client: StudioClient, files: list[str]) -> list[str]: