Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions tests/unit/test_sql_execution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import base64
import datetime
import io
import json
import os
import secrets
import unittest
from unittest import TestCase, mock

import duckdb
import pandas as pd
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from parameterized import parameterized

from deepnote_toolkit.sql.sql_execution import (
Expand Down Expand Up @@ -234,12 +238,21 @@ def test_sql_executed_with_audit_comment_with_semicolon(
def test_execute_sql_with_connection_json_with_snowflake_private_key(
self, mock_execute_sql_with_caching
):
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_key_b64 = base64.b64encode(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
).decode("utf-8")

template = "SELECT * FROM table"
sql_alchemy_json = json.dumps(
{
"url": "snowflake://CHRISARTMANN@2nginys-hu78995?warehouse=&role=&application=Deepnote_Workspaces",
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
"params": {
"snowflake_private_key": "LS0tLS1CRUdJTiBQUklWQVRFIEtFWS0tLS0tCk1JSUV2Z0lCQURBTkJna3Foa2lHOXcwQkFRRUZBQVNDQktnd2dnU2tBZ0VBQW9JQkFRQzFqcUV2dVJ0ajd5bDgKZ29PMTNqWkErak1yU1lReklUWnYva09vVUJvS3dlVFZWKzhYWXgwQzF0QmdRRXUycFZHUUR6ZmY1RFhyU3NBOAo4bWpSczROU3k3aVNpZDdlLzg2QTdoYXRpQ2Q0SDg1aEtabzFxaHRGOW9ob2dwa3Z5NDRZc2RQVWNvYlBNMjFRClB5bkxDdzBOM2kxRklJSjVzY2xRY2Q3ZDNmK3hqZU5SK3M0QkozZTN0SFNaUFB0a25ZZ3EzTUNKaFVNZGoxODQKZ2x4VUFSYnpIbzVFaUJhVmlqckY1R2tSeXpIZFE1MnRORUhaOEllckFxWUJUYVk4bkQwNlJzRHNKUUNPcXIxcQo2bnUzcjBZbmNsNmZyOFgxTTJ3ZU04YWZqL3lRYThnM1Q1dG9CTTYxbVpBRm1uUzcvTDhqdTMvMmRucnZKTkkyCi95OE1TN2RiQWdNQkFBRUNnZ0VBRkl2Z1N6NzVqZVVDU1pMdVJqdGE0Y2p1MkN2cENCaEVIdHg1aFB5N3F4S24Kb1BVNG01UklpTiszNlRkSUw1S2ttUUdxamNNM3p1UkF2bnBOeVIySzg3M0EvNDhXdlI4dlJ0YXE2UUhnT3U5RAozZFJsY3BsSTg0YWpoL1dhVWNHMExRWHRpN3lzNjVuNFR2MDI3NWJWUFYweXU2dE5nMmw1VE45TGNoOERmQnVsCkhHRzJhN2lXWWowUVllUEVKZ3ArMzRVVXBzbjZEdnFmRVUzOGZiN3hHclhiaW5YendleGRLbGNvK21hdmFKYW4KSE1WSzZ3RlgzdzVsbGlQaW5tbFFhOXFueGl5amJ3THBObGgzTFVmbjRnRnQxZllwNWVVV0E2a3FJTm5UMDAzcQpRcWtnTVJ6enQ0MVdTb0FvbDlXZnpDN2NNdm5YSzhkVjkxNlVKN3d0d1FLQmdRRGFyenJkaGRIY0R2NjVYdm1jCldhSlNHQ3ZRdDhMbjUwK3liTVNGNFRXUnMvWGJibFhLUExiWHB2MTgyRVB1U0M1Umt1dVZjZDNlVXV2SmZzNGMKRHF3TGFIS2pIZ09lbE5jd0E1U3Z0eWdBTDRTNjBrcHB3TlNBQ1BPSDhHQ0haUlNuNDgrbVpMWnBINEFCWGwrQgpIaDgyWCt1RWR4M3I4UHUyenVRWUFnZ0UwUUtCZ1FEVWlaSHZhM2ZHcnoyNnRpQ2xYelFFZGRNOW1oTEdPaWtCCjN2Y3dVUWk1V3kwdFZtalpyN0JmWGxzVmY3OU5qRndjK1BqL0RYZjVCS29RcUY2ekZVdW9saG1yVVpSUWVmYk0KUC84K3ZxM015MXN6cXFZa1F5SVBoSDRiT1RWVFFmZ21vbWRPUGdPeVczLzJjSVJkb1RIS2c3L2NnY2lrSnoyNwo0cnlsTVZOMGF3S0JnUUNQSTUyRFBFRjJLZlovUFhSaTY2UzgyWWRzY2F2SkFYWUFFd083b2dMZllRenZXVlFjCk1RdDVNcHUvYVF0bDM2YzV5OUlhR3RNZjMrVG9HZkV0R2tsd21paFhMcUV0M3J6UGQ3aU9IM08yVTJRc3FOTCsKVDdLSUw5Ty95ZzVVOFV2STdPdVJQV0RNaEVyVUdvS20wQ0djQk1MekRNandFK2VlNitNTzk5MXAwUUtCZ1FDVgphLzZBZjJLZStiY0JYR2dKTzZ4N2NrYkg2VmxIcWI0SXhiT3RjVnNieldFdW5iQnJVdHhCd0RsekhQUG0xa1l3ClRFM3FLcEx0TEgxUDVyOWxVaFIxK3NraksrQ0V6NnBXSUt3WGRjRUUyUGRPbEt2bmxKY09wOHhzNFVSL08wTDIKRG5sb2hhcmRxdnlFeXNnVWQyNWsvVWxYQXB1SDVOcS9EQUlxZFVwQjd3S0JnR0tXYlltNE9kL2FGNzlsQUlGbwpEc09VU3YvUFhIQkNOM2xOSFYvaGF4dDlPKzhGL0NMQTBYVnpaRjdESDdScXBXd0NrL2xLa0c5c2lUNE1jK2JJCk9UZEtXOFdOYXptZ1dyVmpQYU9jdENrbnN5SUFJK0phZnB5Qys3cGZkVFdpL1dUUTFneTVUN3JjRFFDTURRZG0KbzlES2NRc3dPeldFa05qOW9NeExUTUlBCi0tLS0tRU5EIFBSSVZBVEUgS0VZLS0tLS0="
"snowflake_private_key": private_key_b64,
},
"param_style": "pyformat",
}
Expand All @@ -249,12 +262,19 @@ def test_execute_sql_with_connection_json_with_snowflake_private_key(

args, _ = mock_execute_sql_with_caching.call_args

# the private key is converted to DER format
expected_private_key_der = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

self.assertEqual(args[0], "SELECT * FROM table")
self.assertEqual(
args[2],
{
"url": "snowflake://CHRISARTMANN@2nginys-hu78995?warehouse=&role=&application=Deepnote_Workspaces",
"params": {"connect_args": {"private_key": mock.ANY}},
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
"params": {"connect_args": {"private_key": expected_private_key_der}},
"param_style": "pyformat",
},
)
Expand All @@ -263,14 +283,25 @@ def test_execute_sql_with_connection_json_with_snowflake_private_key(
def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key(
self, mock_execute_sql_with_caching
):
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
private_key_passphrase = secrets.token_urlsafe(16)
private_key_b64 = base64.b64encode(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(
private_key_passphrase.encode("utf-8")
),
)
).decode("utf-8")

template = "SELECT * FROM table"
sql_alchemy_json = json.dumps(
{
"url": "snowflake://CHRISARTMANN@2nginys-hu78995?warehouse=&role=&application=Deepnote_Workspaces",
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
"params": {
# This is a base-64 encoded private key that's been encrypted with the passphrase.
"snowflake_private_key": "LS0tLS1CRUdJTiBFTkNSWVBURUQgUFJJVkFURSBLRVktLS0tLQpNSUlGSkRCV0Jna3Foa2lHOXcwQkJRMHdTVEF4QmdrcWhraUc5dzBCQlF3d0pBUVFKZUdPYkVYdjY2VXFGem5BCmNBUzBFd0lDQ0FBd0RBWUlLb1pJaHZjTkFna0ZBREFVQmdncWhraUc5dzBEQndRSWJNT1BRQTE5cEhFRWdnVEkKc0NzQnpHOVlqRVRtZS80V2kxYWpyYkFCeDJPSU9UbXBONnJTWEVDcWkxeWNVb3JlWXlxKy9wRVF5WVNzNzBMdQpJenBpVk5zenJZRDAxOFZlOTlXZVU0UGdYb1QrN0dDWWFDWDZERGVIbVhCNzlzSkFmcWYvKzZiRTJUZEIyY2FkCmJ5bjNsOXZBbEd6UmVTSmtrcG5BdnR2YlhpZXpuaGFlWXE1SEtESmFBRnZNL0htdFlsemVCY3JNQW91NmlsSUYKRjgyQzF2TTVTekVRUmhxblN1RU03K2NVTUd4c2R6aHZoSTkwWWRRY0NJVXNsTnE2YkhGUHIyTmFoeGwvRGxKZQozRnY5M3hZSG5pa0tCN01hdGJNK1FQcGVBMnJXOFl4aWdRVTd2WHc2R1lNYkdvS1QwTDczZlo4dXh6UzcxeENJCndVMFR1S0tuN0ZjRW9CaXZIQ0NoNEJpZUpjSk5xdzJNSGJDc2ZmTXlSNUZOZUFSOG5tNjZuSUlURGdLb3NJczQKMVZWTjlKeUNDVk16WGxrNk80bkhaU2JZK0llblNWdE9kaGNpR2U2Zmp5TEFTMlJwekJOMExKb1FueDVnK1JpNgpMTGh1RE1GSDdybFZnT3hWeFY3RWNFeHJuZWF5R1E5andDYnNVNktENFpNeVhXWnhMNVdMbzdHU21lWkZHZmZNCk9MNnBWZXhhdjVpaXQzRzBHUkNNSWhZRjhMb2xidHRXYzB4dldrSjQ2UTJ5am9lcFVwa3IrQ0Q0THNraERkUCsKZ3RFVzEzc2xqbHNCbUxTbzdPMnRlY3ZqTnV4RlFpSU4vZ0x6bHZESTFMQUZVc0k3T1NWOVNRWmRUZXZ1SnZNUgpncDc0Q3N5dWV6US9rZUJhMStWYkltWVRsMDE1czNIeldrSHRjcTg0dS90cm56N1JqTFhKNUdobG9XSTRjL1BwCk15UXlYMWZqMnlzaURERWE4M05wTGNBZWNOYS9mNmpjVm5kRnBWWmlVcWtVaE1JajdPT0VwQUdjNVhILy9ObWYKRFp6UUxFd0xReFR6dlhZaDR4TGRkS3ZFNkovZXI3RUpMczFISGRyeDlxaG85T3dEUjdxZWRNQWxjVFp1Mm1OWgpJREpnM2J6aUF3RlF0MXRxVjV3Rkd6cGNMdFFtblo4dm9od2kzZWx2b2RkaUI2WGpPY3N2QmtqWFRjaEZvSEJyCmdCQUxSU1h6dmxObnN2ZG45UnlCYUtlUnJOL3RKYnhKeHdHa2ZYUkJWZzlTV2R6ZVNsNDE4OXJGaDkrbWR3Y2kKbmcrUTBPZkp6elNTY3cxYXpQVmc0NEd3ejBSTEdwS2FIZmpiSDM3SzltTDJQckJvYmZLdEd4SnNmZ0dxMTFZbAplb3I4bUszbEV3Qm5XL1NQd3dzcEl4QXhoN0x0VEVacTh0V2tjQlMrUWlmQ2Z4L1VVcWQyeGRFNk1XT29JT2ZsCkc1Ti9XdUdhMTREcUZKSDl6R0t4VTBLKzA1K0YwRUZkei9Tc01PQUNtM1VKQXc1UVhUR0ZEV3UyK1M3SW5IRXMKd1R4UGhMWjhKQnZjNHdoS0g3WGNkaEl0UUZhaHY4MWpoNEUwS1ZmUXRvejNEOVowdEFSRnF5K1NIM2JrWlJNSgprczFJWlJ6S1JLdi9ZUlpuaDVYK3VPc2N2Wm1NOHFrU3ZOZWtqdGV6WVo2NjAxWWNnMnRGcDFxYm4xODNVakFVCjZyRHFBQTVKOVRpZ3VyK1Q0TmhKUnZYSEM0Tkp2aklXdXNXSzVWR0E5ZlhKa1UrL1B4anBaOVF4QkozWUxraisKMHpjdU5FdHV5TUN6MFlhcWdRSmd4RUd1cEFrQXlNUlVTM1pvbk5jaEsySkJhekcyN2dGemtMSzF5RTdmQU41SQo0VW1jQ0luOWJkQlMybUFvYXlQU0VURFo2SHNSSU1NbDV3cG5ORkpaM3U4LytraTJjeDlQeXBYbHl3WWUzb0I3CjdnZGw3VjJpL1pHSjNtWld2OUFnZGJwWVBpTkVDSUduTGVpQUdYL2dzRXNlUG9xUERSUURuWGdUZjZOODBLbmUKdm5rb3hsdEE0emRGMVFManU1aXpreUx0NGZZV0dZV2pTLytMN3NqSDhxYktydjRQZFcxbWw1MW16ZEZrQ1pmdAowaEFCbWJMRDV4UkZMSVBoT2tHVjFiRmxQVGFoVHdEbAotLS0tLUVORCBFTkNSWVBURUQgUFJJVkFURSBLRVktLS0tLQ==",
"snowflake_private_key_passphrase": "0*>f-REO1p#1N[1/^",
"snowflake_private_key": private_key_b64,
"snowflake_private_key_passphrase": private_key_passphrase,
},
"param_style": "pyformat",
}
Expand All @@ -280,12 +311,19 @@ def test_execute_sql_with_connection_json_with_snowflake_encrypted_private_key(

args, _ = mock_execute_sql_with_caching.call_args

# The private key is loaded with passphrase (decrypted) then converted to DER format without encryption
expected_private_key_der = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

self.assertEqual(args[0], "SELECT * FROM table")
self.assertEqual(
args[2],
{
"url": "snowflake://CHRISARTMANN@2nginys-hu78995?warehouse=&role=&application=Deepnote_Workspaces",
"params": {"connect_args": {"private_key": mock.ANY}},
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
"params": {"connect_args": {"private_key": expected_private_key_der}},
"param_style": "pyformat",
},
)
Expand Down