Skip to content

Commit

Permalink
Fix test``
Browse files Browse the repository at this point in the history
  • Loading branch information
amritghimire committed Dec 10, 2024
1 parent b79d46f commit 8e3a5d8
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion tests/test_cli_studio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import MagicMock

import requests_mock
import websockets
from dvc_studio_client.auth import AuthorizationExpiredError
from tabulate import tabulate

Expand All @@ -8,6 +11,24 @@
from datachain.utils import STUDIO_URL


def mocked_connect(url, additional_headers):
async def mocked_recv():
raise websockets.exceptions.ConnectionClosed("Connection closed")

async def mocked_send(message):
pass

async def mocked_close():
pass

assert additional_headers == {"Authorization": "token isat_access_token"}
mocked_websocket = MagicMock()
mocked_websocket.recv = mocked_recv
mocked_websocket.send = mocked_send
mocked_websocket.close = mocked_close
return mocked_websocket


def test_studio_login_token_check_failed(mocker):
mocker.patch(
"dvc_studio_client.auth.get_access_token",
Expand Down Expand Up @@ -292,6 +313,9 @@ def test_studio_rm_dataset(capsys, mocker):


def test_studio_run(capsys, mocker, tmp_dir):
mocker.patch(
"datachain.remote.studio.websockets.connect", side_effect=mocked_connect
)
with Config(ConfigLevel.GLOBAL).edit() as conf:
conf["studio"] = {"token": "isat_access_token", "team": "team_name"}

Expand All @@ -301,6 +325,10 @@ def test_studio_run(capsys, mocker, tmp_dir):
f"{STUDIO_URL}/api/datachain/job",
json={"job": {"id": 1, "url": "https://example.com"}},
)
m.get(
f"{STUDIO_URL}/api/datachain/datasets/dataset_job_versions?job_id=1&team_name=team_name",
json={"dataset_versions": [{"dataset_name": "dataset_name", "version": 1}]},
)

(tmp_dir / "env_file.txt").write_text("ENV_FROM_FILE=1")
(tmp_dir / "reqs.txt").write_text("pyjokes")
Expand Down Expand Up @@ -333,7 +361,12 @@ def test_studio_run(capsys, mocker, tmp_dir):
)

out = capsys.readouterr().out
assert out.strip() == "Job 1 created\nOpen the job in Studio at https://example.com"
assert (
out.strip() == "Job 1 created\nOpen the job in Studio at https://example.com\n"
"========================================\n\n\n"
">>>> Dataset versions created during the job:\n"
" - dataset_name@v1"
)

first_request = m.request_history[0]
second_request = m.request_history[1]
Expand Down

0 comments on commit 8e3a5d8

Please sign in to comment.