From 30250962db2404559190b9f2e5bb55652deaae21 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Tue, 16 Apr 2024 11:49:06 +0200 Subject: [PATCH] update simple-salesforce type hints to support 1.12.6 (#39047) --- airflow/providers/salesforce/hooks/salesforce.py | 6 ++++-- airflow/providers/salesforce/operators/bulk.py | 16 +++++++++------- airflow/providers/salesforce/provider.yaml | 4 +--- generated/provider_dependencies.json | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py index 12b220bc57089..6b2bc0e058bf7 100644 --- a/airflow/providers/salesforce/hooks/salesforce.py +++ b/airflow/providers/salesforce/hooks/salesforce.py @@ -27,7 +27,7 @@ import logging import time from functools import cached_property -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any, Iterable, cast from simple_salesforce import Salesforce, api @@ -36,6 +36,7 @@ if TYPE_CHECKING: import pandas as pd from requests import Session + from simple_salesforce.api import SFType log = logging.getLogger(__name__) @@ -190,8 +191,9 @@ def describe_object(self, obj: str) -> dict: :return: the description of the Salesforce object. """ conn = self.get_conn() + sftype: SFType = cast("SFType", conn.__getattr__(obj)) - return conn.__getattr__(obj).describe() + return sftype.describe() def get_available_fields(self, obj: str) -> list[str]: """ diff --git a/airflow/providers/salesforce/operators/bulk.py b/airflow/providers/salesforce/operators/bulk.py index a467601565660..85839fd9f1be4 100644 --- a/airflow/providers/salesforce/operators/bulk.py +++ b/airflow/providers/salesforce/operators/bulk.py @@ -16,12 +16,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, cast from airflow.models import BaseOperator from airflow.providers.salesforce.hooks.salesforce import SalesforceHook if TYPE_CHECKING: + from simple_salesforce.bulk import SFBulkHandler from typing_extensions import Literal from airflow.utils.context import Context @@ -88,29 +89,30 @@ def execute(self, context: Context): """ sf_hook = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) conn = sf_hook.get_conn() + bulk: SFBulkHandler = cast("SFBulkHandler", conn.__getattr__("bulk")) - result = [] + result: Iterable = [] if self.operation == "insert": - result = conn.bulk.__getattr__(self.object_name).insert( + result = bulk.__getattr__(self.object_name).insert( data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial ) elif self.operation == "update": - result = conn.bulk.__getattr__(self.object_name).update( + result = bulk.__getattr__(self.object_name).update( data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial ) elif self.operation == "upsert": - result = conn.bulk.__getattr__(self.object_name).upsert( + result = bulk.__getattr__(self.object_name).upsert( data=self.payload, external_id_field=self.external_id_field, batch_size=self.batch_size, use_serial=self.use_serial, ) elif self.operation == "delete": - result = conn.bulk.__getattr__(self.object_name).delete( + result = bulk.__getattr__(self.object_name).delete( data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial ) elif self.operation == "hard_delete": - result = conn.bulk.__getattr__(self.object_name).hard_delete( + result = bulk.__getattr__(self.object_name).hard_delete( data=self.payload, batch_size=self.batch_size, use_serial=self.use_serial ) diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/salesforce/provider.yaml index 05e34e87895d6..5efa276b5cb6f 100644 --- a/airflow/providers/salesforce/provider.yaml +++ b/airflow/providers/salesforce/provider.yaml @@ -55,9 +55,7 @@ versions: dependencies: - apache-airflow>=2.6.0 - # simple-salesforce 1.12.6 breaks static checks, so we limit it to <1.12.6 for now - # https://github.com/apache/airflow/pull/39045 - - simple-salesforce>=1.0.0,<1.12.6 + - simple-salesforce>=1.0.0 # In pandas 2.2 minimal version of the sqlalchemy is 2.0 # https://pandas.pydata.org/docs/whatsnew/v2.2.0.html#increased-minimum-versions-for-dependencies # However Airflow not fully supports it yet: https://github.com/apache/airflow/issues/28723 diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 8771cab541a92..841f3764674e3 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -970,7 +970,7 @@ "deps": [ "apache-airflow>=2.6.0", "pandas>=1.2.5,<2.2", - "simple-salesforce>=1.0.0,<1.12.6" + "simple-salesforce>=1.0.0" ], "devel-deps": [], "cross-providers-deps": [],