Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass extra session info to subscription callbacks that requested so #27

Merged
merged 1 commit into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cffi/cdefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ int sr_session_stop(sr_session_ctx_t *);
int sr_session_switch_ds(sr_session_ctx_t *, sr_datastore_t);
sr_datastore_t sr_session_get_ds(sr_session_ctx_t *);
sr_conn_ctx_t *sr_session_get_connection(sr_session_ctx_t *);
uint32_t sr_session_get_event_nc_id(sr_session_ctx_t *);
const char *sr_session_get_event_user(sr_session_ctx_t *);
int sr_get_error(sr_session_ctx_t *, const sr_error_info_t **);
int sr_set_error(sr_session_ctx_t *, const char *, const char *, ...);

Expand Down
89 changes: 84 additions & 5 deletions sysrepo/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,30 @@ def set_error(self, xpath: Optional[str], message: str):
lib.sr_set_error, self.cdata, str2c(xpath), str2c("%s"), str2c(message)
)

def get_netconf_id(self) -> int:
"""
It can only be called on an implicit sysrepo.Session (i.e., it can only be
called from an event callback)

:returns: the NETCONF session ID set for the event originator sysrepo session
"""
if not self.is_implicit:
raise SysrepoUnsupportedError(
"can only report netconf id on implicit sessions"
)
return lib.sr_session_get_event_nc_id(self.cdata)

def get_user(self) -> str:
"""
It can only be called on an implicit sysrepo.Session (i.e., it can only be
called from an event callback)

:returns: the effective username of the event originator sysrepo session
"""
if not self.is_implicit:
raise SysrepoUnsupportedError("can only report user on implicit sessions")
return c2str(lib.sr_session_get_event_user(self.cdata))

