From 74d4666a4190541ee11da3ec629a60b9b1ce6d9d Mon Sep 17 00:00:00 2001 From: Andy McCurdy Date: Wed, 31 Mar 2010 19:20:42 -0700 Subject: [PATCH] Pipeines can not optionally be transactions (wrapped in MULTI/EXEC) or not by passing the transaction parameter. This fixes #23. --- redis/client.py | 64 ++++++++++++++++++++++++++++++++++------------- tests/pipeline.py | 22 ++++++++++------ 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/redis/client.py b/redis/client.py index 496e28f097..1c0e8fa082 100644 --- a/redis/client.py +++ b/redis/client.py @@ -3,6 +3,7 @@ import socket import threading import warnings +from itertools import chain from redis.exceptions import ConnectionError, ResponseError, InvalidResponse from redis.exceptions import RedisError, AuthenticationError @@ -246,8 +247,20 @@ def _get_db(self): return self.connection.db db = property(_get_db) - def pipeline(self): - return Pipeline(self.connection, self.encoding, self.errors) + def pipeline(self, transaction=True): + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from multiple atomic operations, + pipelines are useful for batch loading of data as they reduce the + number of back and forth network operations between client and server. + """ + return Pipeline( + self.connection, + transaction, + self.encoding, + self.errors + ) #### COMMAND EXECUTION AND PROTOCOL PARSING #### @@ -1032,8 +1045,9 @@ class Pipeline(Redis): ResponseError exceptions, such as those raised when issuing a command on a key of a different datatype. """ - def __init__(self, connection, charset, errors): + def __init__(self, connection, transaction, charset, errors): self.connection = connection + self.transaction = transaction self.encoding = charset self.errors = errors self.subscribed = False # NOTE not in use, but necessary @@ -1041,7 +1055,6 @@ def __init__(self, connection, charset, errors): def reset(self): self.command_stack = [] - self.execute_command('MULTI') def _execute_command(self, command_name, command, **options): """ @@ -1066,19 +1079,20 @@ def _execute_command(self, command_name, command, **options): self.command_stack.append((command_name, command, options)) return self - def _execute(self, commands): - # build up all commands into a single request to increase network perf - all_cmds = ''.join([c for _1, c, _2 in commands]) + def _execute_transaction(self, commands): + # wrap the commands in MULTI ... EXEC statements to indicate an + # atomic operation + all_cmds = ''.join([c for _1, c, _2 in chain( + (('', 'MULTI\r\n', ''),), + commands, + (('', 'EXEC\r\n', ''),) + )]) self.connection.send(all_cmds, self) - # we only care about the last item in the response, which should be - # the EXEC command - for i in range(len(commands)-1): + # parse off the response for MULTI and all commands prior to EXEC + for i in range(len(commands)+1): _ = self.parse_response('_') - # tell the response parse to catch errors and return them as - # part of the response + # parse the EXEC. we want errors returned as items in the response response = self.parse_response('_', catch_errors=True) - # don't return the results of the MULTI or EXEC command - commands = [(c[0], c[2]) for c in commands[1:-1]] if len(response) != len(commands): raise ResponseError("Wrong number of response items from " "pipline execution") @@ -1087,20 +1101,34 @@ def _execute(self, commands): for r, cmd in zip(response, commands): if not isinstance(r, Exception): if cmd[0] in self.RESPONSE_CALLBACKS: - r = self.RESPONSE_CALLBACKS[cmd[0]](r, **cmd[1]) + r = self.RESPONSE_CALLBACKS[cmd[0]](r, **cmd[2]) data.append(r) return data + def _execute_pipeline(self, commands): + # build up all commands into a single request to increase network perf + all_cmds = ''.join([c for _1, c, _2 in commands]) + self.connection.send(all_cmds, self) + data = [] + for command_name, _, options in commands: + data.append( + self.parse_response(command_name, catch_errors=True, **options) + ) + return data + def execute(self): "Execute all the commands in the current pipeline" - self.execute_command('EXEC') stack = self.command_stack self.reset() + if self.transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline try: - return self._execute(stack) + return execute(stack) except ConnectionError: self.connection.disconnect() - return self._execute(stack) + return execute(stack) def select(self, *args, **kwargs): raise RedisError("Cannot select a different database from a pipeline") diff --git a/tests/pipeline.py b/tests/pipeline.py index 9e620e45e1..f9c8dfa887 100644 --- a/tests/pipeline.py +++ b/tests/pipeline.py @@ -5,10 +5,10 @@ class PipelineTestCase(unittest.TestCase): def setUp(self): self.client = redis.Redis(host='localhost', port=6379, db=9) self.client.flushdb() - + def tearDown(self): self.client.flushdb() - + def test_pipeline(self): pipe = self.client.pipeline() pipe.set('a', 'a1').get('a').zadd('z', 'z1', 1).zadd('z', 'z2', 4) @@ -23,14 +23,14 @@ def test_pipeline(self): [('z1', 2.0), ('z2', 4)], ] ) - + def test_pipeline_with_fresh_connection(self): redis.client.connection_manager.connections.clear() self.client = redis.Redis(host='localhost', port=6379, db=9) pipe = self.client.pipeline() pipe.set('a', 'b') self.assertEquals(pipe.execute(), [True]) - + def test_invalid_command_in_pipeline(self): # all commands but the invalid one should be excuted correctly self.client['c'] = 'a' @@ -53,10 +53,16 @@ def test_invalid_command_in_pipeline(self): self.assertEquals(pipe.set('z', 'zzz').execute(), [True]) self.assertEquals(self.client['z'], 'zzz') - def test_pipe_cannot_select(self): + def test_pipeline_cannot_select(self): pipe = self.client.pipeline() self.assertRaises(redis.RedisError, pipe.select, 'localhost', 6379, db=9) - - - + + def test_pipeline_no_transaction(self): + pipe = self.client.pipeline(transaction=False) + pipe.set('a', 'a1').set('b', 'b1').set('c', 'c1') + self.assertEquals(pipe.execute(), [True, True, True]) + self.assertEquals(self.client['a'], 'a1') + self.assertEquals(self.client['b'], 'b1') + self.assertEquals(self.client['c'], 'c1') +