Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 843 | Files: download files saved during Program execution #896

Merged
merged 12 commits into from
Sep 1, 2023
92 changes: 92 additions & 0 deletions client/quantum_serverless/core/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# This code is a Qiskit project.
#
# (C) Copyright IBM 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.


"""
===============================================
Provider (:mod:`quantum_serverless.core.files`)
==============================================

.. currentmodule:: quantum_serverless.core.files

Quantum serverless files
========================

.. autosummary::
:toctree: ../stubs/

"""
import os.path
import uuid
from typing import List, Optional

import requests
from opentelemetry import trace
from tqdm import tqdm

from quantum_serverless.core.constants import REQUESTS_TIMEOUT
from quantum_serverless.utils.json import safe_json_request


class GatewayFilesClient:
"""GatewayFilesClient."""

def __init__(self, host: str, token: str, version: str):
"""Files client for Gateway service.

Args:
host: gateway host
version: gateway version
token: authorization token
"""
self.host = host
self.version = version
self._token = token

def download(self, file: str, directory: str) -> Optional[str]:
"""Downloads file."""
tracer = trace.get_tracer("client.tracer")
psschwei marked this conversation as resolved.
Show resolved Hide resolved
with tracer.start_as_current_span("files.download"):
with requests.get(
f"{self.host}/api/{self.version}/files/download/",
params={"file": file},
stream=True,
headers={"Authorization": f"Bearer {self._token}"},
timeout=REQUESTS_TIMEOUT,
) as req:
req.raise_for_status()

total_size_in_bytes = int(req.headers.get("content-length", 0))
chunk_size = 8192
progress_bar = tqdm(
total=total_size_in_bytes, unit="iB", unit_scale=True
)
file_name = f"downloaded_{str(uuid.uuid4())[:8]}_{file}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this download into an unique file name? Avoiding override or appending?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to avoid overriding and we also return file name too, if user want to do something with it programmatically

with open(os.path.join(directory, file_name), "wb") as f:
for chunk in req.iter_content(chunk_size=chunk_size):
progress_bar.update(len(chunk))
f.write(chunk)
progress_bar.close()
return file_name

