diff --git a/lms/djangoapps/discussion/django_comment_client/base/tests_v2.py b/lms/djangoapps/discussion/django_comment_client/base/tests_v2.py index 7bc84e5038c0..1f5ae7805740 100644 --- a/lms/djangoapps/discussion/django_comment_client/base/tests_v2.py +++ b/lms/djangoapps/discussion/django_comment_client/base/tests_v2.py @@ -180,7 +180,8 @@ def test_flag(self): with mock.patch( "openedx.core.djangoapps.django_comment_common.signals.thread_flagged.send" ) as signal_mock: - response = self.call_view("flag_abuse_for_thread", "update_thread_flag") + with self.captureOnCommitCallbacks(execute=True): + response = self.call_view("flag_abuse_for_thread", "update_thread_flag") self._assert_json_response_contains_group_info(response) self.assertEqual(signal_mock.call_count, 1) response = self.call_view("un_flag_abuse_for_thread", "update_thread_flag") @@ -471,10 +472,15 @@ def setUp(self): def assert_discussion_signals(self, signal, user=None): if user is None: user = self.student + # Use captureOnCommitCallbacks to execute on_commit callbacks during tests, + # since signals are now deferred until after transaction commit. + # Order matters: assert_signal_sent must be outer context so the signal + # fires (via captureOnCommitCallbacks) before the assertion check. with self.assert_signal_sent( views, signal, sender=None, user=user, exclude_args=("post",) ): - yield + with self.captureOnCommitCallbacks(execute=True): + yield def test_create_thread(self): with self.assert_discussion_signals("thread_created"): @@ -1218,7 +1224,8 @@ def test_flag(self): with mock.patch( "openedx.core.djangoapps.django_comment_common.signals.comment_flagged.send" ) as signal_mock: - self.call_view("flag_abuse_for_comment", "update_comment_flag") + with self.captureOnCommitCallbacks(execute=True): + self.call_view("flag_abuse_for_comment", "update_comment_flag") self.assertEqual(signal_mock.call_count, 1) diff --git a/lms/djangoapps/discussion/django_comment_client/base/views.py b/lms/djangoapps/discussion/django_comment_client/base/views.py index 95d5a020108f..14ce9c4b575a 100644 --- a/lms/djangoapps/discussion/django_comment_client/base/views.py +++ b/lms/djangoapps/discussion/django_comment_client/base/views.py @@ -50,6 +50,7 @@ prepare_content, sanitize_body ) +from lms.djangoapps.discussion.rest_api.utils import send_signal_after_commit from openedx.core.djangoapps.django_comment_common.signals import ( comment_created, comment_deleted, @@ -587,7 +588,10 @@ def create_thread(request, course_id, commentable_id): thread.save() - thread_created.send(sender=None, user=user, post=thread) + # Use send_signal_after_commit() to ensure the signal is sent only after the transaction commits. + send_signal_after_commit( + lambda: thread_created.send(sender=None, user=user, post=thread) + ) # patch for backward compatibility to comments service if 'pinned' not in thread.attributes: @@ -598,7 +602,9 @@ def create_thread(request, course_id, commentable_id): if follow: cc_user = cc.User.from_django_user(user) cc_user.follow(thread, course_id) - thread_followed.send(sender=None, user=user, post=thread) + send_signal_after_commit( + lambda: thread_followed.send(sender=None, user=user, post=thread) + ) data = thread.to_dict() @@ -645,7 +651,9 @@ def update_thread(request, course_id, thread_id): thread.save() - thread_edited.send(sender=None, user=user, post=thread) + send_signal_after_commit( + lambda: thread_edited.send(sender=None, user=user, post=thread) + ) track_thread_edited_event(request, course, thread, None) if request.headers.get('x-requested-with') == 'XMLHttpRequest': @@ -688,7 +696,9 @@ def _create_comment(request, course_key, thread_id=None, parent_id=None): ) comment.save(params={"course_id": str(course_key)}) - comment_created.send(sender=None, user=user, post=comment) + send_signal_after_commit( + lambda: comment_created.send(sender=None, user=user, post=comment) + ) followed = post.get('auto_subscribe', 'false').lower() == 'true' @@ -729,7 +739,9 @@ def delete_thread(request, course_id, thread_id): course = get_course_with_access(request.user, 'load', course_key) thread = cc.Thread.find(thread_id) thread.delete(course_id=course_id) - thread_deleted.send(sender=None, user=request.user, post=thread) + send_signal_after_commit( + lambda: thread_deleted.send(sender=None, user=request.user, post=thread) + ) track_thread_deleted_event(request, course, thread) return JsonResponse(prepare_content(thread.to_dict(), course_key)) @@ -751,7 +763,9 @@ def update_comment(request, course_id, comment_id): comment.body = sanitize_body(request.POST["body"]) comment.save(params={"course_id": course_id}) - comment_edited.send(sender=None, user=request.user, post=comment) + send_signal_after_commit( + lambda: comment_edited.send(sender=None, user=request.user, post=comment) + ) track_comment_edited_event(request, course, comment, None) if request.headers.get('x-requested-with') == 'XMLHttpRequest': @@ -776,7 +790,9 @@ def endorse_comment(request, course_id, comment_id): comment.endorsed = endorsed comment.endorsement_user_id = user.id comment.save(params={"course_id": course_id}) - comment_endorsed.send(sender=None, user=user, post=comment) + send_signal_after_commit( + lambda: comment_endorsed.send(sender=None, user=user, post=comment) + ) track_forum_response_mark_event(request, course, comment, endorsed) return JsonResponse(prepare_content(comment.to_dict(), course_key)) @@ -828,7 +844,9 @@ def delete_comment(request, course_id, comment_id): course = get_course_with_access(request.user, 'load', course_key) comment = cc.Comment.find(comment_id) comment.delete(course_id=course_id) - comment_deleted.send(sender=None, user=request.user, post=comment) + send_signal_after_commit( + lambda: comment_deleted.send(sender=None, user=request.user, post=comment) + ) track_comment_deleted_event(request, course, comment) return JsonResponse(prepare_content(comment.to_dict(), course_key)) @@ -847,7 +865,9 @@ def _vote_or_unvote(request, course_id, obj, value='up', undo_vote=False): # (People could theoretically downvote by handcrafting AJAX requests.) else: user.vote(obj, value, course_id) - thread_voted.send(sender=None, user=request.user, post=obj) + send_signal_after_commit( + lambda: thread_voted.send(sender=None, user=request.user, post=obj) + ) track_voted_event(request, course, obj, value, undo_vote) return JsonResponse(prepare_content(obj.to_dict(), course_key)) @@ -861,7 +881,9 @@ def vote_for_comment(request, course_id, comment_id, value): """ comment = cc.Comment.find(comment_id) result = _vote_or_unvote(request, course_id, comment, value) - comment_voted.send(sender=None, user=request.user, post=comment) + send_signal_after_commit( + lambda: comment_voted.send(sender=None, user=request.user, post=comment) + ) return result @@ -914,7 +936,9 @@ def flag_abuse_for_thread(request, course_id, thread_id): thread = cc.Thread.find(thread_id) thread.flagAbuse(user, thread, course_id) track_discussion_reported_event(request, course, thread) - thread_flagged.send(sender='flag_abuse_for_thread', user=request.user, post=thread) + send_signal_after_commit( + lambda: thread_flagged.send(sender='flag_abuse_for_thread', user=request.user, post=thread) + ) return JsonResponse(prepare_content(thread.to_dict(), course_key)) @@ -953,7 +977,9 @@ def flag_abuse_for_comment(request, course_id, comment_id): comment = cc.Comment.find(comment_id) comment.flagAbuse(user, comment, course_id) track_discussion_reported_event(request, course, comment) - comment_flagged.send(sender='flag_abuse_for_comment', user=request.user, post=comment) + send_signal_after_commit( + lambda: comment_flagged.send(sender='flag_abuse_for_comment', user=request.user, post=comment) + ) return JsonResponse(prepare_content(comment.to_dict(), course_key)) @@ -1019,7 +1045,9 @@ def follow_thread(request, course_id, thread_id): # lint-amnesty, pylint: disab course = get_course_by_id(course_key) thread = cc.Thread.find(thread_id) user.follow(thread, course_id=course_id) - thread_followed.send(sender=None, user=request.user, post=thread) + send_signal_after_commit( + lambda: thread_followed.send(sender=None, user=request.user, post=thread) + ) track_thread_followed_event(request, course, thread, True) return JsonResponse({}) @@ -1051,7 +1079,9 @@ def unfollow_thread(request, course_id, thread_id): # lint-amnesty, pylint: dis user = cc.User.from_django_user(request.user) thread = cc.Thread.find(thread_id) user.unfollow(thread, course_id=course_id) - thread_unfollowed.send(sender=None, user=request.user, post=thread) + send_signal_after_commit( + lambda: thread_unfollowed.send(sender=None, user=request.user, post=thread) + ) track_thread_followed_event(request, course, thread, False) return JsonResponse({}) diff --git a/lms/djangoapps/discussion/rest_api/api.py b/lms/djangoapps/discussion/rest_api/api.py index b87852c16cfa..1c0b05735208 100644 --- a/lms/djangoapps/discussion/rest_api/api.py +++ b/lms/djangoapps/discussion/rest_api/api.py @@ -128,6 +128,7 @@ discussion_open_for_user, get_usernames_for_course, get_usernames_from_search_string, + send_signal_after_commit, set_attribute, is_posting_allowed, can_user_notify_all_learners, is_captcha_enabled, get_captcha_site_key_by_platform @@ -1382,7 +1383,9 @@ def _handle_following_field(form_value, user, cc_content, request): else: user.unfollow(cc_content) signal = thread_followed if form_value else thread_unfollowed - signal.send(sender=None, user=user, post=cc_content) + send_signal_after_commit( + lambda: signal.send(sender=None, user=user, post=cc_content) + ) track_thread_followed_event(request, course, cc_content, form_value) @@ -1395,9 +1398,13 @@ def _handle_abuse_flagged_field(form_value, user, cc_content, request): track_discussion_reported_event(request, course, cc_content) if ENABLE_DISCUSSIONS_MFE.is_enabled(course_key): if cc_content.type == 'thread': - thread_flagged.send(sender='flag_abuse_for_thread', user=user, post=cc_content) + send_signal_after_commit( + lambda: thread_flagged.send(sender='flag_abuse_for_thread', user=user, post=cc_content) + ) else: - comment_flagged.send(sender='flag_abuse_for_comment', user=user, post=cc_content) + send_signal_after_commit( + lambda: comment_flagged.send(sender='flag_abuse_for_comment', user=user, post=cc_content) + ) else: remove_all = bool(is_privileged_user(course_key, User.objects.get(id=user.id))) cc_content.unFlagAbuse(user, cc_content, remove_all) @@ -1407,7 +1414,9 @@ def _handle_abuse_flagged_field(form_value, user, cc_content, request): def _handle_voted_field(form_value, cc_content, api_content, request, context): """vote or undo vote on thread/comment""" signal = thread_voted if cc_content.type == 'thread' else comment_voted - signal.send(sender=None, user=context["request"].user, post=cc_content) + send_signal_after_commit( + lambda: signal.send(sender=None, user=context["request"].user, post=cc_content) + ) if form_value: context["cc_requester"].vote(cc_content, "up") api_content["vote_count"] += 1 @@ -1452,7 +1461,9 @@ def _handle_comment_signals(update_data, comment, user, sender=None): """ for key, value in update_data.items(): if key == "endorsed" and value is True: - comment_endorsed.send(sender=sender, user=user, post=comment) + send_signal_after_commit( + lambda: comment_endorsed.send(sender=sender, user=user, post=comment) + ) def create_thread(request, thread_data): @@ -1502,7 +1513,10 @@ def create_thread(request, thread_data): raise ValidationError(dict(list(serializer.errors.items()) + list(actions_form.errors.items()))) serializer.save() cc_thread = serializer.instance - thread_created.send(sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners) + # Use send_signal_after_commit() to ensure the signal is sent only after the transaction commits. + send_signal_after_commit( + lambda: thread_created.send(sender=None, user=user, post=cc_thread, notify_all_learners=notify_all_learners) + ) api_thread = serializer.data _do_extra_actions(api_thread, cc_thread, list(thread_data.keys()), actions_form, context, request) @@ -1550,7 +1564,9 @@ def create_comment(request, comment_data): context["cc_requester"].follow(cc_thread) serializer.save() cc_comment = serializer.instance - comment_created.send(sender=None, user=request.user, post=cc_comment) + send_signal_after_commit( + lambda: comment_created.send(sender=None, user=request.user, post=cc_comment) + ) api_comment = serializer.data _do_extra_actions(api_comment, cc_comment, list(comment_data.keys()), actions_form, context, request) track_comment_created_event(request, course, cc_comment, cc_thread["commentable_id"], followed=False, @@ -1586,7 +1602,9 @@ def update_thread(request, thread_id, update_data): if set(update_data) - set(actions_form.fields): serializer.save() # signal to update Teams when a user edits a thread - thread_edited.send(sender=None, user=request.user, post=cc_thread) + send_signal_after_commit( + lambda: thread_edited.send(sender=None, user=request.user, post=cc_thread) + ) api_thread = serializer.data _do_extra_actions(api_thread, cc_thread, list(update_data.keys()), actions_form, context, request) @@ -1635,7 +1653,9 @@ def update_comment(request, comment_id, update_data): # Only save comment object if some of the edited fields are in the comment data, not extra actions if set(update_data) - set(actions_form.fields): serializer.save() - comment_edited.send(sender=None, user=request.user, post=cc_comment) + send_signal_after_commit( + lambda: comment_edited.send(sender=None, user=request.user, post=cc_comment) + ) api_comment = serializer.data _do_extra_actions(api_comment, cc_comment, list(update_data.keys()), actions_form, context, request) _handle_comment_signals(update_data, cc_comment, request.user) @@ -1823,7 +1843,9 @@ def delete_thread(request, thread_id): cc_thread, context = _get_thread_and_context(request, thread_id) if can_delete(cc_thread, context): cc_thread.delete() - thread_deleted.send(sender=None, user=request.user, post=cc_thread) + send_signal_after_commit( + lambda: thread_deleted.send(sender=None, user=request.user, post=cc_thread) + ) track_thread_deleted_event(request, context["course"], cc_thread) else: raise PermissionDenied @@ -1848,7 +1870,9 @@ def delete_comment(request, comment_id): cc_comment, context = _get_comment_and_context(request, comment_id) if can_delete(cc_comment, context): cc_comment.delete() - comment_deleted.send(sender=None, user=request.user, post=cc_comment) + send_signal_after_commit( + lambda: comment_deleted.send(sender=None, user=request.user, post=cc_comment) + ) track_comment_deleted_event(request, context["course"], cc_comment) else: raise PermissionDenied diff --git a/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py b/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py index 900d52017c5e..df4ac947bf0d 100644 --- a/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py +++ b/lms/djangoapps/discussion/rest_api/tests/test_api_v2.py @@ -273,7 +273,8 @@ def test_basic(self, mock_emit): with self.assert_signal_sent( api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners") ): - actual = create_thread(self.request, self.minimal_data) + with self.captureOnCommitCallbacks(execute=True): + actual = create_thread(self.request, self.minimal_data) expected = self.expected_thread_data( { "id": "test_id", @@ -352,7 +353,8 @@ def test_basic_in_blackout_period_with_user_access(self, mock_emit): with self.assert_signal_sent( api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners") ): - actual = create_thread(self.request, self.minimal_data) + with self.captureOnCommitCallbacks(execute=True): + actual = create_thread(self.request, self.minimal_data) expected = self.expected_thread_data( { "author_label": "Moderator", @@ -428,7 +430,8 @@ def test_title_truncation(self, mock_emit): with self.assert_signal_sent( api, "thread_created", sender=None, user=self.user, exclude_args=("post", "notify_all_learners") ): - create_thread(self.request, data) + with self.captureOnCommitCallbacks(execute=True): + create_thread(self.request, data) event_name, event_data = mock_emit.call_args[0] assert event_name == "edx.forum.thread.created" assert event_data == { @@ -678,7 +681,8 @@ def test_success(self, parent_id, mock_emit): with self.assert_signal_sent( api, "comment_created", sender=None, user=self.user, exclude_args=("post",) ): - actual = create_comment(self.request, data) + with self.captureOnCommitCallbacks(execute=True): + actual = create_comment(self.request, data) expected = { "id": "test_comment", "thread_id": "test_thread", @@ -785,7 +789,8 @@ def test_success_in_black_out_with_user_access(self, parent_id, mock_emit): with self.assert_signal_sent( api, "comment_created", sender=None, user=self.user, exclude_args=("post",) ): - actual = create_comment(self.request, data) + with self.captureOnCommitCallbacks(execute=True): + actual = create_comment(self.request, data) expected = { "id": "test_comment", "thread_id": "test_thread", @@ -1118,9 +1123,10 @@ def test_basic(self): with self.assert_signal_sent( api, "thread_edited", sender=None, user=self.user, exclude_args=("post",) ): - actual = update_thread( - self.request, "test_thread", {"raw_body": "Edited body"} - ) + with self.captureOnCommitCallbacks(execute=True): + actual = update_thread( + self.request, "test_thread", {"raw_body": "Edited body"} + ) assert actual == self.expected_thread_data( { @@ -1436,13 +1442,13 @@ def test_following(self, old_following, new_following, mock_emit): self.register_thread() data = {"following": new_following} signal_name = "thread_followed" if new_following else "thread_unfollowed" - mock_path = ( - f"openedx.core.djangoapps.django_comment_common.signals.{signal_name}.send" - ) + # Patch at the api module level where the signal is imported and used + mock_path = f"lms.djangoapps.discussion.rest_api.api.{signal_name}" with mock.patch(mock_path) as signal_patch: - result = update_thread(self.request, "test_thread", data) + with self.captureOnCommitCallbacks(execute=True): + result = update_thread(self.request, "test_thread", data) if old_following != new_following: - self.assertEqual(signal_patch.call_count, 1) + self.assertEqual(signal_patch.send.call_count, 1) assert result["following"] == new_following if old_following == new_following: @@ -1782,9 +1788,10 @@ def test_basic(self, parent_id): with self.assert_signal_sent( api, "comment_edited", sender=None, user=self.user, exclude_args=("post",) ): - actual = update_comment( - self.request, "test_comment", {"raw_body": "Edited body"} - ) + with self.captureOnCommitCallbacks(execute=True): + actual = update_comment( + self.request, "test_comment", {"raw_body": "Edited body"} + ) expected = { "anonymous": False, "anonymous_to_peers": False, @@ -2207,7 +2214,7 @@ def test_raw_body_access(self, role_name, is_thread_author, is_comment_author): ) @ddt.unpack @mock.patch( - "openedx.core.djangoapps.django_comment_common.signals.comment_endorsed.send" + "lms.djangoapps.discussion.rest_api.api.comment_endorsed.send" ) def test_endorsed_access( self, role_name, is_thread_author, thread_type, is_comment_author, endorsed_mock @@ -2226,7 +2233,8 @@ def test_endorsed_access( thread_type == "discussion" or not is_thread_author ) try: - update_comment(self.request, "test_comment", {"endorsed": True}) + with self.captureOnCommitCallbacks(execute=True): + update_comment(self.request, "test_comment", {"endorsed": True}) self.assertEqual(endorsed_mock.call_count, 1) assert not expected_error except ValidationError as err: @@ -2354,7 +2362,8 @@ def test_basic(self, mock_emit): with self.assert_signal_sent( api, "thread_deleted", sender=None, user=self.user, exclude_args=("post",) ): - assert delete_thread(self.request, self.thread_id) is None + with self.captureOnCommitCallbacks(execute=True): + assert delete_thread(self.request, self.thread_id) is None self.check_mock_called("delete_thread") params = { "thread_id": self.thread_id, @@ -2540,7 +2549,8 @@ def test_basic(self, mock_emit): with self.assert_signal_sent( api, "comment_deleted", sender=None, user=self.user, exclude_args=("post",) ): - assert delete_comment(self.request, self.comment_id) is None + with self.captureOnCommitCallbacks(execute=True): + assert delete_comment(self.request, self.comment_id) is None self.check_mock_called("delete_comment") params = { "comment_id": self.comment_id, diff --git a/lms/djangoapps/discussion/rest_api/utils.py b/lms/djangoapps/discussion/rest_api/utils.py index 0f02a0dcdcf2..a2591655adc2 100644 --- a/lms/djangoapps/discussion/rest_api/utils.py +++ b/lms/djangoapps/discussion/rest_api/utils.py @@ -3,13 +3,14 @@ """ import logging from datetime import datetime -from typing import Dict, List +from typing import Callable, Dict, List import requests from crum import get_current_request from django.conf import settings from django.contrib.auth.models import User # lint-amnesty, pylint: disable=imported-auth-user from django.core.paginator import Paginator +from django.db import transaction from django.db.models.functions import Length from pytz import UTC @@ -496,3 +497,24 @@ def get_captcha_site_key_by_platform(platform: str) -> str | None: Get reCAPTCHA site key based on the platform. """ return settings.RECAPTCHA_SITE_KEYS.get(platform, None) + + +def send_signal_after_commit(signal_func: Callable): + """ + Schedule a signal to be sent after the current database transaction commits. + + This helper ensures that signals are only sent after the transaction commits, + preventing race conditions where async tasks (like Celery workers) may try to + access database records before they are visible (especially important for MySQL + backend with transaction isolation). + + Args: + signal_func: A callable that sends the signal. This will be executed + after the transaction commits. + + Example: + send_signal_after_commit( + lambda: thread_created.send(sender=None, user=user, post=thread, notify_all_learners=False) + ) + """ + transaction.on_commit(signal_func)