Skip to content

Commit

Permalink
Merge pull request #978 from vespa-engine/thomasht86/wip-version-arg-…
Browse files Browse the repository at this point in the history
…to-deploy

(feat) support `version`-parameter in `VespaCloud().deploy`
  • Loading branch information
thomasht86 authored Dec 4, 2024
2 parents b9def47 + 1af6509 commit 9ac5745
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -666,4 +666,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
176 changes: 124 additions & 52 deletions vespa/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,14 +609,16 @@ def deploy(
self,
instance: Optional[str] = "default",
disk_folder: Optional[str] = None,
max_wait: int = 300,
version: Optional[str] = None,
max_wait: int = 1800,
) -> Vespa:
"""
Deploy the given application package as the given instance in the Vespa Cloud dev environment.
:param instance: Name of this instance of the application, in the Vespa Cloud.
:param disk_folder: Disk folder to save the required Vespa config files. Default to application name
folder within user's current working directory.
:param version: Vespa version to use for deployment. Default is None, which means the latest version. Should only be set on instructions from Vespa team. Must be a valid Vespa version, e.g. "8.435.13".
:param max_wait: Seconds to wait for the deployment.
:return: a Vespa connection instance. Returns a connection to the mtls endpoint. To connect to the token endpoint, use :func:`VespaCloud.get_application(endpoint_type="token")`.
Expand All @@ -627,8 +629,14 @@ def deploy(

region = self.get_dev_region()
job = "dev-" + region
run = self._start_deployment(instance, job, disk_folder, None)
self._follow_deployment(instance, job, run)
run = self._start_deployment(
instance=instance,
job=job,
disk_folder=disk_folder,
application_zip_bytes=None,
version=version,
)
self._follow_deployment(instance, job, run, max_wait)
app: Vespa = self.get_application(
instance=instance, environment="dev", endpoint_type="mtls"
)
Expand Down Expand Up @@ -854,7 +862,11 @@ def wait_for_prod_deployment(
raise TimeoutError(f"Deployment did not finish within {max_wait} seconds. ")

def deploy_from_disk(
self, instance: str, application_root: Path, max_wait: int = 300
self,
instance: str,
application_root: Path,
max_wait: int = 300,
version: Optional[str] = None,
) -> Vespa:
"""
Deploy to dev from a directory tree.
Expand All @@ -864,6 +876,7 @@ def deploy_from_disk(
:param instance: Name of the instance where the application is to be run
:param application_root: Application package directory root
:param max_wait: Seconds to wait for the deployment.
:param version: Vespa version to use for deployment. Default is None, which means the latest version. Must be a valid Vespa version, e.g. "8.435.13".
:return: a Vespa connection instance. Returns a connection to the mtls endpoint. To connect to the token endpoint, use :func:`VespaCloud.get_application(endpoint_type="token")`.
"""
data = BytesIO(self.read_app_package_from_disk(application_root))
Expand All @@ -873,11 +886,13 @@ def deploy_from_disk(
region = self.get_dev_region()
job = "dev-" + region
run = self._start_deployment(
instance, job, disk_folder, application_zip_bytes=data
instance=instance,
job=job,
disk_folder=disk_folder,
application_zip_bytes=data,
version=version,
)
self._follow_deployment(instance, job, run)
run = self._start_deployment(instance, job, disk_folder, None)
self._follow_deployment(instance, job, run)
app: Vespa = self.get_application(
instance=instance, environment="dev", endpoint_type="mtls"
)
Expand Down Expand Up @@ -1208,37 +1223,63 @@ def _try_get_access_token(self) -> str:

return auth["providers"]["auth0"]["systems"]["public"]["access_token"]

def _request_with_access_token(
def _handle_response(
self,
method: str,
path: str,
body: BytesIO = BytesIO(),
headers={},
return_raw_response=False,
response: httpx.Response,
return_raw_response: bool = False,
path: str = "",
) -> Union[dict, httpx.Response]:
if not self.control_plane_access_token:
raise ValueError("Access token not set.")
body.seek(0)
headers = {
"Authorization": "Bearer " + self.control_plane_access_token,
**headers,
}
response = self.get_connection_response_with_retry(method, path, body, headers)
"""Common response handling logic"""
if return_raw_response:
return response

try:
parsed = json.load(response)
except json.JSONDecodeError:
parsed = response.read()

if response.status_code != 200:
print(parsed)
raise HTTPError(
f"HTTP {response.status_code} error: {response.reason_phrase} for {path}"
)
return parsed

def _get_auth_headers(self, additional_headers: dict = {}) -> dict:
"""Create authorization headers"""
if not self.control_plane_access_token:
raise ValueError("Access token not set.")

return {
"Authorization": f"Bearer {self.control_plane_access_token}",
**additional_headers,
}

def _request_with_access_token(
self,
method: str,
path: str,
body: Union[BytesIO, MultipartEncoder] = BytesIO(),
headers: dict = {},
return_raw_response: bool = False,
) -> Union[dict, httpx.Response]:
"""Make authenticated request with access token"""
if hasattr(body, "seek"):
body.seek(0)

auth_headers = self._get_auth_headers(headers)
response = self.get_connection_response_with_retry(
method, path, body, auth_headers
)

return self._handle_response(response, return_raw_response, path)

def _request(
self, method: str, path: str, body: BytesIO = BytesIO(), headers={}
self,
method: str,
path: str,
body: Union[BytesIO, MultipartEncoder] = BytesIO(),
headers: dict = {},
) -> Union[dict, httpx.Response]:
if self.control_plane_auth_method == "access_token":
return self._request_with_access_token(method, path, body, headers)
Expand All @@ -1253,47 +1294,54 @@ def _request_with_api_key(
self,
method: str,
path: str,
body: BytesIO = BytesIO(),
headers={},
return_raw_response=False,
body: Union[BytesIO, MultipartEncoder] = BytesIO(),
headers: dict = {},
return_raw_response: bool = False,
) -> Union[dict, httpx.Response]:
digest = hashes.Hash(hashes.SHA256(), default_backend())
body.seek(0)
digest.update(body.read())

# Handle different body types
if isinstance(body, MultipartEncoder):
# Use the encoded data for hash computation
digest = hashes.Hash(hashes.SHA256(), default_backend())
digest.update(body.to_string()) # This moves the buffer position to the end
body._buffer.seek(0) # Needs to be reset. Otherwise, no data will be sent
# Update the headers to include the Content-Type
headers.update({"Content-Type": body.content_type})
# Read the content of multipart_data into a bytes object
multipart_data_bytes: bytes = body.to_string()
headers.update({"Content-Length": str(len(multipart_data_bytes))})
# Convert multipart_data_bytes to type BytesIO
body_data: BytesIO = BytesIO(multipart_data_bytes)
else:
if hasattr(body, "seek"):
body.seek(0)
digest.update(body.read())
body_data = body
# Create signature
content_hash = standard_b64encode(digest.finalize()).decode("UTF-8")
timestamp = (
datetime.utcnow().isoformat() + "Z"
) # Java's Instant.parse requires the neutral time zone appended
timestamp = datetime.utcnow().isoformat() + "Z"
url = self.base_url + path

canonical_message = method + "\n" + url + "\n" + timestamp + "\n" + content_hash
signature = self.api_key.sign(
canonical_message.encode("UTF-8"), ec.ECDSA(hashes.SHA256())
)
signature_b64 = standard_b64encode(signature).decode("UTF-8")

headers = {
"X-Timestamp": timestamp,
"X-Content-Hash": content_hash,
"X-Key-Id": self.tenant + ":" + self.application + ":" + "default",
"X-Key-Id": f"{self.tenant}:{self.application}:default",
"X-Key": self.api_public_key_bytes,
"X-Authorization": standard_b64encode(signature),
"X-Authorization": signature_b64,
**headers,
}

body.seek(0)
response = self.get_connection_response_with_retry(method, path, body, headers)
if return_raw_response:
return response
try:
parsed = json.load(response)
except json.JSONDecodeError:
parsed = response.read()
if response.status_code != 200:
print(parsed)
raise HTTPError(
f"HTTP {response.status_code} error: {response.reason_phrase} for {url}"
)
return parsed
response = self.get_connection_response_with_retry(
method, path, body_data, headers
)
return self._handle_response(response, return_raw_response, path)

def get_all_endpoints(
self,
Expand Down Expand Up @@ -1538,6 +1586,7 @@ def _start_deployment(
job: str,
disk_folder: str,
application_zip_bytes: Optional[BytesIO] = None,
version: Optional[str] = None,
) -> int:
deploy_path = (
"/application/v4/tenant/{}/application/{}/instance/{}/deploy/{}".format(
Expand All @@ -1551,11 +1600,30 @@ def _start_deployment(
if not application_zip_bytes:
application_zip_bytes = self._to_application_zip(disk_folder=disk_folder)

if version is not None:
# Create multipart form data
form_data = {
"applicationZip": (
"application.zip",
application_zip_bytes,
"application/zip",
),
"deployOptions": (
"",
json.dumps({"vespaVersion": version}),
"application/json",
),
}
multipart = MultipartEncoder(fields=form_data)
headers = {"Content-Type": multipart.content_type}
payload = multipart
else:
# Use existing direct zip upload
headers = {"Content-Type": "application/zip"}
payload = application_zip_bytes

response = self._request(
"POST",
deploy_path,
application_zip_bytes,
{"Content-Type": "application/zip"},
method="POST", path=deploy_path, body=payload, headers=headers
)
message = response.get("message", "No message provided")
print(message, file=self.output)
Expand Down Expand Up @@ -1616,9 +1684,12 @@ def _to_application_zip(self, disk_folder: str) -> BytesIO:

return buffer

def _follow_deployment(self, instance: str, job: str, run: int) -> None:
def _follow_deployment(
self, instance: str, job: str, run: int, max_wait: int = 1800
) -> None:
last = -1
while True:
start = time.time()
while time.time() - start < max_wait:
try:
status, last = self._get_deployment_status(instance, job, run, last)
except RuntimeError:
Expand All @@ -1630,6 +1701,7 @@ def _follow_deployment(self, instance: str, job: str, run: int) -> None:
return
else:
raise RuntimeError("Unexpected status: {}".format(status))
raise TimeoutError(f"Deployment did not finish within {max_wait} seconds.")

def _get_deployment_status(
self, instance: str, job: str, run: int, last: int
Expand Down

0 comments on commit 9ac5745

Please sign in to comment.