def get_ly_ctx(self) -> libyang.Context:
"""
:returns:
Expand Down Expand Up @@ -152,6 +176,13 @@ def get_ly_ctx(self) -> libyang.Context:
have changed.
:arg private_data:
Private context opaque to sysrepo used when subscribing.
:arg kwargs (optional):
If the callback was registered with the argument extra_info=True (see
Session.subscribe_module_change), then extra keyword arguments are passed when
calling the callback:
* netconf_id: the NETCONF session ID set for the event originator
sysrepo session
* user: the effective username of the event originator sysrepo session

When event is one of ("update", "change"), if the callback raises an exception, the
changes will be rejected and the error will be forwarded to the client that made the
Expand All @@ -174,7 +205,8 @@ def subscribe_module_change(
enabled: bool = False,
private_data: Any = None,
asyncio_register: bool = False,
include_implicit_defaults: bool = True
include_implicit_defaults: bool = True,
extra_info: bool = False
) -> None:
"""
Subscribe for changes made in the specified module.
Expand Down Expand Up @@ -210,6 +242,10 @@ def subscribe_module_change(
monitored read file descriptors. Implies `no_thread=True`.
:arg include_implicit_defaults:
Include implicit default nodes in changes.
:arg extra_info:
When True, the given callback is called with extra keyword arguments
containing extra information of the sysrepo session that gave origin to the
event (see ModuleChangeCallbackType for more details)
"""
if self.is_implicit:
raise SysrepoUnsupportedError("cannot subscribe with implicit sessions")
Expand All @@ -220,6 +256,7 @@ def subscribe_module_change(
private_data,
asyncio_register=asyncio_register,
include_implicit_defaults=include_implicit_defaults,
extra_info=extra_info,
)
sub_p = ffi.new("sr_subscription_ctx_t **")

Expand Down Expand Up @@ -253,6 +290,13 @@ def subscribe_module_change(
module operational data.
:arg private_data:
Private context opaque to sysrepo used when subscribing.
:arg kwargs (optional):
If the callback was registered with the argument extra_info=True (see
Session.subscribe_module_change), then extra keyword arguments are passed when
calling the callback:
* netconf_id: the NETCONF session ID set for the event originator
sysrepo session
* user: the effective username of the event originator sysrepo session

The callback is expected to return a python dictionary containing the operational
data. The dictionary should be in the libyang "dict" format. It will be parsed to a
Expand All @@ -272,7 +316,8 @@ def subscribe_oper_data_request(
no_thread: bool = False,
private_data: Any = None,
asyncio_register: bool = False,
strict: bool = False
strict: bool = False,
extra_info: bool = False
) -> None:
"""
Register for providing operational data at the given xpath.
Expand All @@ -296,13 +341,21 @@ def subscribe_oper_data_request(
:arg strict:
Reject the whole data returned by callback if it contains elements without
schema definition.
:arg extra_info:
When True, the given callback is called with extra keyword arguments
containing extra information of the sysrepo session that gave origin to the
event (see OperDataCallbackType for more details)
"""
if self.is_implicit:
raise SysrepoUnsupportedError("cannot subscribe with implicit sessions")
_check_subscription_callback(callback, self.OperDataCallbackType)

sub = Subscription(
callback, private_data, asyncio_register=asyncio_register, strict=strict
callback,
private_data,
asyncio_register=asyncio_register,
strict=strict,
extra_info=extra_info,
)
sub_p = ffi.new("sr_subscription_ctx_t **")

Expand Down Expand Up @@ -369,6 +422,13 @@ def subscribe_oper_data_request(
will be called with 'abort'.
:arg private_data:
Private context opaque to sysrepo used when subscribing.
:arg kwargs (optional):
If the callback was registered with the argument extra_info=True (see
Session.subscribe_module_change), then extra keyword arguments are passed when
calling the callback:
* netconf_id: the NETCONF session ID set for the event originator
sysrepo session
* user: the effective username of the event originator sysrepo session

The callback is expected to return a python dictionary containing the RPC output
data. The dictionary should be in the libyang "dict" format and must only contain
Expand All @@ -393,7 +453,8 @@ def subscribe_rpc_call(
private_data: Any = None,
asyncio_register: bool = False,
strict: bool = False,
include_implicit_defaults: bool = True
include_implicit_defaults: bool = True,
extra_info: bool = False
) -> None:
"""
Subscribe for the delivery of an RPC/action.
Expand All @@ -418,6 +479,10 @@ def subscribe_rpc_call(
schema definition.
:arg include_implicit_defaults:
Include implicit defaults into input parameters passed to callbacks.
:arg extra_info:
When True, the given callback is called with extra keyword arguments
containing extra information of the sysrepo session that gave origin to the
event (see RpcCallbackType for more details)
"""
if self.is_implicit:
raise SysrepoUnsupportedError("cannot subscribe with implicit sessions")
Expand All @@ -429,6 +494,7 @@ def subscribe_rpc_call(
asyncio_register=asyncio_register,
strict=strict,
include_implicit_defaults=include_implicit_defaults,
extra_info=extra_info,
)
sub_p = ffi.new("sr_subscription_ctx_t **")

Expand Down Expand Up @@ -480,6 +546,13 @@ def subscribe_rpc_call(
Timestamp of the notification as an unsigned 32-bits integer.
:arg private_data:
Private context opaque to sysrepo used when subscribing.
:arg kwargs (optional):
If the callback was registered with the argument extra_info=True (see
Session.subscribe_module_change), then extra keyword arguments are passed when
calling the callback:
* netconf_id: the NETCONF session ID set for the event originator
sysrepo session
* user: the effective username of the event originator sysrepo session
"""

def subscribe_notification(
Expand All @@ -492,7 +565,8 @@ def subscribe_notification(
stop_time: int = 0,
no_thread: bool = False,
asyncio_register: bool = False,
private_data: Any = None
private_data: Any = None,
extra_info: bool = False
) -> None:
"""
Subscribe for the delivery of a notification.
Expand All @@ -516,6 +590,10 @@ def subscribe_notification(
read file descriptors. Implies no_thread=True.
:arg private_data:
Private context passed to the callback function, opaque to sysrepo.
:arg extra_info:
When True, the given callback is called with extra keyword arguments
containing extra information of the sysrepo session that gave origin to the
event (see RpcCallbackType for more details)
"""

if self.is_implicit:
Expand All @@ -526,6 +604,7 @@ def subscribe_notification(
callback,
private_data,
asyncio_register=asyncio_register,
extra_info=extra_info,
)

sub_p = ffi.new("sr_subscription_ctx_t **")
Expand Down
59 changes: 51 additions & 8 deletions sysrepo/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
asyncio_register: bool = False,
strict: bool = False,
include_implicit_defaults: bool = True,
extra_info: bool = False,
):
"""
:arg callback:
Expand All @@ -49,6 +50,10 @@ def __init__(
:arg include_implicit_defaults:
If True, include implicit default nodes into Change objects passed to module
change callbacks and into input parameters passed to RPC/action callbacks.
:arg extra_info:
When True, the given callback is called with extra keyword arguments
containing extra information of the sysrepo session that gave origin to the
event
"""
if is_async_func(callback) and not asyncio_register:
raise ValueError(
Expand All @@ -59,6 +64,7 @@ def __init__(
self.asyncio_register = asyncio_register
self.strict = strict
self.include_implicit_defaults = include_implicit_defaults
self.extra_info = extra_info
if asyncio_register:
self.loop = asyncio.get_event_loop()
else:
Expand Down Expand Up @@ -214,6 +220,13 @@ def module_change_callback(session, module, xpath, event, req_id, priv):
callback = subscription.callback
private_data = subscription.private_data
event_name = EVENT_NAMES[event]
if subscription.extra_info:
extra_info = {
"netconf_id": session.get_netconf_id(),
"user": session.get_user(),
}
else:
extra_info = {}

if is_async_func(callback):
task_id = (event, req_id)
Expand All @@ -230,7 +243,7 @@ def module_change_callback(session, module, xpath, event, req_id, priv):
)
)
task = subscription.loop.create_task(
callback(event_name, req_id, changes, private_data)
callback(event_name, req_id, changes, private_data, **extra_info)
)
task.add_done_callback(
functools.partial(subscription.task_done, task_id, event_name)
Expand All @@ -257,7 +270,7 @@ def module_change_callback(session, module, xpath, event, req_id, priv):
include_implicit_defaults=subscription.include_implicit_defaults,
)
)
callback(event_name, req_id, changes, private_data)
callback(event_name, req_id, changes, private_data, **extra_info)

return lib.SR_ERR_OK

Expand Down Expand Up @@ -328,12 +341,21 @@ def oper_data_callback(session, module, xpath, req_xpath, req_id, parent, priv):
subscription = ffi.from_handle(priv)
callback = subscription.callback
private_data = subscription.private_data
if subscription.extra_info:
extra_info = {
"netconf_id": session.get_netconf_id(),
"user": session.get_user(),
}
else:
extra_info = {}

if is_async_func(callback):
task_id = req_id

if task_id not in subscription.tasks:
task = subscription.loop.create_task(callback(req_xpath, private_data))
task = subscription.loop.create_task(
callback(req_xpath, private_data, **extra_info)
)
task.add_done_callback(
functools.partial(subscription.task_done, task_id, "oper")
)
Expand All @@ -349,7 +371,7 @@ def oper_data_callback(session, module, xpath, req_xpath, req_id, parent, priv):
oper_data = task.result()

else:
oper_data = callback(req_xpath, private_data)
oper_data = callback(req_xpath, private_data, **extra_info)

if isinstance(oper_data, dict):
# convert oper_data to a libyang.DNode object
Expand Down Expand Up @@ -438,13 +460,20 @@ def rpc_callback(session, xpath, input_node, event, req_id, output_node, priv):
).values()
)
)
if subscription.extra_info:
extra_info = {
"netconf_id": session.get_netconf_id(),
"user": session.get_user(),
}
else:
extra_info = {}

