diff --git a/doc/api/class-mesh-connection.rst b/doc/api/class-mesh-connection.rst new file mode 100644 index 00000000..d1048eba --- /dev/null +++ b/doc/api/class-mesh-connection.rst @@ -0,0 +1,8 @@ + +.. currentmodule:: tarantool.mesh_connection + +class :class:`MeshConnection` +----------------------------- + +.. autoclass:: MeshConnection + diff --git a/doc/index.rst b/doc/index.rst index c89da773..346c656c 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -40,6 +40,7 @@ API Reference api/module-tarantool.rst api/class-connection.rst + api/class-mesh-connection.rst api/class-space.rst api/class-response.rst diff --git a/doc/index.ru.rst b/doc/index.ru.rst index 2708d8ab..ddf1b528 100644 --- a/doc/index.ru.rst +++ b/doc/index.ru.rst @@ -40,6 +40,7 @@ api/module-tarantool.rst api/class-connection.rst + api/class-mesh-connection.rst api/class-space.rst api/class-response.rst diff --git a/tarantool/const.py b/tarantool/const.py index 1ebad494..9d175974 100644 --- a/tarantool/const.py +++ b/tarantool/const.py @@ -86,3 +86,5 @@ RECONNECT_MAX_ATTEMPTS = 10 # Default delay between attempts to reconnect (seconds) RECONNECT_DELAY = 0.1 +# Default cluster nodes list refresh interval (seconds) +CLUSTER_DISCOVERY_DELAY = 60 diff --git a/tarantool/error.py b/tarantool/error.py index f49ba60c..cc66e8c5 100644 --- a/tarantool/error.py +++ b/tarantool/error.py @@ -43,6 +43,12 @@ class InterfaceError(Error): ''' +class ConfigurationError(Error): + ''' + Error of initialization with a user-provided configuration. + ''' + + # Monkey patch os.strerror for win32 if sys.platform == "win32": # Windows Sockets Error Codes (not all, but related on network errors) @@ -152,6 +158,11 @@ class NetworkWarning(UserWarning): pass +class ClusterDiscoveryWarning(UserWarning): + '''Warning related to cluster discovery''' + pass + + # always print this warnings warnings.filterwarnings("always", category=NetworkWarning) @@ -166,6 +177,7 @@ def warn(message, warning_class): line_no = frame.f_lineno warnings.warn_explicit(message, warning_class, module_name, line_no) + _strerror = { 0: ("ER_UNKNOWN", "Unknown error"), 1: ("ER_ILLEGAL_PARAMS", "Illegal parameters, %s"), diff --git a/tarantool/mesh_connection.py b/tarantool/mesh_connection.py index a2d69c56..d1ed851b 100644 --- a/tarantool/mesh_connection.py +++ b/tarantool/mesh_connection.py @@ -4,29 +4,182 @@ between tarantool instances and basic Round-Robin strategy. ''' +import time + + from tarantool.connection import Connection -from tarantool.error import NetworkError +from tarantool.error import ( + warn, + NetworkError, + DatabaseError, + ConfigurationError, + ClusterDiscoveryWarning, +) from tarantool.utils import ENCODING_DEFAULT from tarantool.const import ( + CONNECTION_TIMEOUT, SOCKET_TIMEOUT, RECONNECT_MAX_ATTEMPTS, - RECONNECT_DELAY + RECONNECT_DELAY, + CLUSTER_DISCOVERY_DELAY, +) + +from tarantool.request import ( + RequestCall ) +try: + string_types = basestring +except NameError: + string_types = str + + +def parse_uri(uri): + def parse_error(uri, msg): + msg = 'URI "%s": %s' % (uri, msg) + return None, msg + + if not uri: + return parse_error(uri, 'should not be None or empty string') + if not isinstance(uri, string_types): + return parse_error(uri, 'should be of a string type') + if uri.count(':') != 1: + return parse_error(uri, 'does not match host:port scheme') + + host, port_str = uri.split(':', 1) + if not host: + return parse_error(uri, 'host value is empty') + + try: + port = int(port_str) + except ValueError: + return parse_error(uri, 'port should be a number') + + return {'host': host, 'port': port}, None + + +def validate_address(address): + messages = [] + + if isinstance(address, dict): + if "host" not in address: + messages.append("host key must be set") + elif not isinstance(address["host"], string_types): + messages.append("host value must be string type") + + if "port" not in address: + messages.append("port is not set") + elif not isinstance(address["port"], int): + messages.append("port value must be int type") + elif address["port"] == 0: + messages.append("port value must not be zero") + elif address["port"] > 65535: + messages.append("port value must not be above 65535") + else: + messages.append("address must be a dict") + + if messages: + messages_str = ', '.join(messages) + msg = 'Address %s: %s' % (str(address), messages_str) + return None, msg + + return True, None + class RoundRobinStrategy(object): + """ + Simple round-robin address rotation + """ def __init__(self, addrs): - self.addrs = addrs - self.pos = 0 + self.update(addrs) + + def update(self, new_addrs): + # Verify new_addrs is a non-empty list. + assert new_addrs and isinstance(new_addrs, list) + + # Remove duplicates. + new_addrs_unique = [] + for addr in new_addrs: + if addr not in new_addrs_unique: + new_addrs_unique.append(addr) + new_addrs = new_addrs_unique + + # Save a current address if any. + if 'pos' in self.__dict__ and 'addrs' in self.__dict__: + current_addr = self.addrs[self.pos] + else: + current_addr = None + + # Determine a position of a current address (if any) in + # the new addresses list. + if current_addr and current_addr in new_addrs: + new_pos = new_addrs.index(current_addr) + else: + new_pos = -1 + + self.addrs = new_addrs + self.pos = new_pos def getnext(self): - tmp = self.pos self.pos = (self.pos + 1) % len(self.addrs) - return self.addrs[tmp] + return self.addrs[self.pos] class MeshConnection(Connection): - def __init__(self, addrs, + ''' + Represents a connection to a cluster of Tarantool servers. + + This class uses Connection to connect to one of the nodes of the cluster. + The initial list of nodes is passed to the constructor in 'addrs' parameter. + The class set in 'strategy_class' parameter is used to select a node from + the list and switch nodes in case of unavailability of the current node. + + 'cluster_discovery_function' param of the constructor sets the name of a + stored Lua function used to refresh the list of available nodes. The + function takes no parameters and returns a list of strings in format + 'host:port'. A generic function for getting the list of nodes looks like + this: + + .. code-block:: lua + + function get_cluster_nodes() + return { + '192.168.0.1:3301', + '192.168.0.2:3302', + -- ... + } + end + + You may put in this list whatever you need depending on your cluster + topology. Chances are you'll want to make the list of nodes from nodes' + replication config. Here is an example for it: + + .. code-block:: lua + + local uri_lib = require('uri') + + function get_cluster_nodes() + local nodes = {} + + local replicas = box.cfg.replication + + for i = 1, #replicas do + local uri = uri_lib.parse(replicas[i]) + + if uri.host and uri.service then + table.insert(nodes, uri.host .. ':' .. uri.service) + end + end + + -- if your replication config doesn't contain the current node + -- you have to add it manually like this: + table.insert(nodes, '192.168.0.1:3301') + + return nodes + end + ''' + + def __init__(self, host=None, port=None, user=None, password=None, socket_timeout=SOCKET_TIMEOUT, @@ -34,32 +187,156 @@ def __init__(self, addrs, reconnect_delay=RECONNECT_DELAY, connect_now=True, encoding=ENCODING_DEFAULT, - strategy_class=RoundRobinStrategy): - self.nattempts = 2 * len(addrs) + 1 + call_16=False, + connection_timeout=CONNECTION_TIMEOUT, + addrs=None, + strategy_class=RoundRobinStrategy, + cluster_discovery_function=None, + cluster_discovery_delay=CLUSTER_DISCOVERY_DELAY): + if addrs is None: + addrs = [] + else: + # Don't change user provided arguments. + addrs = addrs[:] + + if host and port: + addrs.insert(0, {'host': host, 'port': port}) + + # Verify that at least one address is provided. + if not addrs: + raise ConfigurationError( + 'Neither "host" and "port", nor "addrs" arguments are set') + + # Verify addresses. + for addr in addrs: + ok, msg = validate_address(addr) + if not ok: + raise ConfigurationError(msg) + + self.strategy_class = strategy_class self.strategy = strategy_class(addrs) + addr = self.strategy.getnext() host = addr['host'] port = addr['port'] - super(MeshConnection, self).__init__(host=host, - port=port, - user=user, - password=password, - socket_timeout=socket_timeout, - reconnect_max_attempts=reconnect_max_attempts, - reconnect_delay=reconnect_delay, - connect_now=connect_now, - encoding=encoding) + + self.cluster_discovery_function = cluster_discovery_function + self.cluster_discovery_delay = cluster_discovery_delay + self.last_nodes_refresh = 0 + + super(MeshConnection, self).__init__( + host=host, + port=port, + user=user, + password=password, + socket_timeout=socket_timeout, + reconnect_max_attempts=reconnect_max_attempts, + reconnect_delay=reconnect_delay, + connect_now=connect_now, + encoding=encoding, + call_16=call_16, + connection_timeout=connection_timeout) + + def connect(self): + super(MeshConnection, self).connect() + if self.connected and self.cluster_discovery_function: + self._opt_refresh_instances() def _opt_reconnect(self): - nattempts = self.nattempts - while nattempts > 0: + ''' + Attempt to connect "reconnect_max_attempts" times to each + available address. + ''' + + last_error = None + for _ in range(len(self.strategy.addrs)): try: super(MeshConnection, self)._opt_reconnect() + last_error = None break - except NetworkError: - nattempts -= 1 + except NetworkError as e: + last_error = e addr = self.strategy.getnext() - self.host = addr['host'] - self.port = addr['port'] - else: - raise NetworkError + self.host = addr["host"] + self.port = addr["port"] + + if last_error: + raise last_error + + def _opt_refresh_instances(self): + ''' + Refresh list of tarantool instances in a cluster. + Reconnect if a current instance was gone from the list. + ''' + now = time.time() + + if not self.connected or not self.cluster_discovery_function or \ + now - self.last_nodes_refresh < self.cluster_discovery_delay: + return + + # Call a cluster discovery function w/o reconnection. If + # something going wrong: warn about that and ignore. + request = RequestCall(self, self.cluster_discovery_function, (), + self.call_16) + try: + resp = self._send_request_wo_reconnect(request) + except DatabaseError as e: + msg = 'got "%s" error, skipped addresses updating' % str(e) + warn(msg, ClusterDiscoveryWarning) + return + + if not resp.data or not resp.data[0] or \ + not isinstance(resp.data[0], list): + msg = "got incorrect response instead of URI list, " + \ + "skipped addresses updating" + warn(msg, ClusterDiscoveryWarning) + return + + # Validate received address list. + new_addrs = [] + for uri in resp.data[0]: + addr, msg = parse_uri(uri) + if not addr: + warn(msg, ClusterDiscoveryWarning) + continue + + ok, msg = validate_address(addr) + if not ok: + warn(msg, ClusterDiscoveryWarning) + continue + + new_addrs.append(addr) + + if not new_addrs: + msg = "got no correct URIs, skipped addresses updating" + warn(msg, ClusterDiscoveryWarning) + return + + self.strategy.update(new_addrs) + self.last_nodes_refresh = now + + # Disconnect from a current instance if it was gone from + # an instance list and connect to one of new instances. + current_addr = {'host': self.host, 'port': self.port} + if current_addr not in self.strategy.addrs: + self.close() + addr = self.strategy.getnext() + self.host = addr['host'] + self.port = addr['port'] + self._opt_reconnect() + + def _send_request(self, request): + ''' + Update instances list if "cluster_discovery_function" is provided and a + last update was more then "cluster_discovery_delay" seconds ago. + + After that perform a request as usual and return an instance of + `Response` class. + + :param request: object representing a request + :type request: `Request` instance + + :rtype: `Response` instance + ''' + self._opt_refresh_instances() + return super(MeshConnection, self)._send_request(request) diff --git a/unit/suites/__init__.py b/unit/suites/__init__.py index 3f59862e..ead75297 100644 --- a/unit/suites/__init__.py +++ b/unit/suites/__init__.py @@ -8,9 +8,10 @@ from .test_dml import TestSuite_Request from .test_protocol import TestSuite_Protocol from .test_reconnect import TestSuite_Reconnect +from .test_mesh import TestSuite_Mesh test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol, - TestSuite_Reconnect) + TestSuite_Reconnect, TestSuite_Mesh) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/unit/suites/test_mesh.py b/unit/suites/test_mesh.py new file mode 100644 index 00000000..dda59a89 --- /dev/null +++ b/unit/suites/test_mesh.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- + +from __future__ import print_function + +import sys +import unittest +import warnings +from time import sleep +import tarantool +from tarantool.error import ( + ConfigurationError, + ClusterDiscoveryWarning, +) +from .lib.tarantool_server import TarantoolServer + + +def create_server(_id): + srv = TarantoolServer() + srv.script = 'unit/suites/box.lua' + srv.start() + srv.admin("box.schema.user.create('test', {password = 'test', " + + "if_not_exists = true})") + srv.admin("box.schema.user.grant('test', 'execute', 'universe')") + + # Create srv_id function (for testing purposes). + srv.admin("function srv_id() return %s end" % _id) + return srv + + +@unittest.skipIf(sys.platform.startswith("win"), + 'Mesh tests on windows platform are not supported') +class TestSuite_Mesh(unittest.TestCase): + def define_cluster_function(self, func_name, servers): + addresses = [(srv.host, srv.args['primary']) for srv in servers] + addresses_lua = ",".join("'%s:%d'" % address for address in addresses) + func_body = """ + function %s() + return {%s} + end + """ % (func_name, addresses_lua) + for srv in self.servers: + srv.admin(func_body) + + def define_custom_cluster_function(self, func_name, retval): + func_body = """ + function %s() + return %s + end + """ % (func_name, retval) + for srv in self.servers: + srv.admin(func_body) + + @classmethod + def setUpClass(self): + print(' MESH '.center(70, '='), file=sys.stderr) + print('-' * 70, file=sys.stderr) + + def setUp(self): + # Create two servers and extract helpful fields for tests. + self.srv = create_server(1) + self.srv2 = create_server(2) + self.servers = [self.srv, self.srv2] + self.host_1 = self.srv.host + self.port_1 = self.srv.args['primary'] + self.host_2 = self.srv2.host + self.port_2 = self.srv2.args['primary'] + + # Create get_all_nodes() function on servers. + self.get_all_nodes_func_name = 'get_all_nodes' + self.define_cluster_function(self.get_all_nodes_func_name, + self.servers) + + def test_01_contructor(self): + # Verify that an error is risen when no addresses are + # configured (neither with host/port, nor with addrs). + with self.assertRaises(ConfigurationError): + tarantool.MeshConnection() + + # Verify that a bad address given at initialization leads + # to an error. + bad_addrs = [ + {"port": 1234}, # no host + {"host": "localhost"}, # no port + {"host": "localhost", "port": "1234"}, # port is str + ] + for bad_addr in bad_addrs: + with self.assertRaises(ConfigurationError): + con = tarantool.MeshConnection(bad_addr.get('host'), + bad_addr.get('port')) + with self.assertRaises(ConfigurationError): + con = tarantool.MeshConnection(addrs=[bad_addr]) + + # Verify that identical addresses are squashed. + addrs = [{"host": "localhost", "port": 1234}] + con = tarantool.MeshConnection("localhost", 1234, addrs=addrs, + connect_now=False) + self.assertEqual(len(con.strategy.addrs), 1) + + def test_02_discovery_bad_address(self): + retvals = [ + "", + "1", + "'localhost:1234'", + "{}", + "error('raise an error')", + "{'localhost:foo'}", + "{'localhost:0'}", + "{'localhost:65536'}", + "{'localhost:1234:5678'}", + "{':1234'}", + "{'localhost:'}", + ] + for retval in retvals: + func_name = 'bad_cluster_discovery' + self.define_custom_cluster_function(func_name, retval) + con = tarantool.MeshConnection(self.host_1, self.port_1, + user='test', password='test') + con.cluster_discovery_function = func_name + + # Verify that a cluster discovery (that is triggered + # by ping) give one or two warnings. + with warnings.catch_warnings(record=True) as ws: + con.ping() + self.assertTrue(len(ws) in (1, 2)) + for w in ws: + self.assertIs(w.category, ClusterDiscoveryWarning) + + # Verify that incorrect or empty result was discarded. + self.assertEqual(len(con.strategy.addrs), 1) + self.assertEqual(con.strategy.addrs[0]['host'], self.host_1) + self.assertEqual(con.strategy.addrs[0]['port'], self.port_1) + + con.close() + + def test_03_discovery_bad_good_addresses(self): + func_name = 'bad_and_good_addresses' + retval = "{'localhost:', '%s:%d'}" % (self.host_2, self.port_2) + self.define_custom_cluster_function(func_name, retval) + con = tarantool.MeshConnection(self.host_1, self.port_1, + user='test', password='test') + con.cluster_discovery_function = func_name + + # Verify that a cluster discovery (that is triggered + # by ping) give one warning. + with warnings.catch_warnings(record=True) as ws: + con.ping() + self.assertEqual(len(ws), 1) + self.assertIs(ws[0].category, ClusterDiscoveryWarning) + + # Verify that only second address was accepted. + self.assertEqual(len(con.strategy.addrs), 1) + self.assertEqual(con.strategy.addrs[0]['host'], self.host_2) + self.assertEqual(con.strategy.addrs[0]['port'], self.port_2) + + con.close() + + def test_04_discovery_add_address(self): + # Create a mesh connection; pass only the first server + # address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=self.get_all_nodes_func_name, + connect_now=False) + + # Verify that the strategy has one address that comes from + # the constructor arguments. + self.assertEqual(len(con.strategy.addrs), 1) + con.connect() + + # Verify that we work with the first server. + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 1) + + # Verify that the refresh was successful and the strategy + # has 2 addresses. + self.assertEqual(len(con.strategy.addrs), 2) + + con.close() + + def test_05_discovery_delay(self): + # Create a mesh connection, pass only the first server address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=self.get_all_nodes_func_name, + cluster_discovery_delay=1) + + # Verify that the strategy has two addresses come from + # the function right after connecting. + self.assertEqual(len(con.strategy.addrs), 2) + + # Drop addresses list to the initial state. + con.strategy.update([con.strategy.addrs[0], ]) + + # Verify that the discovery will not be performed until + # 'cluster_discovery_delay' seconds will be passed. + con.ping() + self.assertEqual(len(con.strategy.addrs), 1) + + sleep(1.1) + + # Refresh after cluster_discovery_delay. + con.ping() + self.assertEqual(len(con.strategy.addrs), 2) + + con.close() + + def test_06_reconnection(self): + # Create a mesh connection; pass only the first server + # address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=self.get_all_nodes_func_name) + + con.last_nodes_refresh = 0 + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 1) + + # Verify that the last discovery was successful and the + # strategy has 2 addresses. + self.assertEqual(len(con.strategy.addrs), 2) + + self.srv.stop() + + # Verify that we switched to the second server. + with warnings.catch_warnings(): + # Suppress reconnection warnings. + warnings.simplefilter("ignore") + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 2) + + con.close() + + def test_07_discovery_exclude_address(self): + # Define function to get back only second server. + func_name = 'get_second_node' + self.define_cluster_function(func_name, [self.srv2]) + + # Create a mesh connection, pass only the first server address. + con = tarantool.MeshConnection( + self.host_1, self.port_1, user='test', password='test', + cluster_discovery_function=func_name) + + # Verify that discovery was successful and the strategy + # has 1 address. + self.assertEqual(len(con.strategy.addrs), 1) + + # Verify that the current server is second one. + resp = con.call('srv_id') + self.assertEqual(resp.data and resp.data[0], 2) + + con.close() + + def tearDown(self): + self.srv.stop() + self.srv.clean() + + self.srv2.stop() + self.srv2.clean()