diff --git a/src/datacustomcode/cli.py b/src/datacustomcode/cli.py index 8f9d950..437674a 100644 --- a/src/datacustomcode/cli.py +++ b/src/datacustomcode/cli.py @@ -83,6 +83,7 @@ def zip(path: str): @click.option("--name", required=True) @click.option("--version", default="0.0.1") @click.option("--description", default="Custom Data Transform Code") +@click.option("--profile", default="default") @click.option( "--cpu-size", default="CPU_2XL", @@ -96,7 +97,9 @@ def zip(path: str): Choose based on your workload requirements.""", ) -def deploy(path: str, name: str, version: str, description: str, cpu_size: str): +def deploy( + path: str, name: str, version: str, description: str, cpu_size: str, profile: str +): from datacustomcode.credentials import Credentials from datacustomcode.deploy import TransformationJobMetadata, deploy_full @@ -122,7 +125,7 @@ def deploy(path: str, name: str, version: str, description: str, cpu_size: str): computeType=COMPUTE_TYPES[cpu_size], ) try: - credentials = Credentials.from_available() + credentials = Credentials.from_available(profile=profile) except ValueError as e: click.secho( f"Error: {e}", @@ -192,7 +195,13 @@ def scan(filename: str, config: str, dry_run: bool, no_requirements: bool): @click.argument("entrypoint") @click.option("--config-file", default=None) @click.option("--dependencies", default=[], multiple=True) -def run(entrypoint: str, config_file: Union[str, None], dependencies: List[str]): +@click.option("--profile", default="default") +def run( + entrypoint: str, + config_file: Union[str, None], + dependencies: List[str], + profile: str, +): from datacustomcode.run import run_entrypoint - run_entrypoint(entrypoint, config_file, dependencies) + run_entrypoint(entrypoint, config_file, dependencies, profile) diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index 270640c..0ed02db 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -1,8 +1,12 @@ reader_config: type_config_name: QueryAPIDataCloudReader + options: + credentials_profile: default writer_config: type_config_name: PrintDataCloudWriter + options: + credentials_profile: default spark_config: app_name: DC Custom Code Python SDK Testing diff --git a/src/datacustomcode/credentials.py b/src/datacustomcode/credentials.py index 1689512..d7db5e2 100644 --- a/src/datacustomcode/credentials.py +++ b/src/datacustomcode/credentials.py @@ -65,11 +65,11 @@ def from_env(cls) -> Credentials: ) from exc @classmethod - def from_available(cls) -> Credentials: + def from_available(cls, profile: str = "default") -> Credentials: if os.environ.get("SFDC_USERNAME"): return cls.from_env() if os.path.exists(INI_FILE): - return cls.from_ini() + return cls.from_ini(profile=profile) raise ValueError( "Credentials not found in env or ini file. " "Run `datacustomcode configure` to create a credentials file." diff --git a/src/datacustomcode/io/reader/query_api.py b/src/datacustomcode/io/reader/query_api.py index f41e767..b6b87fa 100644 --- a/src/datacustomcode/io/reader/query_api.py +++ b/src/datacustomcode/io/reader/query_api.py @@ -75,9 +75,11 @@ class QueryAPIDataCloudReader(BaseDataCloudReader): CONFIG_NAME = "QueryAPIDataCloudReader" - def __init__(self, spark: SparkSession) -> None: + def __init__( + self, spark: SparkSession, credentials_profile: str = "default" + ) -> None: self.spark = spark - credentials = Credentials.from_available() + credentials = Credentials.from_available(profile=credentials_profile) self._conn = SalesforceCDPConnection( credentials.login_url, diff --git a/src/datacustomcode/io/writer/print.py b/src/datacustomcode/io/writer/print.py index 7b9ffd4..87129d9 100644 --- a/src/datacustomcode/io/writer/print.py +++ b/src/datacustomcode/io/writer/print.py @@ -26,10 +26,17 @@ class PrintDataCloudWriter(BaseDataCloudWriter): CONFIG_NAME = "PrintDataCloudWriter" def __init__( - self, spark: SparkSession, reader: Optional[QueryAPIDataCloudReader] = None + self, + spark: SparkSession, + reader: Optional[QueryAPIDataCloudReader] = None, + credentials_profile: str = "default", ) -> None: super().__init__(spark) - self.reader = QueryAPIDataCloudReader(self.spark) if reader is None else reader + self.reader = ( + QueryAPIDataCloudReader(self.spark, credentials_profile) + if reader is None + else reader + ) def validate_dataframe_columns_against_dlo( self, diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index 031703b..f25136c 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -20,7 +20,10 @@ def run_entrypoint( - entrypoint: str, config_file: Union[str, None], dependencies: List[str] + entrypoint: str, + config_file: Union[str, None], + dependencies: List[str], + profile: str, ) -> None: """Run the entrypoint script with the given config and dependencies. @@ -28,7 +31,14 @@ def run_entrypoint( entrypoint: The entrypoint script to run. config_file: The config file to use. dependencies: The dependencies to import. + profile: The profile to use. """ + if profile != "default": + if config.reader_config and hasattr(config.reader_config, "options"): + config.reader_config.options["credentials_profile"] = profile + if config.writer_config and hasattr(config.writer_config, "options"): + config.writer_config.options["credentials_profile"] = profile + if config_file: config.load(config_file) for dependency in dependencies: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index e9462e3..e2f1bcc 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -244,3 +244,70 @@ def test_update_ini_new_profile(self): # Check that existing profile was not modified assert mock_config["existing"]["username"] == "existing_user" + + def test_from_available_with_custom_profile(self): + """Test that from_available uses custom profile when specified.""" + ini_content = """ + [default] + username = default_user + password = default_pass + client_id = default_client_id + client_secret = default_secret + login_url = https://default.login.url + + [custom_profile] + username = custom_user + password = custom_pass + client_id = custom_client_id + client_secret = custom_secret + login_url = https://custom.login.url + """ + + with ( + patch("datacustomcode.credentials.INI_FILE", "fake_path"), + patch("os.path.exists", return_value=True), + patch("builtins.open", mock_open(read_data=ini_content)), + ): + # Mock the configparser behavior for reading the file + mock_config = configparser.ConfigParser() + mock_config.read_string(ini_content) + + with patch.object(configparser, "ConfigParser", return_value=mock_config): + # Test default profile + creds_default = Credentials.from_available() + assert creds_default.username == "default_user" + assert creds_default.login_url == "https://default.login.url" + + # Test custom profile + creds_custom = Credentials.from_available(profile="custom_profile") + assert creds_custom.username == "custom_user" + assert creds_custom.password == "custom_pass" + assert creds_custom.client_id == "custom_client_id" + assert creds_custom.client_secret == "custom_secret" + assert creds_custom.login_url == "https://custom.login.url" + + def test_from_available_fallback_to_default(self): + """Test that from_available falls back to default when no profile specified.""" + ini_content = """ + [default] + username = default_user + password = default_pass + client_id = default_client_id + client_secret = default_secret + login_url = https://default.login.url + """ + + with ( + patch("datacustomcode.credentials.INI_FILE", "fake_path"), + patch("os.path.exists", return_value=True), + patch("builtins.open", mock_open(read_data=ini_content)), + ): + # Mock the configparser behavior for reading the file + mock_config = configparser.ConfigParser() + mock_config.read_string(ini_content) + + with patch.object(configparser, "ConfigParser", return_value=mock_config): + # Test that no profile parameter defaults to "default" + creds = Credentials.from_available() + assert creds.username == "default_user" + assert creds.login_url == "https://default.login.url" diff --git a/tests/test_run.py b/tests/test_run.py index cadafd7..8f3f966 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -96,6 +96,7 @@ def test_run_entrypoint_preserves_config(test_config_file, test_entrypoint_file) entrypoint=test_entrypoint_file, config_file=test_config_file, dependencies=[], + profile="default", ) # Check that config was maintained @@ -180,6 +181,7 @@ def test_run_entrypoint_with_dependencies(): entrypoint=entrypoint_file, config_file=config_file, dependencies=[module_name], + profile="default", ) # Verify dependency was imported and used