From 2a0ff1fe49fc84dcbc7df0ab76f840149b419891 Mon Sep 17 00:00:00 2001 From: "Damon P. Cortesi" Date: Wed, 18 May 2016 19:27:19 -0700 Subject: [PATCH] Add support for multiple results in salesforce batch jobs (#1686) * #1685 Add support for multiple results in batch jobs Adds a new method 'get_batch_result_ids' and maintains backwards compatibility with the old method, adding a warning. * Add the ability to merge result sets and basic testing. Note that this only supports CSV since this is what the library does by default. --- luigi/contrib/salesforce.py | 78 ++++++++++++++++----- test/contrib/salesforce_test.py | 117 ++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 16 deletions(-) create mode 100644 test/contrib/salesforce_test.py diff --git a/luigi/contrib/salesforce.py b/luigi/contrib/salesforce.py index dadbd86a65..7da43d5d74 100644 --- a/luigi/contrib/salesforce.py +++ b/luigi/contrib/salesforce.py @@ -17,6 +17,7 @@ import time import abc import logging +import warnings import xml.etree.ElementTree as ET from collections import OrderedDict import re @@ -139,6 +140,11 @@ def is_soql_file(self): """Override to True if soql property is a file path.""" return False + @property + def content_type(self): + """Override to use a different content type. (e.g. XML)""" + return "CSV" + def run(self): if self.use_sandbox and not self.sandbox_name: raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox") @@ -170,11 +176,23 @@ def run(self): if 'foreign key relationships not supported' not in status['state_message'].lower(): raise Exception(msg) else: - result_id = sf.get_batch_results(job_id, batch_id) - data = sf.get_batch_result(job_id, batch_id, result_id) - - with open(self.output().fn, 'w') as outfile: - outfile.write(data) + result_ids = sf.get_batch_result_ids(job_id, batch_id) + + # If there's only one result, just download it, otherwise we need to merge the resulting downloads + if len(result_ids) == 1: + data = sf.get_batch_result(job_id, batch_id, result_ids[0]) + with open(self.output().path, 'w') as outfile: + outfile.write(data) + else: + # Download each file to disk, and then merge into one. + # Preferring to do it this way so as to minimize memory consumption. + for i, result_id in enumerate(result_ids): + logger.info("Downloading batch result %s for batch: %s and job: %s" % (result_id, batch_id, job_id)) + with open("%s.%d" % (self.output().path, i), 'w') as outfile: + outfile.write(sf.get_batch_result(job_id, batch_id, result_id)) + + logger.info("Merging results of batch %s" % batch_id) + self.merge_batch_results(result_ids) finally: logger.info("Closing job %s" % job_id) sf.close_job(job_id) @@ -184,11 +202,30 @@ def run(self): data_file = sf.query_all(self.soql) reader = csv.reader(data_file) - with open(self.output().fn, 'w') as outfile: + with open(self.output().path, 'w') as outfile: writer = csv.writer(outfile, dialect='excel') for row in reader: writer.writerow(row) + def merge_batch_results(self, result_ids): + """ + Merges the resulting files of a multi-result batch bulk query. + """ + outfile = open(self.output().path, 'w') + + if self.content_type == 'CSV': + for i, result_id in enumerate(result_ids): + with open("%s.%d" % (self.output().path, i), 'r') as f: + header = f.readline() + if i == 0: + outfile.write(header) + for line in f: + outfile.write(line) + else: + raise Exception("Batch result merging not implemented for %s" % self.content_type) + + outfile.close() + class SalesforceAPI(object): """ @@ -353,15 +390,17 @@ def restful(self, path, params): else: return json_result - def create_operation_job(self, operation, obj, external_id_field_name=None, content_type='CSV'): + def create_operation_job(self, operation, obj, external_id_field_name=None, content_type=None): """ Creates a new SF job that for doing any operation (insert, upsert, update, delete, query) :param operation: delete, insert, query, upsert, update, hardDelete. Must be lowercase. :param obj: Parent SF object :param external_id_field_name: Optional. - :param content_type: XML, CSV, ZIP_CSV, or ZIP_XML. Defaults to CSV """ + if content_type is None: + content_type = self.content_type + if not self.has_active_session(): self.start_session() @@ -419,7 +458,7 @@ def close_job(self, job_id): return response - def create_batch(self, job_id, data, file_type='csv'): + def create_batch(self, job_id, data, file_type=None): """ Creates a batch with either a string of data or a file containing data. @@ -429,13 +468,15 @@ def create_batch(self, job_id, data, file_type='csv'): :param job_id: job_id as returned by 'create_operation_job(...)' :param data: - :param file_type: :return: Returns batch_id """ if not job_id or not self.has_active_session(): raise Exception("Can not create a batch without a valid job_id and an active session.") + if file_type is None: + file_type = self.content_type.lower() + headers = self._get_create_batch_content_headers(file_type) headers['Content-Length'] = len(data) @@ -473,22 +514,27 @@ def block_on_batch(self, job_id, batch_id, sleep_time_seconds=5, max_wait_time_s def get_batch_results(self, job_id, batch_id): """ - Get results of a batch that has completed processing. - If the batch is a CSV file, the response is in CSV format. - If the batch is an XML file, the response is in XML format. + DEPRECATED: Use `get_batch_result_ids` + """ + warnings.warn("get_batch_results is deprecated and only returns one batch result. Please use get_batch_result_ids") + return self.get_batch_result_ids(job_id, batch_id)[0] + + def get_batch_result_ids(self, job_id, batch_id): + """ + Get result IDs of a batch that has completed processing. :param job_id: job_id as returned by 'create_operation_job(...)' :param batch_id: batch_id as returned by 'create_batch(...)' - :return: batch result response as either CSV or XML, dependent on the batch + :return: list of batch result IDs to be used in 'get_batch_result(...)' """ response = requests.get(self._get_batch_results_url(job_id, batch_id), headers=self._get_batch_info_headers()) response.raise_for_status() root = ET.fromstring(response.text) - result = root.find('%sresult' % self.API_NS).text + result_ids = [r.text for r in root.findall('%sresult' % self.API_NS)] - return result + return result_ids def get_batch_result(self, job_id, batch_id, result_id): """ diff --git a/test/contrib/salesforce_test.py b/test/contrib/salesforce_test.py new file mode 100644 index 0000000000..6b7fe278a6 --- /dev/null +++ b/test/contrib/salesforce_test.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# +# Copyright (c) 2016 Simply Measured +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +# This method will be used by the mock to replace requests.get + +""" +Unit test for the Salesforce contrib package +""" + +from luigi.contrib.salesforce import SalesforceAPI, QuerySalesforce + +from helpers import unittest +import mock +from luigi.mock import MockTarget +from luigi.six import PY3 +import re + + +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, body, status_code): + self.body = body + self.status_code = status_code + + @property + def text(self): + return self.body + + def raise_for_status(self): + return None + + result_list = ( + '' + '123412351236' + '' + ) + return MockResponse(result_list, 200) + +# Keep open around so we can use it in the mock responses +old__open = open + + +def mocked_open(*args, **kwargs): + if re.match("job_data", args[0]): + return MockTarget(args[0]).open(args[1]) + else: + return old__open(*args) + + +class TestSalesforceAPI(unittest.TestCase): + # We patch 'requests.get' with our own method. The mock object is passed in to our test case method. + @mock.patch('requests.get', side_effect=mocked_requests_get) + def test_deprecated_results(self, mock_get): + sf = SalesforceAPI('xx', 'xx', 'xx') + result_id = sf.get_batch_results('job_id', 'batch_id') + self.assertEqual('1234', result_id) + + @mock.patch('requests.get', side_effect=mocked_requests_get) + def test_result_ids(self, mock_get): + sf = SalesforceAPI('xx', 'xx', 'xx') + result_ids = sf.get_batch_result_ids('job_id', 'batch_id') + self.assertEqual(['1234', '1235', '1236'], result_ids) + + +class TestQuerySalesforce(QuerySalesforce): + def output(self): + return MockTarget('job_data.csv') + + @property + def object_name(self): + return 'dual' + + @property + def soql(self): + return "SELECT * FROM %s" % self.object_name + + +class TestSalesforceQuery(unittest.TestCase): + patch_name = '__builtin__.open' + if PY3: + patch_name = 'builtins.open' + + @mock.patch(patch_name, side_effect=mocked_open) + def setUp(self, mock_open): + MockTarget.fs.clear() + self.result_ids = ['a', 'b', 'c'] + + counter = 1 + self.all_lines = "Lines\n" + self.header = "Lines" + for i, id in enumerate(self.result_ids): + filename = "%s.%d" % ('job_data.csv', i) + with MockTarget(filename).open('w') as f: + line = "%d line\n%d line" % ((counter), (counter+1)) + f.write(self.header + "\n" + line + "\n") + self.all_lines += line+"\n" + counter += 2 + + @mock.patch(patch_name, side_effect=mocked_open) + def test_multi_csv_download(self, mock_open): + qsf = TestQuerySalesforce() + + qsf.merge_batch_results(self.result_ids) + self.assertEqual(MockTarget(qsf.output().path).open('r').read(), self.all_lines)