diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index d33ebde20..f9a1484bf 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -39,6 +39,7 @@ from dbt.adapters.spark.python_submissions import ( JobClusterPythonJobHelper, AllPurposeClusterPythonJobHelper, + SessionHelper ) from dbt.adapters.base import BaseRelation from dbt.adapters.contracts.relation import RelationType, RelationConfig @@ -493,6 +494,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: return { "job_cluster": JobClusterPythonJobHelper, "all_purpose_cluster": AllPurposeClusterPythonJobHelper, + "session": SessionHelper, } def standardize_grants_dict(self, grants_table: "agate.Table") -> dict: diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index e3e7cb370..8800b6da7 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -293,3 +293,16 @@ def submit(self, compiled_code: str) -> None: ) finally: context.destroy(context_id) + + +class SessionHelper(PythonJobHelper): + def __init__(self, parsed_model: Dict, credentials: SparkCredentials) -> None: + pass + + def submit(self, compiled_code: str) -> Any: + try: + from pyspark.sql import SparkSession + spark = SparkSession.getActiveSession() + exec(compiled_code,{"spark": spark}) + except Exception as e: + raise DbtRuntimeError(f"Python model failed with traceback as:\n{e}")