Skip to content

Commit

Permalink
Add support for multiple results in salesforce batch jobs (#1686)
Browse files Browse the repository at this point in the history
* #1685 Add support for multiple results in batch jobs
Adds a new method 'get_batch_result_ids' and maintains backwards compatibility
with the old method, adding a warning.

* Add the ability to merge result sets and basic testing.
Note that this only supports CSV since this is what the library does by default.
  • Loading branch information
dacort authored and Tarrasch committed May 19, 2016
1 parent 41be6b6 commit 2a0ff1f
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 16 deletions.
78 changes: 62 additions & 16 deletions luigi/contrib/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
import abc
import logging
import warnings
import xml.etree.ElementTree as ET
from collections import OrderedDict
import re
Expand Down Expand Up @@ -139,6 +140,11 @@ def is_soql_file(self):
"""Override to True if soql property is a file path."""
return False

@property
def content_type(self):
"""Override to use a different content type. (e.g. XML)"""
return "CSV"

def run(self):
if self.use_sandbox and not self.sandbox_name:
raise Exception("Parameter sf_sandbox_name must be provided when uploading to a Salesforce Sandbox")
Expand Down Expand Up @@ -170,11 +176,23 @@ def run(self):
if 'foreign key relationships not supported' not in status['state_message'].lower():
raise Exception(msg)
else:
result_id = sf.get_batch_results(job_id, batch_id)
data = sf.get_batch_result(job_id, batch_id, result_id)

with open(self.output().fn, 'w') as outfile:
outfile.write(data)
result_ids = sf.get_batch_result_ids(job_id, batch_id)

# If there's only one result, just download it, otherwise we need to merge the resulting downloads
if len(result_ids) == 1:
data = sf.get_batch_result(job_id, batch_id, result_ids[0])
with open(self.output().path, 'w') as outfile:
outfile.write(data)
else:
# Download each file to disk, and then merge into one.
# Preferring to do it this way so as to minimize memory consumption.
for i, result_id in enumerate(result_ids):
logger.info("Downloading batch result %s for batch: %s and job: %s" % (result_id, batch_id, job_id))
with open("%s.%d" % (self.output().path, i), 'w') as outfile:
outfile.write(sf.get_batch_result(job_id, batch_id, result_id))

logger.info("Merging results of batch %s" % batch_id)
self.merge_batch_results(result_ids)
finally:
logger.info("Closing job %s" % job_id)
sf.close_job(job_id)
Expand All @@ -184,11 +202,30 @@ def run(self):
data_file = sf.query_all(self.soql)

reader = csv.reader(data_file)
with open(self.output().fn, 'w') as outfile:
with open(self.output().path, 'w') as outfile:

This comment has been minimized.

Copy link
@dlstadther

dlstadther Jun 21, 2016

Collaborator

Again, I'm sorry for not asking these questions awhile back... For what purpose did you change this?

This comment has been minimized.

Copy link
@dlstadther

dlstadther Jun 21, 2016

Collaborator

self.output().fn and self.output().path seem to be equal (at least for me).

This comment has been minimized.

Copy link
@dlstadther

dlstadther Jun 21, 2016

Collaborator

@Tarrasch @erikbern Correct me if i'm wrong. self.output().fn should be used over self.output().path because class variables shouldn't be accessed directly - the property fn was provided to provide the value of path.

self.output().fn calls the property fn from LocalTarget (which returns the value of path, whereas self.output().path calls the self.path object from the __init__() method (as set by LocalTarget due to its call to the super constructor of FileSystemTarget, where self.path=path)

This comment has been minimized.

Copy link
@dlstadther

dlstadther Jun 21, 2016

Collaborator

It seems that mock.py and file.py contradict each other here. For a mock object, you must use self.output().path, but for a file object, the property is self.output().fn - although you can call path directly without error.

It seems like they should use the same property name. Thoughts?

For reference, hadoop, scalding, and sge all use output().path.

writer = csv.writer(outfile, dialect='excel')
for row in reader:
writer.writerow(row)

def merge_batch_results(self, result_ids):
"""
Merges the resulting files of a multi-result batch bulk query.
"""
outfile = open(self.output().path, 'w')

if self.content_type == 'CSV':
for i, result_id in enumerate(result_ids):
with open("%s.%d" % (self.output().path, i), 'r') as f:
header = f.readline()
if i == 0:
outfile.write(header)
for line in f:
outfile.write(line)
else:
raise Exception("Batch result merging not implemented for %s" % self.content_type)

This comment has been minimized.

Copy link
@dlstadther

dlstadther Jun 21, 2016

Collaborator

self.content_type is fine here.


outfile.close()


class SalesforceAPI(object):
"""
Expand Down Expand Up @@ -353,15 +390,17 @@ def restful(self, path, params):
else:
return json_result

def create_operation_job(self, operation, obj, external_id_field_name=None, content_type='CSV'):
def create_operation_job(self, operation, obj, external_id_field_name=None, content_type=None):
"""
Creates a new SF job that for doing any operation (insert, upsert, update, delete, query)
:param operation: delete, insert, query, upsert, update, hardDelete. Must be lowercase.
:param obj: Parent SF object
:param external_id_field_name: Optional.
:param content_type: XML, CSV, ZIP_CSV, or ZIP_XML. Defaults to CSV
"""
if content_type is None:
content_type = self.content_type

if not self.has_active_session():

This comment has been minimized.

Copy link
@dlstadther

dlstadther Jun 21, 2016

Collaborator

Sorry to just now notice this....but self.content_type is invalid. The property content_type you defined is within QuerySalesforce, not SalesforceAPI.

self.start_session()

Expand Down Expand Up @@ -419,7 +458,7 @@ def close_job(self, job_id):

return response

def create_batch(self, job_id, data, file_type='csv'):
def create_batch(self, job_id, data, file_type=None):
"""
Creates a batch with either a string of data or a file containing data.
Expand All @@ -429,13 +468,15 @@ def create_batch(self, job_id, data, file_type='csv'):
:param job_id: job_id as returned by 'create_operation_job(...)'
:param data:
:param file_type:
:return: Returns batch_id
"""
if not job_id or not self.has_active_session():
raise Exception("Can not create a batch without a valid job_id and an active session.")

if file_type is None:
file_type = self.content_type.lower()

headers = self._get_create_batch_content_headers(file_type)
headers['Content-Length'] = len(data)

Expand Down Expand Up @@ -473,22 +514,27 @@ def block_on_batch(self, job_id, batch_id, sleep_time_seconds=5, max_wait_time_s

def get_batch_results(self, job_id, batch_id):
"""
Get results of a batch that has completed processing.
If the batch is a CSV file, the response is in CSV format.
If the batch is an XML file, the response is in XML format.
DEPRECATED: Use `get_batch_result_ids`
"""
warnings.warn("get_batch_results is deprecated and only returns one batch result. Please use get_batch_result_ids")
return self.get_batch_result_ids(job_id, batch_id)[0]

def get_batch_result_ids(self, job_id, batch_id):
"""
Get result IDs of a batch that has completed processing.
:param job_id: job_id as returned by 'create_operation_job(...)'
:param batch_id: batch_id as returned by 'create_batch(...)'
:return: batch result response as either CSV or XML, dependent on the batch
:return: list of batch result IDs to be used in 'get_batch_result(...)'
"""
response = requests.get(self._get_batch_results_url(job_id, batch_id),
headers=self._get_batch_info_headers())
response.raise_for_status()

root = ET.fromstring(response.text)
result = root.find('%sresult' % self.API_NS).text
result_ids = [r.text for r in root.findall('%sresult' % self.API_NS)]

return result
return result_ids

def get_batch_result(self, job_id, batch_id, result_id):
"""
Expand Down
117 changes: 117 additions & 0 deletions test/contrib/salesforce_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
#
# Copyright (c) 2016 Simply Measured
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
# This method will be used by the mock to replace requests.get

"""
Unit test for the Salesforce contrib package
"""

from luigi.contrib.salesforce import SalesforceAPI, QuerySalesforce

from helpers import unittest
import mock
from luigi.mock import MockTarget
from luigi.six import PY3
import re


def mocked_requests_get(*args, **kwargs):
class MockResponse:
def __init__(self, body, status_code):
self.body = body
self.status_code = status_code

@property
def text(self):
return self.body

def raise_for_status(self):
return None

result_list = (
'<result-list xmlns="http://www.force.com/2009/06/asyncapi/dataload">'
'<result>1234</result><result>1235</result><result>1236</result>'
'</result-list>'
)
return MockResponse(result_list, 200)

# Keep open around so we can use it in the mock responses
old__open = open


def mocked_open(*args, **kwargs):
if re.match("job_data", args[0]):
return MockTarget(args[0]).open(args[1])
else:
return old__open(*args)


class TestSalesforceAPI(unittest.TestCase):
# We patch 'requests.get' with our own method. The mock object is passed in to our test case method.
@mock.patch('requests.get', side_effect=mocked_requests_get)
def test_deprecated_results(self, mock_get):
sf = SalesforceAPI('xx', 'xx', 'xx')
result_id = sf.get_batch_results('job_id', 'batch_id')
self.assertEqual('1234', result_id)

@mock.patch('requests.get', side_effect=mocked_requests_get)
def test_result_ids(self, mock_get):
sf = SalesforceAPI('xx', 'xx', 'xx')
result_ids = sf.get_batch_result_ids('job_id', 'batch_id')
self.assertEqual(['1234', '1235', '1236'], result_ids)


class TestQuerySalesforce(QuerySalesforce):
def output(self):
return MockTarget('job_data.csv')

@property
def object_name(self):
return 'dual'

@property
def soql(self):
return "SELECT * FROM %s" % self.object_name


class TestSalesforceQuery(unittest.TestCase):
patch_name = '__builtin__.open'
if PY3:
patch_name = 'builtins.open'

@mock.patch(patch_name, side_effect=mocked_open)
def setUp(self, mock_open):
MockTarget.fs.clear()
self.result_ids = ['a', 'b', 'c']

counter = 1
self.all_lines = "Lines\n"
self.header = "Lines"
for i, id in enumerate(self.result_ids):
filename = "%s.%d" % ('job_data.csv', i)
with MockTarget(filename).open('w') as f:
line = "%d line\n%d line" % ((counter), (counter+1))
f.write(self.header + "\n" + line + "\n")
self.all_lines += line+"\n"
counter += 2

@mock.patch(patch_name, side_effect=mocked_open)
def test_multi_csv_download(self, mock_open):
qsf = TestQuerySalesforce()

qsf.merge_batch_results(self.result_ids)
self.assertEqual(MockTarget(qsf.output().path).open('r').read(), self.all_lines)

1 comment on commit 2a0ff1f

@dlstadther
Copy link
Collaborator

@dlstadther dlstadther commented on 2a0ff1f Jun 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be submitting a fix to my comments. Basically, if you're going to set the default content_type to csv, then content_type will never be None. I'll submit the PR soon.

I apologize for not catching these breaking changes when the PR was pending.

Please sign in to comment.