Skip to content

Commit

Permalink
SNOW-1650124: Pass Azure SAS Token via params (#2060)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Oct 22, 2024
1 parent dbc9284 commit d73afee
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
18 changes: 17 additions & 1 deletion src/snowflake/connector/azure_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import xml.etree.ElementTree as ET
from datetime import datetime, timezone
from logging import getLogger
from logging import Filter, getLogger
from random import choice
from string import hexdigits
from typing import TYPE_CHECKING, Any, NamedTuple
Expand Down Expand Up @@ -39,6 +39,22 @@ class AzureLocation(NamedTuple):
MATDESC = "x-ms-meta-matdesc"


class AzureCredentialFilter(Filter):
LEAKY_FMT = '%s://%s:%s "%s %s %s" %s %s'

def filter(self, record):
if record.msg == AzureCredentialFilter.LEAKY_FMT and len(record.args) == 8:
record.args = (
record.args[:4] + (record.args[4].split("?")[0],) + record.args[5:]
)
return True


getLogger("snowflake.connector.vendored.urllib3.connectionpool").addFilter(
AzureCredentialFilter()
)


class SnowflakeAzureRestClient(SnowflakeStorageClient):
def __init__(
self,
Expand Down
11 changes: 9 additions & 2 deletions test/integ/test_put_get_with_azure_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import sys
import time
from logging import getLogger
from logging import DEBUG, getLogger

import pytest

Expand All @@ -37,8 +37,9 @@
@pytest.mark.parametrize(
"from_path", [True, pytest.param(False, marks=pytest.mark.skipolddriver)]
)
def test_put_get_with_azure(tmpdir, conn_cnx, from_path):
def test_put_get_with_azure(tmpdir, conn_cnx, from_path, caplog):
"""[azure] Puts and Gets a small text using Azure."""
caplog.set_level(DEBUG)
# create a data file
fname = str(tmpdir.join("test_put_get_with_azure_token.txt.gz"))
original_contents = "123,test1\n456,test2\n"
Expand Down Expand Up @@ -85,6 +86,12 @@ def test_put_get_with_azure(tmpdir, conn_cnx, from_path):
file_stream.close()
csr.execute(f"drop table {table_name}")

for line in caplog.text.splitlines():
if "blob.core.windows.net" in line:
assert (
"sig=" not in line
), "connectionpool logger is leaking sensitive information"

files = glob.glob(os.path.join(tmp_dir, "data_*"))
with gzip.open(files[0], "rb") as fd:
contents = fd.read().decode(UTF8)
Expand Down

0 comments on commit d73afee

Please sign in to comment.