def list(self) -> List[str]:
"""Returns list of available files to download produced by programs,"""
tracer = trace.get_tracer("client.tracer")
with tracer.start_as_current_span("files.list"):
response_data = safe_json_request(
request=lambda: requests.get(
f"{self.host}/api/{self.version}/files/",
headers={"Authorization": f"Bearer {self._token}"},
timeout=REQUESTS_TIMEOUT,
)
)
return response_data.get("results", [])
16 changes: 16 additions & 0 deletions client/quantum_serverless/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ENV_GATEWAY_PROVIDER_TOKEN,
GATEWAY_PROVIDER_VERSION_DEFAULT,
)
from quantum_serverless.core.files import GatewayFilesClient
from quantum_serverless.core.job import (
Job,
RayJobClient,
Expand Down Expand Up @@ -286,6 +287,14 @@ def run(self, program: Program, arguments: Optional[Dict[str, Any]] = None) -> J

return job_client.run(program, arguments)

def files(self) -> List[str]:
"""Returns list of available files produced by programs to download."""
raise NotImplementedError

def download(self, file: str, directory: str):
"""Download file."""
raise NotImplementedError

def widget(self):
"""Widget for information about provider and jobs."""
return Widget(self).show()
Expand Down Expand Up @@ -343,6 +352,7 @@ def __init__(
self._fetch_token(username, password)

self._job_client = GatewayJobClient(self.host, self._token, self.version)
self._files_client = GatewayFilesClient(self.host, self._token, self.version)

def get_compute_resources(self) -> List[ComputeResource]:
raise NotImplementedError("GatewayProvider does not support resources api yet.")
Expand All @@ -365,6 +375,12 @@ def run(self, program: Program, arguments: Optional[Dict[str, Any]] = None) -> J
def get_jobs(self, **kwargs) -> List[Job]:
return self._job_client.list(**kwargs)

def files(self) -> List[str]:
return self._files_client.list()

def download(self, file: str, directory: str = "./"):
return self._files_client.download(file, directory)

def _fetch_token(self, username: str, password: str):
response_data = safe_json_request(
request=lambda: requests.post(
Expand Down
25 changes: 25 additions & 0 deletions client/quantum_serverless/quantum_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,31 @@ def get_jobs(self, **kwargs):
"""
return self._selected_provider.get_jobs(**kwargs)

def files(self):
"""Returns list of available files to download.

Example:
>>> serverless = QuantumServerless()
>>> serverless.files()

Returns:
list of available files
"""
return self._selected_provider.files()

def download(self, file: str, directory: str = "./"):
psschwei marked this conversation as resolved.
Show resolved Hide resolved
"""Downloads file.
IceKhan13 marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> serverless = QuantumServerless()
>>> serverless.download('artifact.tar', directory="./")

Args:
file: name of file to download
directory: destination directory. Default: current directory
"""
return self._selected_provider.download(file, directory)

def context(
self,
provider: Optional[Union[str, BaseProvider]] = None,
Expand Down
1 change: 1 addition & 0 deletions client/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ qiskit>=0.44.0
qiskit-ibm-runtime>=0.11.3
redis>=4.6.0, <5.0
cloudpickle>=2.2.1
tqdm>=4.65.0
# opentelemetry
opentelemetry-api>=1.18.0
opentelemetry-sdk>=1.18.0
Expand Down
4 changes: 3 additions & 1 deletion gateway/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def decrypt_env_vars(env_vars: Dict[str, str]) -> Dict[str, str]:
if "token" in key.lower():
try:
env_vars[key] = decrypt_string(value)
except Exception as decryption_error: # pylint: disable=broad-exception-caught
except (
Exception # pylint: disable=broad-exception-caught
) as decryption_error:
logger.error("Cannot decrypt %s. %s", key, decryption_error)
return env_vars
3 changes: 3 additions & 0 deletions gateway/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
v1_views.JobViewSet,
basename=v1_views.JobViewSet.BASE_NAME,
)
router.register(
r"files", v1_views.FilesViewSet, basename=v1_views.FilesViewSet.BASE_NAME
)

urlpatterns = router.urls
8 changes: 8 additions & 0 deletions gateway/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ class JobViewSet(views.JobViewSet): # pylint: disable=too-many-ancestors

def get_serializer_class(self):
return v1_serializers.JobSerializer


class FilesViewSet(views.FilesViewSet):
"""
Files view set.
"""

permission_classes = [permissions.IsAuthenticated, IsOwner]
psschwei marked this conversation as resolved.
Show resolved Hide resolved
66 changes: 65 additions & 1 deletion gateway/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@

Version views inherit from the different views.
"""

import glob
import mimetypes
import os
import json
import logging
from wsgiref.util import FileWrapper

import requests
from allauth.socialaccount.providers.keycloak.views import KeycloakOAuth2Adapter
from dj_rest_auth.registration.views import SocialLoginView
from django.conf import settings
from django.contrib.auth import get_user_model
from django.http import StreamingHttpResponse
from ray.dashboard.modules.job.sdk import JobSubmissionClient
from rest_framework import viewsets, permissions, status
from rest_framework.decorators import action
Expand Down Expand Up @@ -187,6 +190,67 @@ def stop(self, request, pk=None): # pylint: disable=invalid-name,unused-argumen
return Response({"message": message})


class FilesViewSet(viewsets.ViewSet):
"""ViewSet for file operations handling.

Note: only tar files are available for list and download
"""

BASE_NAME = "files"

# @action(methods=["GET"], detail=False)
IceKhan13 marked this conversation as resolved.
Show resolved Hide resolved
def list(self, request):
"""List of available for user files."""
files = []
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.list", context=ctx):
user_dir = os.path.join(settings.MEDIA_ROOT, request.user.username)
if os.path.exists(user_dir):
files = [
os.path.basename(path) for path in glob.glob(f"{user_dir}/*.tar")
]
else:
logger.warning(
"Directory %s does not exist for %s.", user_dir, request.user
)

return Response({"results": files})

@action(methods=["GET"], detail=False)
def download(self, request): # pylint: disable=invalid-name
"""Download selected file."""
response = Response(
IceKhan13 marked this conversation as resolved.
Show resolved Hide resolved
{"message": "Requested file was not found."},
status=status.HTTP_404_NOT_FOUND,
)
tracer = trace.get_tracer("gateway.tracer")
ctx = TraceContextTextMapPropagator().extract(carrier=request.headers)
with tracer.start_as_current_span("gateway.files.download", context=ctx):
requested_file_name = request.query_params.get("file")
if requested_file_name is not None:
# look for file in user's folder
user_dir = os.path.join(settings.MEDIA_ROOT, request.user.username)
file_path = os.path.join(user_dir, requested_file_name)

if os.path.exists(user_dir) and os.path.exists(file_path):
filename = os.path.basename(file_path)
chunk_size = 8192
# note: we do not use with statements as Streaming response closing file itself.
response = StreamingHttpResponse(
FileWrapper(
open( # pylint: disable=consider-using-with
file_path, "rb"
),
chunk_size,
),
content_type=mimetypes.guess_type(file_path)[0],
)
response["Content-Length"] = os.path.getsize(file_path)
response["Content-Disposition"] = f"attachment; filename={filename}"
return response


class KeycloakLogin(SocialLoginView):
"""KeycloakLogin."""

Expand Down
80 changes: 80 additions & 0 deletions gateway/tests/api/test_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Tests files api."""
import os
from urllib.parse import quote_plus

from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase


class TestFilesApi(APITestCase):
"""TestProgramApi."""

fixtures = ["tests/fixtures/fixtures.json"]

def test_files_list_non_authorized(self):
"""Tests files list non-authorized."""
url = reverse("v1:files-list")
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

def test_files_list(self):
"""Tests files list."""

media_root = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"resources",
"fake_media",
)

with self.settings(MEDIA_ROOT=media_root):
auth = reverse("rest_login")
response = self.client.post(
auth, {"username": "test_user", "password": "123"}, format="json"
)
token = response.data.get("access")
self.client.credentials(HTTP_AUTHORIZATION="Bearer " + token)
url = reverse("v1:files-list")
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, ["artifact.tar"])

def test_non_existing_file_download(self):
psschwei marked this conversation as resolved.
Show resolved Hide resolved
"""Tests downloading non-existing file."""
auth = reverse("rest_login")
response = self.client.post(
auth, {"username": "test_user", "password": "123"}, format="json"
)
token = response.data.get("access")
self.client.credentials(HTTP_AUTHORIZATION="Bearer " + token)
url = reverse("v1:files-download")
response = self.client.get(
url, data={"file": "non_existing.tar"}, format="json"
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
response = self.client.get(url, format="json")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_file_download(self):
"""Tests downloading non-existing file."""
media_root = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"resources",
"fake_media",
)

with self.settings(MEDIA_ROOT=media_root):
auth = reverse("rest_login")
response = self.client.post(
auth, {"username": "test_user", "password": "123"}, format="json"
)
token = response.data.get("access")
self.client.credentials(HTTP_AUTHORIZATION="Bearer " + token)
url = reverse("v1:files-download")
response = self.client.get(
url, data={"file": "artifact.tar"}, format="json"
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(response.streaming)
Binary file not shown.
Empty file.