diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py index b6fa934fdd..24ca7e8c52 100644 --- a/pyiceberg/io/__init__.py +++ b/pyiceberg/io/__init__.py @@ -48,55 +48,62 @@ logger = logging.getLogger(__name__) -AWS_REGION = "client.region" -AWS_ACCESS_KEY_ID = "client.access-key-id" -AWS_SECRET_ACCESS_KEY = "client.secret-access-key" -AWS_SESSION_TOKEN = "client.session-token" -AWS_ROLE_ARN = "client.role-arn" -AWS_ROLE_SESSION_NAME = "client.role-session-name" -S3_ENDPOINT = "s3.endpoint" -S3_ACCESS_KEY_ID = "s3.access-key-id" -S3_SECRET_ACCESS_KEY = "s3.secret-access-key" -S3_SESSION_TOKEN = "s3.session-token" -S3_REGION = "s3.region" -S3_RESOLVE_REGION = "s3.resolve-region" -S3_PROXY_URI = "s3.proxy-uri" -S3_CONNECT_TIMEOUT = "s3.connect-timeout" -S3_REQUEST_TIMEOUT = "s3.request-timeout" -S3_SIGNER = "s3.signer" -S3_SIGNER_URI = "s3.signer.uri" -S3_SIGNER_ENDPOINT = "s3.signer.endpoint" -S3_SIGNER_ENDPOINT_DEFAULT = "v1/aws/s3/sign" -S3_ROLE_ARN = "s3.role-arn" -S3_ROLE_SESSION_NAME = "s3.role-session-name" -S3_FORCE_VIRTUAL_ADDRESSING = "s3.force-virtual-addressing" -HDFS_HOST = "hdfs.host" -HDFS_PORT = "hdfs.port" -HDFS_USER = "hdfs.user" -HDFS_KERB_TICKET = "hdfs.kerberos_ticket" -ADLS_CONNECTION_STRING = "adls.connection-string" -ADLS_ACCOUNT_NAME = "adls.account-name" +ADLS_ACCOUNT_HOST = "adls.account-host" ADLS_ACCOUNT_KEY = "adls.account-key" -ADLS_SAS_TOKEN = "adls.sas-token" -ADLS_TENANT_ID = "adls.tenant-id" +ADLS_ACCOUNT_NAME = "adls.account-name" ADLS_CLIENT_ID = "adls.client-id" ADLS_CLIENT_SECRET = "adls.client-secret" -ADLS_ACCOUNT_HOST = "adls.account-host" -GCS_TOKEN = "gcs.oauth2.token" -GCS_TOKEN_EXPIRES_AT_MS = "gcs.oauth2.token-expires-at" -GCS_PROJECT_ID = "gcs.project-id" +ADLS_CONNECTION_STRING = "adls.connection-string" +ADLS_SAS_TOKEN = "adls.sas-token" +ADLS_TENANT_ID = "adls.tenant-id" +AWS_ACCESS_KEY_ID = "client.access-key-id" +AWS_REGION = "client.region" +AWS_ROLE_ARN = "client.role-arn" +AWS_ROLE_SESSION_NAME = "client.role-session-name" +AWS_SECRET_ACCESS_KEY = "client.secret-access-key" +AWS_SESSION_TOKEN = "client.session-token" +AZURE_CLIENT_ID = "client.client-id" +AZURE_ClIENT_SECRET = "client.client-secret" +AZURE_TENANT_ID = "client.tenant-id" +COLUMN_KEY = "table.column-key" +FOOTER_KEY = "table.footer-key" +GCP_OAUTH2_TOKEN = "client.oauth2-token" GCS_ACCESS = "gcs.access" -GCS_CONSISTENCY = "gcs.consistency" GCS_CACHE_TIMEOUT = "gcs.cache-timeout" +GCS_CONSISTENCY = "gcs.consistency" +GCS_DEFAULT_LOCATION = "gcs.default-bucket-location" +GCS_PROJECT_ID = "gcs.project-id" GCS_REQUESTER_PAYS = "gcs.requester-pays" -GCS_SESSION_KWARGS = "gcs.session-kwargs" GCS_SERVICE_HOST = "gcs.service.host" -GCS_DEFAULT_LOCATION = "gcs.default-bucket-location" +GCS_SESSION_KWARGS = "gcs.session-kwargs" +GCS_TOKEN = "gcs.oauth2.token" +GCS_TOKEN_EXPIRES_AT_MS = "gcs.oauth2.token-expires-at" GCS_VERSION_AWARE = "gcs.version-aware" +HDFS_HOST = "hdfs.host" +HDFS_KERB_TICKET = "hdfs.kerberos_ticket" +HDFS_PORT = "hdfs.port" +HDFS_USER = "hdfs.user" HF_ENDPOINT = "hf.endpoint" HF_TOKEN = "hf.token" +KEEP_FOOTER_IN_PLAINTEXT = "table.keep-footer-in-plaintext" +KMS_VENDOR = "client.kms-vendor" PYARROW_USE_LARGE_TYPES_ON_READ = "pyarrow.use-large-types-on-read" - +S3_ACCESS_KEY_ID = "s3.access-key-id" +S3_CONNECT_TIMEOUT = "s3.connect-timeout" +S3_ENDPOINT = "s3.endpoint" +S3_FORCE_VIRTUAL_ADDRESSING = "s3.force-virtual-addressing" +S3_PROXY_URI = "s3.proxy-uri" +S3_REGION = "s3.region" +S3_REQUEST_TIMEOUT = "s3.request-timeout" +S3_RESOLVE_REGION = "s3.resolve-region" +S3_ROLE_ARN = "s3.role-arn" +S3_ROLE_SESSION_NAME = "s3.role-session-name" +S3_SECRET_ACCESS_KEY = "s3.secret-access-key" +S3_SESSION_TOKEN = "s3.session-token" +S3_SIGNER = "s3.signer" +S3_SIGNER_ENDPOINT = "s3.signer.endpoint" +S3_SIGNER_ENDPOINT_DEFAULT = "v1/aws/s3/sign" +S3_SIGNER_URI = "s3.signer.uri" @runtime_checkable class InputStream(Protocol): diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1aaab32dbe..d6e1b6e8f9 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -36,11 +36,22 @@ import uuid import warnings from abc import ABC, abstractmethod +from azure.identity import ClientSecretCredential +from azure.keyvault.keys.crypto import ( + CryptographyClient, + EncryptionAlgorithm, +) +from base64 import ( + b64decode, + b64encode, +) +from boto3 import client from concurrent.futures import Future from copy import copy from dataclasses import dataclass from enum import Enum from functools import lru_cache, singledispatch +from requests import post from typing import ( TYPE_CHECKING, Any, @@ -71,6 +82,12 @@ FileType, FSSpecHandler, ) +from pyarrow.parquet.encryption import ( + CryptoFactory, + EncryptionConfiguration, + KmsClient, + KmsConnectionConfig, +) from sortedcontainers import SortedList from pyiceberg.conversions import to_bytes @@ -91,6 +108,12 @@ AWS_ROLE_SESSION_NAME, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, + AZURE_CLIENT_ID, + AZURE_ClIENT_SECRET, + AZURE_TENANT_ID, + COLUMN_KEY, + FOOTER_KEY, + GCP_OAUTH2_TOKEN, GCS_DEFAULT_LOCATION, GCS_SERVICE_HOST, GCS_TOKEN, @@ -99,6 +122,8 @@ HDFS_KERB_TICKET, HDFS_PORT, HDFS_USER, + KEEP_FOOTER_IN_PLAINTEXT, + KMS_VENDOR, PYARROW_USE_LARGE_TYPES_ON_READ, S3_ACCESS_KEY_ID, S3_CONNECT_TIMEOUT, @@ -197,9 +222,278 @@ MAP_VALUE_NAME = "value" DOC = "doc" UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"} +# Common object of ParquetModularEncryption class +parquet_modular_encryption = None T = TypeVar("T") +class ParquetModularEncryption: + """ + A class that implements Parquet Modular Encryption. + + Attributes: + ---------- + aws_access_key_id: str + AWS access key ID of an AWS IAM user, where the user has appropriate permissions to access AWS KMS. + This attribute is applicable when kms_vendor attribute is set to "aws". + aws_region_name: str + AWS region name of the AWS KMS. + This attribute is applicable when kms_vendor attribute is set to "aws". + aws_secret_access_key: str + AWS secret access key of an AWS IAM user, where the user has appropriate permissions to access AWS KMS. + This attribute is applicable when kms_vendor attribute is set to "aws". + azure_client_id: str + Client ID of a Microsoft Entra ID application, where the application has appropriate permissions to access Azure Key Vault. + This attribute is applicable when kms_vendor attribute is set to "azure". + azure_client_secret: str + Client secret of a Microsoft Entra ID application, where the application has appropriate permissions to access Azure Key Vault. + This attribute is applicable when kms_vendor attribute is set to "azure". + azure_tenant_id: str + Tenant ID of a Microsoft Entra ID application, where the application has appropriate permissions to access Azure Key Vault. + This attribute is applicable when kms_vendor attribute is set to "azure". + gcp_oauth2_token: str + GCP OAuth2 token, where the token has appropriate permissions to access GCP KMS. + This attribute is applicable when kms_vendor attribute is set to "gcp". + kms_vendor: str + Specifies KMS vendor. + Valid values are "aws", "azure", or "gcp". + This attribute also determines whether Parquet Modular Encryption is required or not. + + Methods: + ------- + get_kms_connection_config() -> KmsConnectionConfig: + Returns an object of pyarrow.parquet.encryption.KmsConnectionConfig class. + get_crypto_factory() -> CryptoFactory: + Returns an object of pyarrow.parquet.encryption.CryptoFactory class. + get_decryption_properties(): + Returns an object of low-level FileDecryptionProperties class. + get_encryption_properties(column_key: Dict[str, List[str]], footer_key: str, keep_footer_in_plaintext: str): + Returns an object of low-level FileEncryptionProperties class. + """ + def __init__(self, aws_access_key_id: str, aws_region_name: str, aws_secret_access_key: str, azure_client_id: str, azure_client_secret: str, azure_tenant_id: str, gcp_oauth2_token: str, kms_vendor: str): + self.aws_access_key_id = aws_access_key_id + self.aws_region_name = aws_region_name + self.aws_secret_access_key = aws_secret_access_key + self.azure_client_id = azure_client_id + self.azure_client_secret = azure_client_secret + self.azure_tenant_id = azure_tenant_id + self.gcp_oauth2_token = gcp_oauth2_token + self.kms_vendor = kms_vendor + + def get_kms_connection_config(self) -> KmsConnectionConfig: + """ + Returns an object of pyarrow.parquet.encryption.KmsConnectionConfig class. + + Returns: + ------- + KmsConnectionConfig + If self.kms_vendor attribute and self.kms_vendor-related attributes are set to valid values. + """ + if self.kms_vendor == None: + raise Exception("client.kms-vendor catalog property is not set.") + if self.kms_vendor == "aws": + if self.aws_access_key_id == None: + raise Exception("client.access-key-id catalog property is not set.") + if self.aws_region_name == None: + raise Exception("client.region catalog property is not set.") + if self.aws_secret_access_key == None: + raise Exception("client.secret-access-key catalog property is not set.") + return KmsConnectionConfig(custom_kms_conf={ + "aws_access_key_id": self.aws_access_key_id, + "aws_region_name": self.aws_region_name, + "aws_secret_access_key": self.aws_secret_access_key + }) + elif self.kms_vendor == "azure": + if self.azure_client_id == None: + raise Exception("client.client-id catalog property is not set.") + if self.azure_client_secret == None: + raise Exception("client.client-secret catalog property is not set.") + if self.azure_tenant_id == None: + raise Exception("client.tenant-id catalog property is not set.") + return KmsConnectionConfig(custom_kms_conf={ + "azure_client_id": self.azure_client_id, + "azure_client_secret": self.azure_client_secret, + "azure_tenant_id": self.azure_tenant_id + }) + elif self.kms_vendor == "gcp": + if self.gcp_oauth2_token == None: + raise Exception("client.oauth2-token catalog property is not set.") + return KmsConnectionConfig(custom_kms_conf={"gcp_oauth2_token": self.gcp_oauth2_token}) + else: + raise Exception("client.kms-vendor catalog property is invalid.") + + def get_crypto_factory(self) -> CryptoFactory: + """ + Returns an object of pyarrow.parquet.encryption.CryptoFactory class. + + Returns: + ------- + CryptoFactory + If ParquetModularEncryptionKmsClient class, which inherits pyarrow.parquet.encryption.KmsClient class, is successfully created. + """ + kms_connection_config = self.get_kms_connection_config() + + if self.kms_vendor == "aws": + class ParquetModularEncryptionKmsClient(KmsClient): + def __init__(self, kms_connection_config): + KmsClient.__init__(self) + self.client = client( + aws_access_key_id=kms_connection_config.custom_kms_conf["aws_access_key_id"], + aws_secret_access_key=kms_connection_config.custom_kms_conf["aws_secret_access_key"], + region_name=kms_connection_config.custom_kms_conf["aws_region_name"], + service_name="kms" + ) + + def wrap_key(self, key_bytes, master_key_identifier): + return b64encode(s=self.client.encrypt( + KeyId=master_key_identifier, + Plaintext=key_bytes + )["CiphertextBlob"]) + + def unwrap_key(self, wrapped_key, master_key_identifier): + return self.client.decrypt( + CiphertextBlob=b64decode(s=wrapped_key), + KeyId=master_key_identifier + )["Plaintext"] + + def kms_client_factory(kms_connection_config): + return ParquetModularEncryptionKmsClient(kms_connection_config=kms_connection_config) + + return CryptoFactory(kms_client_factory=kms_client_factory) + elif self.kms_vendor == "azure": + class ParquetModularEncryptionKmsClient(KmsClient): + def __init__(self, kms_connection_config): + KmsClient.__init__(self) + self.client_secret_credential = ClientSecretCredential( + client_id=kms_connection_config.custom_kms_conf["azure_client_id"], + client_secret=kms_connection_config.custom_kms_conf["azure_client_secret"], + tenant_id=kms_connection_config.custom_kms_conf["azure_tenant_id"] + ) + + def wrap_key(self, key_bytes, master_key_identifier): + return b64encode(s=CryptographyClient( + credential=self.client_secret_credential, + key="https://%s"%master_key_identifier + ).encrypt( + algorithm=EncryptionAlgorithm.rsa_oaep_256, + plaintext=key_bytes + ).ciphertext) + + def unwrap_key(self, wrapped_key, master_key_identifier): + return CryptographyClient( + credential=self.client_secret_credential, + key="https://%s"%master_key_identifier + ).decrypt( + algorithm=EncryptionAlgorithm.rsa_oaep_256, + ciphertext=b64decode(s=wrapped_key) + ).plaintext + + def kms_client_factory(kms_connection_config): + return ParquetModularEncryptionKmsClient(kms_connection_config=kms_connection_config) + + return CryptoFactory(kms_client_factory=kms_client_factory) + else: + class ParquetModularEncryptionKmsClient(KmsClient): + def __init__(self, kms_connection_config): + KmsClient.__init__(self) + self.token = kms_connection_config.custom_kms_conf["gcp_oauth2_token"] + + def wrap_key(self, key_bytes, master_key_identifier): + response_of_encrypting_key = post( + headers={"Authorization": "Bearer %s"%self.token}, + json={"plaintext": b64encode(s=key_bytes).decode()}, + url="https://cloudkms.googleapis.com/v1/%s:encrypt"%master_key_identifier + ) + response_of_encrypting_key.close() + if response_of_encrypting_key.status_code == 200: + return response_of_encrypting_key.json()["ciphertext"].encode() + else: + raise Exception(response_of_encrypting_key.text) + + def unwrap_key(self, wrapped_key, master_key_identifier): + response_of_decrypting_key = post( + headers={"Authorization": "Bearer %s"%self.token}, + json={"ciphertext": wrapped_key}, + url="https://cloudkms.googleapis.com/v1/%s:decrypt"%master_key_identifier + ) + response_of_decrypting_key.close() + if response_of_decrypting_key.status_code == 200: + return b64decode(s=response_of_decrypting_key.json()["plaintext"]) + else: + raise Exception(response_of_decrypting_key.text) + + def kms_client_factory(kms_connection_config): + return ParquetModularEncryptionKmsClient(kms_connection_config=kms_connection_config) + + return CryptoFactory(kms_client_factory=kms_client_factory) + + def get_decryption_properties(self): + """ + Returns an object of low-level FileDecryptionProperties class. + + Returns: + ------- + None or FileDecryptionProperties + None, if self.kms_vendor attribute is found None. Here, "self.kms_vendor = None" means that Parquet Modular Encryption is not required. + FileDecryptionProperties, if pyarrow.parquet.encryption.CryptoFactory.file_decryption_properties() method is successfully created. + """ + if self.kms_vendor == None: + return None + else: + return self.get_crypto_factory().file_decryption_properties(kms_connection_config=self.get_kms_connection_config()) + + def get_encryption_properties(self, column_key: Dict[str, List[str]], footer_key: str, keep_footer_in_plaintext: str): + """ + Returns an object of low-level FileEncryptionProperties class. + + Parameters: + ---------- + column_key: Dict[str, List[str]] + Dictionary of -> master key ID for encrypting column(s): list of column(s). + Example: + {"master_key_id": ["column_name"]} + footer_key: str + Master key ID for encrypting footer. + keep_footer_in_plaintext: str + Specifies whether to keep footer in plaintext or encrypted. + Valid values are "yes" or "no". + "yes" keeps footer in plaintext, whereas "no" keeps footer in encrypted. + + Returns: + ------- + None or FileEncryptionProperties + None, if self.kms_vendor attribute is found None. Here, "self.kms_vendor = None" means that Parquet Modular Encryption is not required. + FileEncryptionProperties, if pyarrow.parquet.encryption.CryptoFactory.get_encryption_properties() method is successfully created. + """ + if self.kms_vendor == None: + return None + else: + if column_key == None: + raise Exception("table.column-key table property is not set.") + if footer_key == None: + raise Exception("table.footer-key table property is not set.") + if keep_footer_in_plaintext == None: + raise Exception("table.keep-footer-in-plaintext table property is not set.") + if keep_footer_in_plaintext == "yes": + return self.get_crypto_factory().file_encryption_properties( + encryption_config=EncryptionConfiguration( + column_keys=eval(column_key), + footer_key=footer_key, + plaintext_footer=True + ), + kms_connection_config=self.get_kms_connection_config() + ) + elif keep_footer_in_plaintext == "no": + return self.get_crypto_factory().file_encryption_properties( + encryption_config=EncryptionConfiguration( + column_keys=eval(column_key), + footer_key=footer_key, + plaintext_footer=False + ), + kms_connection_config=self.get_kms_connection_config() + ) + else: + raise Exception("table.keep-footer-in-plaintext table property is invalid.") @lru_cache def _cached_resolve_s3_region(bucket: str) -> Optional[str]: @@ -382,6 +676,18 @@ def parse_location(location: str) -> Tuple[str, str, str]: def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem: """Initialize FileSystem for different scheme.""" + global parquet_modular_encryption + parquet_modular_encryption = ParquetModularEncryption( + aws_access_key_id=self.properties.get(AWS_ACCESS_KEY_ID), + aws_region_name=self.properties.get(AWS_REGION), + aws_secret_access_key=self.properties.get(AWS_SECRET_ACCESS_KEY), + azure_client_id=self.properties.get(AZURE_CLIENT_ID), + azure_client_secret=self.properties.get(AZURE_ClIENT_SECRET), + azure_tenant_id=self.properties.get(AZURE_TENANT_ID), + gcp_oauth2_token=self.properties.get(GCP_OAUTH2_TOKEN), + kms_vendor=self.properties.get(KMS_VENDOR) + ) + if scheme in {"oss"}: return self._initialize_oss_fs() @@ -1394,7 +1700,11 @@ def _task_to_record_batches( partition_spec: Optional[PartitionSpec] = None, ) -> Iterator[pa.RecordBatch]: _, _, path = _parse_location(task.file.file_path) - arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) + arrow_format = ds.ParquetFileFormat( + buffer_size=(ONE_MEGABYTE * 8), + decryption_properties=parquet_modular_encryption.get_decryption_properties(), + pre_buffer=True + ) with fs.open_input_file(path) as fin: fragment = arrow_format.make_fragment(fin) physical_schema = fragment.physical_schema @@ -2453,9 +2763,15 @@ def write_parquet(task: WriteTask) -> DataFile: ) as writer: writer.write(arrow_table, row_group_size=row_group_size) statistics = data_file_statistics_from_parquet_metadata( - parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(file_schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(file_schema), + parquet_column_mapping=parquet_path_to_id_mapping(file_schema), + parquet_metadata=pq.read_metadata( + decryption_properties=parquet_modular_encryption.get_decryption_properties(), + where=file_path + ), + stats_columns=compute_statistics_plan( + file_schema, + table_metadata.properties + ) ) data_file = DataFile.from_args( content=DataFileContent.DATA, @@ -2601,6 +2917,11 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]: property_name=TableProperties.PARQUET_DICT_SIZE_BYTES, default=TableProperties.PARQUET_DICT_SIZE_BYTES_DEFAULT, ), + "encryption_properties": parquet_modular_encryption.get_encryption_properties( + column_key=table_properties.get(COLUMN_KEY), + footer_key=table_properties.get(FOOTER_KEY), + keep_footer_in_plaintext=table_properties.get(KEEP_FOOTER_IN_PLAINTEXT) + ), "write_batch_size": property_as_int( properties=table_properties, property_name=TableProperties.PARQUET_PAGE_ROW_LIMIT,