diff --git a/cffi/cdefs.h b/cffi/cdefs.h index 70f09fa..c060203 100755 --- a/cffi/cdefs.h +++ b/cffi/cdefs.h @@ -85,6 +85,8 @@ int sr_session_stop(sr_session_ctx_t *); int sr_session_switch_ds(sr_session_ctx_t *, sr_datastore_t); sr_datastore_t sr_session_get_ds(sr_session_ctx_t *); sr_conn_ctx_t *sr_session_get_connection(sr_session_ctx_t *); +uint32_t sr_session_get_event_nc_id(sr_session_ctx_t *); +const char *sr_session_get_event_user(sr_session_ctx_t *); int sr_get_error(sr_session_ctx_t *, const sr_error_info_t **); int sr_set_error(sr_session_ctx_t *, const char *, const char *, ...); diff --git a/sysrepo/session.py b/sysrepo/session.py index 5c71fc1..4e5eee5 100755 --- a/sysrepo/session.py +++ b/sysrepo/session.py @@ -121,6 +121,30 @@ def set_error(self, xpath: Optional[str], message: str): lib.sr_set_error, self.cdata, str2c(xpath), str2c("%s"), str2c(message) ) + def get_netconf_id(self) -> int: + """ + It can only be called on an implicit sysrepo.Session (i.e., it can only be + called from an event callback) + + :returns: the NETCONF session ID set for the event originator sysrepo session + """ + if not self.is_implicit: + raise SysrepoUnsupportedError( + "can only report netconf id on implicit sessions" + ) + return lib.sr_session_get_event_nc_id(self.cdata) + + def get_user(self) -> str: + """ + It can only be called on an implicit sysrepo.Session (i.e., it can only be + called from an event callback) + + :returns: the effective username of the event originator sysrepo session + """ + if not self.is_implicit: + raise SysrepoUnsupportedError("can only report user on implicit sessions") + return c2str(lib.sr_session_get_event_user(self.cdata)) + def get_ly_ctx(self) -> libyang.Context: """ :returns: @@ -152,6 +176,13 @@ def get_ly_ctx(self) -> libyang.Context: have changed. :arg private_data: Private context opaque to sysrepo used when subscribing. + :arg kwargs (optional): + If the callback was registered with the argument extra_info=True (see + Session.subscribe_module_change), then extra keyword arguments are passed when + calling the callback: + * netconf_id: the NETCONF session ID set for the event originator + sysrepo session + * user: the effective username of the event originator sysrepo session When event is one of ("update", "change"), if the callback raises an exception, the changes will be rejected and the error will be forwarded to the client that made the @@ -174,7 +205,8 @@ def subscribe_module_change( enabled: bool = False, private_data: Any = None, asyncio_register: bool = False, - include_implicit_defaults: bool = True + include_implicit_defaults: bool = True, + extra_info: bool = False ) -> None: """ Subscribe for changes made in the specified module. @@ -210,6 +242,10 @@ def subscribe_module_change( monitored read file descriptors. Implies `no_thread=True`. :arg include_implicit_defaults: Include implicit default nodes in changes. + :arg extra_info: + When True, the given callback is called with extra keyword arguments + containing extra information of the sysrepo session that gave origin to the + event (see ModuleChangeCallbackType for more details) """ if self.is_implicit: raise SysrepoUnsupportedError("cannot subscribe with implicit sessions") @@ -220,6 +256,7 @@ def subscribe_module_change( private_data, asyncio_register=asyncio_register, include_implicit_defaults=include_implicit_defaults, + extra_info=extra_info, ) sub_p = ffi.new("sr_subscription_ctx_t **") @@ -253,6 +290,13 @@ def subscribe_module_change( module operational data. :arg private_data: Private context opaque to sysrepo used when subscribing. + :arg kwargs (optional): + If the callback was registered with the argument extra_info=True (see + Session.subscribe_module_change), then extra keyword arguments are passed when + calling the callback: + * netconf_id: the NETCONF session ID set for the event originator + sysrepo session + * user: the effective username of the event originator sysrepo session The callback is expected to return a python dictionary containing the operational data. The dictionary should be in the libyang "dict" format. It will be parsed to a @@ -272,7 +316,8 @@ def subscribe_oper_data_request( no_thread: bool = False, private_data: Any = None, asyncio_register: bool = False, - strict: bool = False + strict: bool = False, + extra_info: bool = False ) -> None: """ Register for providing operational data at the given xpath. @@ -296,13 +341,21 @@ def subscribe_oper_data_request( :arg strict: Reject the whole data returned by callback if it contains elements without schema definition. + :arg extra_info: + When True, the given callback is called with extra keyword arguments + containing extra information of the sysrepo session that gave origin to the + event (see OperDataCallbackType for more details) """ if self.is_implicit: raise SysrepoUnsupportedError("cannot subscribe with implicit sessions") _check_subscription_callback(callback, self.OperDataCallbackType) sub = Subscription( - callback, private_data, asyncio_register=asyncio_register, strict=strict + callback, + private_data, + asyncio_register=asyncio_register, + strict=strict, + extra_info=extra_info, ) sub_p = ffi.new("sr_subscription_ctx_t **") @@ -369,6 +422,13 @@ def subscribe_oper_data_request( will be called with 'abort'. :arg private_data: Private context opaque to sysrepo used when subscribing. + :arg kwargs (optional): + If the callback was registered with the argument extra_info=True (see + Session.subscribe_module_change), then extra keyword arguments are passed when + calling the callback: + * netconf_id: the NETCONF session ID set for the event originator + sysrepo session + * user: the effective username of the event originator sysrepo session The callback is expected to return a python dictionary containing the RPC output data. The dictionary should be in the libyang "dict" format and must only contain @@ -393,7 +453,8 @@ def subscribe_rpc_call( private_data: Any = None, asyncio_register: bool = False, strict: bool = False, - include_implicit_defaults: bool = True + include_implicit_defaults: bool = True, + extra_info: bool = False ) -> None: """ Subscribe for the delivery of an RPC/action. @@ -418,6 +479,10 @@ def subscribe_rpc_call( schema definition. :arg include_implicit_defaults: Include implicit defaults into input parameters passed to callbacks. + :arg extra_info: + When True, the given callback is called with extra keyword arguments + containing extra information of the sysrepo session that gave origin to the + event (see RpcCallbackType for more details) """ if self.is_implicit: raise SysrepoUnsupportedError("cannot subscribe with implicit sessions") @@ -429,6 +494,7 @@ def subscribe_rpc_call( asyncio_register=asyncio_register, strict=strict, include_implicit_defaults=include_implicit_defaults, + extra_info=extra_info, ) sub_p = ffi.new("sr_subscription_ctx_t **") @@ -480,6 +546,13 @@ def subscribe_rpc_call( Timestamp of the notification as an unsigned 32-bits integer. :arg private_data: Private context opaque to sysrepo used when subscribing. + :arg kwargs (optional): + If the callback was registered with the argument extra_info=True (see + Session.subscribe_module_change), then extra keyword arguments are passed when + calling the callback: + * netconf_id: the NETCONF session ID set for the event originator + sysrepo session + * user: the effective username of the event originator sysrepo session """ def subscribe_notification( @@ -492,7 +565,8 @@ def subscribe_notification( stop_time: int = 0, no_thread: bool = False, asyncio_register: bool = False, - private_data: Any = None + private_data: Any = None, + extra_info: bool = False ) -> None: """ Subscribe for the delivery of a notification. @@ -516,6 +590,10 @@ def subscribe_notification( read file descriptors. Implies no_thread=True. :arg private_data: Private context passed to the callback function, opaque to sysrepo. + :arg extra_info: + When True, the given callback is called with extra keyword arguments + containing extra information of the sysrepo session that gave origin to the + event (see RpcCallbackType for more details) """ if self.is_implicit: @@ -526,6 +604,7 @@ def subscribe_notification( callback, private_data, asyncio_register=asyncio_register, + extra_info=extra_info, ) sub_p = ffi.new("sr_subscription_ctx_t **") diff --git a/sysrepo/subscription.py b/sysrepo/subscription.py index a204458..8bb1776 100755 --- a/sysrepo/subscription.py +++ b/sysrepo/subscription.py @@ -33,6 +33,7 @@ def __init__( asyncio_register: bool = False, strict: bool = False, include_implicit_defaults: bool = True, + extra_info: bool = False, ): """ :arg callback: @@ -49,6 +50,10 @@ def __init__( :arg include_implicit_defaults: If True, include implicit default nodes into Change objects passed to module change callbacks and into input parameters passed to RPC/action callbacks. + :arg extra_info: + When True, the given callback is called with extra keyword arguments + containing extra information of the sysrepo session that gave origin to the + event """ if is_async_func(callback) and not asyncio_register: raise ValueError( @@ -59,6 +64,7 @@ def __init__( self.asyncio_register = asyncio_register self.strict = strict self.include_implicit_defaults = include_implicit_defaults + self.extra_info = extra_info if asyncio_register: self.loop = asyncio.get_event_loop() else: @@ -214,6 +220,13 @@ def module_change_callback(session, module, xpath, event, req_id, priv): callback = subscription.callback private_data = subscription.private_data event_name = EVENT_NAMES[event] + if subscription.extra_info: + extra_info = { + "netconf_id": session.get_netconf_id(), + "user": session.get_user(), + } + else: + extra_info = {} if is_async_func(callback): task_id = (event, req_id) @@ -230,7 +243,7 @@ def module_change_callback(session, module, xpath, event, req_id, priv): ) ) task = subscription.loop.create_task( - callback(event_name, req_id, changes, private_data) + callback(event_name, req_id, changes, private_data, **extra_info) ) task.add_done_callback( functools.partial(subscription.task_done, task_id, event_name) @@ -257,7 +270,7 @@ def module_change_callback(session, module, xpath, event, req_id, priv): include_implicit_defaults=subscription.include_implicit_defaults, ) ) - callback(event_name, req_id, changes, private_data) + callback(event_name, req_id, changes, private_data, **extra_info) return lib.SR_ERR_OK @@ -328,12 +341,21 @@ def oper_data_callback(session, module, xpath, req_xpath, req_id, parent, priv): subscription = ffi.from_handle(priv) callback = subscription.callback private_data = subscription.private_data + if subscription.extra_info: + extra_info = { + "netconf_id": session.get_netconf_id(), + "user": session.get_user(), + } + else: + extra_info = {} if is_async_func(callback): task_id = req_id if task_id not in subscription.tasks: - task = subscription.loop.create_task(callback(req_xpath, private_data)) + task = subscription.loop.create_task( + callback(req_xpath, private_data, **extra_info) + ) task.add_done_callback( functools.partial(subscription.task_done, task_id, "oper") ) @@ -349,7 +371,7 @@ def oper_data_callback(session, module, xpath, req_xpath, req_id, parent, priv): oper_data = task.result() else: - oper_data = callback(req_xpath, private_data) + oper_data = callback(req_xpath, private_data, **extra_info) if isinstance(oper_data, dict): # convert oper_data to a libyang.DNode object @@ -438,13 +460,20 @@ def rpc_callback(session, xpath, input_node, event, req_id, output_node, priv): ).values() ) ) + if subscription.extra_info: + extra_info = { + "netconf_id": session.get_netconf_id(), + "user": session.get_user(), + } + else: + extra_info = {} if is_async_func(callback): task_id = (event, req_id) if task_id not in subscription.tasks: task = subscription.loop.create_task( - callback(xpath, input_dict, event_name, private_data) + callback(xpath, input_dict, event_name, private_data, **extra_info) ) task.add_done_callback( functools.partial(subscription.task_done, task_id, event_name) @@ -461,7 +490,9 @@ def rpc_callback(session, xpath, input_node, event, req_id, output_node, priv): output_dict = task.result() else: - output_dict = callback(xpath, input_dict, event_name, private_data) + output_dict = callback( + xpath, input_dict, event_name, private_data, **extra_info + ) if event != lib.SR_EV_RPC: # May happen when there are multiple callback registered for the @@ -543,15 +574,27 @@ def event_notif_tree_callback(session, notif_type, notif, timestamp, priv): ).values() ) ) + if subscription.extra_info: + extra_info = { + "netconf_id": session.get_netconf_id(), + "user": session.get_user(), + } + else: + extra_info = {} + if is_async_func(callback): task = subscription.loop.create_task( - callback(xpath, notif_type, notif_dict, timestamp, private_data) + callback( + xpath, notif_type, notif_dict, timestamp, private_data, **extra_info + ) ) task.add_done_callback( functools.partial(subscription.task_done, None, "notif") ) else: - callback(xpath, notif_type, notif_dict, timestamp, private_data) + callback( + xpath, notif_type, notif_dict, timestamp, private_data, **extra_info + ) except BaseException: # ATTENTION: catch all exceptions! diff --git a/tests/test_session.py b/tests/test_session.py index b3631f1..5552619 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -125,3 +125,11 @@ def iface(name, field): } }, ) + + def test_get_netconf_id_and_get_user_are_only_available_in_implicit_session(self): + with self.conn.start_session("running") as sess: + with self.assertRaises(sysrepo.SysrepoUnsupportedError): + sess.get_netconf_id() + + with self.assertRaises(sysrepo.SysrepoUnsupportedError): + sess.get_user() diff --git a/tests/test_subs_module_change.py b/tests/test_subs_module_change.py index f999c25..6565e36 100644 --- a/tests/test_subs_module_change.py +++ b/tests/test_subs_module_change.py @@ -1,6 +1,7 @@ # Copyright (c) 2020 6WIND S.A. # SPDX-License-Identifier: BSD-3-Clause +import getpass import logging import os import unittest @@ -21,17 +22,24 @@ def setUpClass(cls): with sysrepo.SysrepoConnection() as conn: conn.install_module(YANG_FILE, enabled_features=["turbo"]) cls.conn = sysrepo.SysrepoConnection(err_on_sched_fail=True) - cls.sess = cls.conn.start_session() @classmethod def tearDownClass(cls): - cls.sess.stop() cls.conn.remove_module("sysrepo-example") cls.conn.disconnect() # reconnect to make sure module is removed with sysrepo.SysrepoConnection(err_on_sched_fail=True): pass + def setUp(self): + with self.conn.start_session("running") as sess: + sess.delete_item("/sysrepo-example:conf") + sess.apply_changes() + self.sess = self.conn.start_session() + + def tearDown(self): + self.sess.stop() + def test_module_change_sub(self): priv = object() current_config = {} @@ -200,3 +208,36 @@ def module_change_cb(event, req_id, changes, private_data): sent_config, "sysrepo-example", strict=True, wait=True ) self.assertEqual(current_config, sent_config) + + def test_module_change_sub_with_extra_info(self): + priv = object() + calls = [] + + def module_change_cb(event, req_id, changes, private_data, **kwargs): + self.assertIn(event, ("change", "done", "abort")) + self.assertIsInstance(req_id, int) + self.assertIsInstance(changes, list) + self.assertIs(private_data, priv) + self.assertIn("user", kwargs) + self.assertEqual(getpass.getuser(), kwargs["user"]) + self.assertIn("netconf_id", kwargs) + self.assertIsInstance(kwargs["netconf_id"], int) + calls.append((event, req_id, changes, private_data, kwargs)) + + self.sess.subscribe_module_change( + "sysrepo-example", + "/sysrepo-example:conf", + module_change_cb, + private_data=priv, + extra_info=True, + ) + + with self.conn.start_session("running") as ch_sess: + sent_config = {"conf": {"system": {"hostname": "bar"}}} + ch_sess.replace_config( + sent_config, "sysrepo-example", strict=True, wait=True + ) + # Successful change callbacks are called twice: + # * once with event "change" + # * once with event "done" + self.assertEqual(2, len(calls)) diff --git a/tests/test_subs_notification.py b/tests/test_subs_notification.py index 5bbf679..e3d2d91 100755 --- a/tests/test_subs_notification.py +++ b/tests/test_subs_notification.py @@ -1,6 +1,7 @@ # Copyright (c) 2020 6WIND S.A. # SPDX-License-Identifier: BSD-3-Clause +import getpass import logging import os import threading @@ -32,30 +33,49 @@ def tearDownClass(cls): with sysrepo.SysrepoConnection(err_on_sched_fail=True): pass - def _test_notification_sub(self, notif_xpath: str, notif_dict: typing.Dict): + def _test_notification_sub( + self, + notif_xpath: str, + notif_dict: typing.Dict, + request_extra_info: bool = False, + ): priv = object() callback_called = threading.Event() - def notif_cb(xpath, notification_type, notification, timestamp, private_data): + def notif_cb( + xpath, notification_type, notification, timestamp, private_data, **kwargs + ): self.assertEqual(xpath, notif_xpath) self.assertEqual(notification_type, "realtime") self.assertEqual(notification, notif_dict) self.assertIsInstance(timestamp, int) self.assertAlmostEqual(timestamp, int(time.time()), delta=5) self.assertEqual(private_data, priv) + if request_extra_info: + self.assertIn("user", kwargs) + self.assertEqual(getpass.getuser(), kwargs["user"]) + self.assertIn("netconf_id", kwargs) + self.assertIsInstance(kwargs["netconf_id"], int) + else: + self.assertEqual(0, len(kwargs)) callback_called.set() - with self.conn.start_session() as sess: - sess.subscribe_notification( - "sysrepo-example", notif_xpath, notif_cb, private_data=priv + with self.conn.start_session() as listening_session: + if request_extra_info: + kwargs = {"extra_info": True} + else: + kwargs = {} + listening_session.subscribe_notification( + "sysrepo-example", notif_xpath, notif_cb, private_data=priv, **kwargs ) - sess.notification_send(notif_xpath, notif_dict) - self.assertTrue( - callback_called.wait(timeout=1), - "Timed-out while waiting for the notification callback to be called", - ) + with self.conn.start_session() as sending_session: + sending_session.notification_send(notif_xpath, notif_dict) + self.assertTrue( + callback_called.wait(timeout=1), + "Timed-out while waiting for the notification callback to be called", + ) def test_notification_top_level(self): self._test_notification_sub( @@ -68,3 +88,10 @@ def test_notification_nested_in_data_node(self): notif_xpath="/sysrepo-example:state/state-changed", notif_dict={"message": "Some state changed"}, ) + + def test_notification_sub_with_extra_info(self): + self._test_notification_sub( + notif_xpath="/sysrepo-example:state/state-changed", + notif_dict={"message": "Some state changed"}, + request_extra_info=True, + ) diff --git a/tests/test_subs_oper.py b/tests/test_subs_oper.py index 7c3e736..1f6da00 100644 --- a/tests/test_subs_oper.py +++ b/tests/test_subs_oper.py @@ -1,6 +1,7 @@ # Copyright (c) 2020 6WIND S.A. # SPDX-License-Identifier: BSD-3-Clause +import getpass import logging import os import unittest @@ -21,17 +22,21 @@ def setUpClass(cls): with sysrepo.SysrepoConnection() as conn: conn.install_module(YANG_FILE, enabled_features=["turbo"]) cls.conn = sysrepo.SysrepoConnection(err_on_sched_fail=True) - cls.sess = cls.conn.start_session() @classmethod def tearDownClass(cls): - cls.sess.stop() cls.conn.remove_module("sysrepo-example") cls.conn.disconnect() # reconnect to make sure module is removed with sysrepo.SysrepoConnection(err_on_sched_fail=True): pass + def setUp(self): + self.sess = self.conn.start_session() + + def tearDown(self): + self.sess.stop() + def test_oper_sub(self): priv = object() state = None @@ -69,3 +74,31 @@ def oper_data_cb(xpath, private_data): state = {"state": {"invalid": True}} with self.assertRaises(sysrepo.SysrepoCallbackFailedError): op_sess.get_data("/sysrepo-example:state") + + def test_oper_sub_with_extra_info(self): + priv = object() + calls = [] + + def oper_data_cb(xpath, private_data, **kwargs): + self.assertEqual(xpath, "/sysrepo-example:state") + self.assertEqual(private_data, priv) + self.assertIn("user", kwargs) + self.assertEqual(getpass.getuser(), kwargs["user"]) + self.assertIn("netconf_id", kwargs) + self.assertIsInstance(kwargs["netconf_id"], int) + calls.append((xpath, private_data, kwargs)) + return {"state": {}} + + self.sess.subscribe_oper_data_request( + "sysrepo-example", + "/sysrepo-example:state", + oper_data_cb, + private_data=priv, + strict=True, + extra_info=True, + ) + + with self.conn.start_session("operational") as op_sess: + oper_data = op_sess.get_data("/sysrepo-example:state") + self.assertEqual(len(calls), 1) + self.assertEqual(oper_data, {"state": {}}) diff --git a/tests/test_subs_rpc.py b/tests/test_subs_rpc.py index d9a6181..ab6a806 100644 --- a/tests/test_subs_rpc.py +++ b/tests/test_subs_rpc.py @@ -1,6 +1,7 @@ # Copyright (c) 2020 6WIND S.A. # SPDX-License-Identifier: BSD-3-Clause +import getpass import logging import os import unittest @@ -134,3 +135,30 @@ def module_change_cb(event, req_id, changes, private_data): self.assertEqual(len(calls), 1) self.assertEqual(calls[0], (xpath, {"duration": 1}, "rpc", priv)) del calls[:] + + def test_rpc_sub_with_extra_info(self): + priv = object() + calls = [] + rpc_xpath = "/sysrepo-example:poweroff" + + def rpc_cb(xpath, input_params, event, private_data, **kwargs): + self.assertEqual(rpc_xpath, xpath) + self.assertEqual(input_params, {"behaviour": "success"}) + self.assertEqual(event, "rpc") + self.assertIs(private_data, priv) + self.assertIn("user", kwargs) + self.assertEqual(getpass.getuser(), kwargs["user"]) + self.assertIn("netconf_id", kwargs) + self.assertIsInstance(kwargs["netconf_id"], int) + calls.append((xpath, input_params, event, private_data)) + return {"message": "bye bye"} + + with self.conn.start_session() as sess: + sess.subscribe_rpc_call( + rpc_xpath, rpc_cb, private_data=priv, strict=True, extra_info=True + ) + + with self.conn.start_session() as rpc_sess: + output = rpc_sess.rpc_send(rpc_xpath, {"behaviour": "success"}) + self.assertEqual(len(calls), 1) + self.assertEqual(output, {"message": "bye bye"})