Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ServiceBus] Adjust AutoLockRenewer to only allow registration of intended types (ReceivedMessage and ServiceBusSession) #14600

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from .._servicebus_receiver import ServiceBusReceiver
from .._servicebus_session import ServiceBusSession
from .message import ServiceBusReceivedMessage
from ..exceptions import AutoLockRenewFailed, AutoLockRenewTimeout, ServiceBusError
from .utils import renewable_start_time, utc_now

if TYPE_CHECKING:
from typing import Callable, Union, Optional
from .message import ServiceBusReceivedMessage
from typing import Callable, Union, Optional, Awaitable
LockRenewFailureCallback = Callable[[Union[ServiceBusSession, ServiceBusReceivedMessage],
Optional[Exception]], None]
Renewable = Union[ServiceBusSession, ServiceBusReceivedMessage]
Expand Down Expand Up @@ -144,6 +144,10 @@ def register(self, receiver, renewable, timeout=300, on_lock_renew_failure=None)

:rtype: None
"""
if not isinstance(renewable, (ServiceBusReceivedMessage, ServiceBusSession)):
raise TypeError("AutoLockRenewer only supports registration of types "
"azure.servicebus.ServiceBusReceivedMessage (via a receiver's receive methods) and "
"azure.servicebus.ServiceBusSession (via a session receiver's property receiver.session).")
if self._shutdown.is_set():
raise ServiceBusError("The AutoLockRenewer has already been shutdown. Please create a new instance for"
" auto lock renewing.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import datetime
import time
import logging
import functools
Expand Down Expand Up @@ -48,6 +47,7 @@


if TYPE_CHECKING:
import datetime
from azure.core.credentials import TokenCredential

_LOGGER = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ def register(
Default value is None (no callback).
:rtype: None
"""
if not isinstance(renewable, (ServiceBusReceivedMessage, ServiceBusSession)):
raise TypeError("AutoLockRenewer only supports registration of types "
"azure.servicebus.ServiceBusReceivedMessage (via a receiver's receive methods) and "
"azure.servicebus.aio.ServiceBusSession "
"(via a session receiver's property receiver.session).")
if self._shutdown.is_set():
raise ServiceBusError("The AutoLockRenewer has already been shutdown. Please create a new instance for"
" auto lock renewing.")
Expand Down
13 changes: 11 additions & 2 deletions sdk/servicebus/azure-servicebus/tests/async_tests/mocks_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta

from azure.servicebus._common.utils import utc_now
from azure.servicebus import ServiceBusReceivedMessage

class MockReceiver:
def __init__(self):
Expand All @@ -13,7 +14,7 @@ async def renew_message_lock(self, message):
message.locked_until_utc = message.locked_until_utc + timedelta(seconds=message._lock_duration)


class MockReceivedMessage:
class MockReceivedMessage(ServiceBusReceivedMessage):
def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
self._lock_duration = 2

Expand All @@ -29,4 +30,12 @@ def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
def _lock_expired(self):
if self.locked_until_utc and self.locked_until_utc <= utc_now():
return True
return False
return False

@property
def locked_until_utc(self):
return self._locked_until_utc

@locked_until_utc.setter
def locked_until_utc(self, value):
self._locked_until_utc = value
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
MessageContentTooLarge,
OperationTimeoutError
)
from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer
from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer, AzureTestCase
from servicebus_preparer import CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer, ServiceBusQueuePreparer
from utilities import get_logger, print_message, sleep_until_expired
from mocks_async import MockReceivedMessage
from mocks_async import MockReceivedMessage, MockReceiver

_logger = get_logger(logging.DEBUG)

Expand Down Expand Up @@ -1135,7 +1135,7 @@ async def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_lin
assert len(messages) == 1
await receiver.complete_message(messages[0])

@pytest.mark.asyncio
@AzureTestCase.await_prepared_test
async def test_async_queue_mock_auto_lock_renew_callback(self):
results = []
errors = []
Expand All @@ -1144,11 +1144,16 @@ async def callback_mock(renewable, error):
if error:
errors.append(error)

receiver = MockReceiver()
auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1 # So we can run the test fast.
async with auto_lock_renew: # Check that it is called when the object expires for any reason (silent renew failure)
with pytest.raises(TypeError):
auto_lock_renew.register(receiver, renewable=Exception()) # an arbitrary invalid type.

auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1 # So we can run the test fast.
async with auto_lock_renew: # Check that it is called when the object expires for any reason (silent renew failure)
message = MockReceivedMessage(prevent_renew_lock=True)
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
await asyncio.sleep(3)
assert len(results) == 1 and results[-1]._lock_expired == True
assert not errors
Expand All @@ -1157,8 +1162,8 @@ async def callback_mock(renewable, error):
del errors[:]
auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1
async with auto_lock_renew: # Check that in normal operation it does not get called
auto_lock_renew.register(renewable=MockReceivedMessage(), on_lock_renew_failure=callback_mock)
async with auto_lock_renew: # Check that in normal operation it does not get called
auto_lock_renew.register(receiver, renewable=MockReceivedMessage(), on_lock_renew_failure=callback_mock)
await asyncio.sleep(3)
assert not results
assert not errors
Expand All @@ -1167,9 +1172,9 @@ async def callback_mock(renewable, error):
del errors[:]
auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1
async with auto_lock_renew: # Check that when a message is settled, it will not get called even after expiry
async with auto_lock_renew: # Check that when a message is settled, it will not get called even after expiry
message = MockReceivedMessage(prevent_renew_lock=True)
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
message._settled = True
await asyncio.sleep(3)
assert not results
Expand All @@ -1181,7 +1186,7 @@ async def callback_mock(renewable, error):
auto_lock_renew._renew_period = 1
async with auto_lock_renew: # Check that it is called when there is an overt renew failure
message = MockReceivedMessage(exception_on_renew_lock=True)
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
await asyncio.sleep(3)
assert len(results) == 1 and results[-1]._lock_expired == True
assert errors[-1]
Expand All @@ -1190,9 +1195,9 @@ async def callback_mock(renewable, error):
del errors[:]
auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1
async with auto_lock_renew: # Check that it is not called when the renewer is shutdown
async with auto_lock_renew: # Check that it is not called when the renewer is shutdown
message = MockReceivedMessage(prevent_renew_lock=True)
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
await auto_lock_renew.close()
await asyncio.sleep(3)
assert not results
Expand All @@ -1202,35 +1207,35 @@ async def callback_mock(renewable, error):
del errors[:]
auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1
async with auto_lock_renew: # Check that it is not called when the receiver is shutdown
async with auto_lock_renew: # Check that it is not called when the receiver is shutdown
message = MockReceivedMessage(prevent_renew_lock=True)
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
message._receiver._running = False
await asyncio.sleep(3)
assert not results
assert not errors


@pytest.mark.asyncio
@AzureTestCase.await_prepared_test
async def test_async_queue_mock_no_reusing_auto_lock_renew(self):
auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1

receiver = MockReceiver()
async with auto_lock_renew:
auto_lock_renew.register(renewable=MockReceivedMessage())
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())
await asyncio.sleep(3)

with pytest.raises(ServiceBusError):
async with auto_lock_renew:
pass

with pytest.raises(ServiceBusError):
auto_lock_renew.register(renewable=MockReceivedMessage())
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())

auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1

auto_lock_renew.register(renewable=MockReceivedMessage())
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())
time.sleep(3)

await auto_lock_renew.close()
Expand All @@ -1240,7 +1245,7 @@ async def test_async_queue_mock_no_reusing_auto_lock_renew(self):
pass

with pytest.raises(ServiceBusError):
auto_lock_renew.register(renewable=MockReceivedMessage())
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())

@pytest.mark.liveTest
@pytest.mark.live_test_only
Expand Down
13 changes: 11 additions & 2 deletions sdk/servicebus/azure-servicebus/tests/mocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta

from azure.servicebus._common.utils import utc_now
from azure.servicebus import ServiceBusReceivedMessage


class MockReceiver:
Expand All @@ -14,7 +15,7 @@ def renew_message_lock(self, message):
message.locked_until_utc = message.locked_until_utc + timedelta(seconds=message._lock_duration)


class MockReceivedMessage:
class MockReceivedMessage(ServiceBusReceivedMessage):
def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
self._lock_duration = 2

Expand All @@ -31,4 +32,12 @@ def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
def _lock_expired(self):
if self.locked_until_utc and self.locked_until_utc <= utc_now():
return True
return False
return False

@property
def locked_until_utc(self):
return self._locked_until_utc

@locked_until_utc.setter
def locked_until_utc(self, value):
self._locked_until_utc = value
4 changes: 4 additions & 0 deletions sdk/servicebus/azure-servicebus/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,10 @@ def callback_mock(renewable, error):
errors.append(error)

receiver = MockReceiver()
auto_lock_renew = AutoLockRenewer()
with pytest.raises(TypeError):
auto_lock_renew.register(Exception()) # an arbitrary invalid type.
yunhaoling marked this conversation as resolved.
Show resolved Hide resolved

auto_lock_renew = AutoLockRenewer()
auto_lock_renew._renew_period = 1 # So we can run the test fast.
with auto_lock_renew: # Check that it is called when the object expires for any reason (silent renew failure)
Expand Down