Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add log streaming to the created job #680

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ dependencies = [
"iterative-telemetry>=0.0.9",
"platformdirs",
"dvc-studio-client>=0.21,<1",
"tabulate"
"tabulate",
"websockets"
]

[project.optional-dependencies]
Expand Down
45 changes: 44 additions & 1 deletion src/datachain/remote/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import os
from collections.abc import Iterable, Iterator
from collections.abc import AsyncIterator, Iterable, Iterator
from datetime import datetime, timedelta, timezone
from struct import unpack
from typing import (
Expand All @@ -11,6 +11,9 @@
Optional,
TypeVar,
)
from urllib.parse import urlparse, urlunparse

import websockets

from datachain.config import Config
from datachain.dataset import DatasetStats
Expand All @@ -22,6 +25,7 @@
DatasetInfoData = Optional[dict[str, Any]]
DatasetStatsData = Optional[DatasetStats]
DatasetRowsData = Optional[Iterable[dict[str, Any]]]
DatasetJobVersionsData = Optional[dict[str, Any]]
DatasetExportStatus = Optional[dict[str, Any]]
DatasetExportSignedUrls = Optional[list[str]]
FileUploadData = Optional[dict[str, Any]]
Expand Down Expand Up @@ -231,6 +235,38 @@

return msgpack.ExtType(code, data)

async def tail_job_logs(self, job_id: str) -> AsyncIterator[dict]:
"""
Follow job logs via websocket connection.

Args:
job_id: ID of the job to follow logs for

Yields:
Dict containing either job status updates or log messages
"""
parsed_url = urlparse(self.url)
ws_url = urlunparse(parsed_url._replace(scheme="ws"))
amritghimire marked this conversation as resolved.
Show resolved Hide resolved
ws_url = f"{ws_url}/logs/follow/?job_id={job_id}&team_name={self.team}"
0x2b3bfa0 marked this conversation as resolved.
Show resolved Hide resolved

async with websockets.connect(
ws_url,
additional_headers={"Authorization": f"token {self.token}"},
) as websocket:
while True:
try:
message = await websocket.recv()
data = json.loads(message)

# Yield the parsed message data
yield data

Check warning on line 262 in src/datachain/remote/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/remote/studio.py#L262

Added line #L262 was not covered by tests

except websockets.exceptions.ConnectionClosed:
break

Check warning on line 265 in src/datachain/remote/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/remote/studio.py#L265

Added line #L265 was not covered by tests
except Exception as e: # noqa: BLE001
logger.error("Error receiving websocket message: %s", e)
break

def ls(self, paths: Iterable[str]) -> Iterator[tuple[str, Response[LsData]]]:
# TODO: change LsData (response.data value) to be list of lists
# to handle cases where a path will be expanded (i.e. globs)
Expand Down Expand Up @@ -302,6 +338,13 @@
method="GET",
)

def dataset_job_versions(self, job_id: str) -> Response[DatasetJobVersionsData]:
return self._send_request(
"datachain/datasets/dataset_job_versions",
{"job_id": job_id},
method="GET",
)

def dataset_stats(self, name: str, version: int) -> Response[DatasetStatsData]:
response = self._send_request(
"datachain/datasets/stats",
Expand Down
29 changes: 28 additions & 1 deletion src/datachain/studio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -230,8 +231,34 @@
if not response.data:
raise DataChainError("Failed to create job")

print(f"Job {response.data.get('job', {}).get('id')} created")
job_id = response.data.get("job", {}).get("id")
print(f"Job {job_id} created")
print("Open the job in Studio at", response.data.get("job", {}).get("url"))
print("=" * 40)

# Sync usage
async def _run():
async for message in client.tail_job_logs(job_id):
if "logs" in message:
for log in message["logs"]:
print(log["message"], end="")

Check warning on line 244 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L244

Added line #L244 was not covered by tests
elif "job" in message:
print(f"\n>>>> Job is now in {message['job']['status']} status.")

Check warning on line 246 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L246

Added line #L246 was not covered by tests

asyncio.run(_run())

response = client.dataset_job_versions(job_id)
if not response.ok:
raise_remote_error(response.message)

Check warning on line 252 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L252

Added line #L252 was not covered by tests

response_data = response.data
if response_data:
dataset_versions = response_data.get("dataset_versions", [])
print("\n\n>>>> Dataset versions created during the job:")
for version in dataset_versions:
print(f" - {version.get('dataset_name')}@v{version.get('version')}")
else:
print("No dataset versions created during the job.")

Check warning on line 261 in src/datachain/studio.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/studio.py#L261

Added line #L261 was not covered by tests


def upload_files(client: StudioClient, files: list[str]) -> list[str]:
Expand Down
35 changes: 34 additions & 1 deletion tests/test_cli_studio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import MagicMock

import requests_mock
import websockets
from dvc_studio_client.auth import AuthorizationExpiredError
from tabulate import tabulate

Expand All @@ -8,6 +11,24 @@
from datachain.utils import STUDIO_URL


def mocked_connect(url, additional_headers):
async def mocked_recv():
raise websockets.exceptions.ConnectionClosed("Connection closed")

async def mocked_send(message):
pass

async def mocked_close():
pass

assert additional_headers == {"Authorization": "token isat_access_token"}
mocked_websocket = MagicMock()
mocked_websocket.recv = mocked_recv
mocked_websocket.send = mocked_send
mocked_websocket.close = mocked_close
return mocked_websocket


def test_studio_login_token_check_failed(mocker):
mocker.patch(
"dvc_studio_client.auth.get_access_token",
Expand Down Expand Up @@ -310,6 +331,9 @@ def test_studio_cancel_job(capsys, mocker):


def test_studio_run(capsys, mocker, tmp_dir):
mocker.patch(
"datachain.remote.studio.websockets.connect", side_effect=mocked_connect
)
with Config(ConfigLevel.GLOBAL).edit() as conf:
conf["studio"] = {"token": "isat_access_token", "team": "team_name"}

Expand All @@ -319,6 +343,10 @@ def test_studio_run(capsys, mocker, tmp_dir):
f"{STUDIO_URL}/api/datachain/job",
json={"job": {"id": 1, "url": "https://example.com"}},
)
m.get(
f"{STUDIO_URL}/api/datachain/datasets/dataset_job_versions?job_id=1&team_name=team_name",
json={"dataset_versions": [{"dataset_name": "dataset_name", "version": 1}]},
)

(tmp_dir / "env_file.txt").write_text("ENV_FROM_FILE=1")
(tmp_dir / "reqs.txt").write_text("pyjokes")
Expand Down Expand Up @@ -351,7 +379,12 @@ def test_studio_run(capsys, mocker, tmp_dir):
)

out = capsys.readouterr().out
assert out.strip() == "Job 1 created\nOpen the job in Studio at https://example.com"
assert (
out.strip() == "Job 1 created\nOpen the job in Studio at https://example.com\n"
"========================================\n\n\n"
">>>> Dataset versions created during the job:\n"
" - dataset_name@v1"
)

first_request = m.request_history[0]
second_request = m.request_history[1]
Expand Down
Loading