if is_async_func(callback):
task_id = (event, req_id)

if task_id not in subscription.tasks:
task = subscription.loop.create_task(
callback(xpath, input_dict, event_name, private_data)
callback(xpath, input_dict, event_name, private_data, **extra_info)
)
task.add_done_callback(
functools.partial(subscription.task_done, task_id, event_name)
Expand All @@ -461,7 +490,9 @@ def rpc_callback(session, xpath, input_node, event, req_id, output_node, priv):
output_dict = task.result()

else:
output_dict = callback(xpath, input_dict, event_name, private_data)
output_dict = callback(
xpath, input_dict, event_name, private_data, **extra_info
)

if event != lib.SR_EV_RPC:
# May happen when there are multiple callback registered for the
Expand Down Expand Up @@ -543,15 +574,27 @@ def event_notif_tree_callback(session, notif_type, notif, timestamp, priv):
).values()
)
)
if subscription.extra_info:
extra_info = {
"netconf_id": session.get_netconf_id(),
"user": session.get_user(),
}
else:
extra_info = {}

if is_async_func(callback):
task = subscription.loop.create_task(
callback(xpath, notif_type, notif_dict, timestamp, private_data)
callback(
xpath, notif_type, notif_dict, timestamp, private_data, **extra_info
)
)
task.add_done_callback(
functools.partial(subscription.task_done, None, "notif")
)
else:
callback(xpath, notif_type, notif_dict, timestamp, private_data)
callback(
xpath, notif_type, notif_dict, timestamp, private_data, **extra_info
)

except BaseException:
# ATTENTION: catch all exceptions!
Expand Down
8 changes: 8 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,11 @@ def iface(name, field):
}
},
)

def test_get_netconf_id_and_get_user_are_only_available_in_implicit_session(self):
with self.conn.start_session("running") as sess:
with self.assertRaises(sysrepo.SysrepoUnsupportedError):
sess.get_netconf_id()

with self.assertRaises(sysrepo.SysrepoUnsupportedError):
sess.get_user()
Loading