Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concurrent room initialization #255

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 63 additions & 54 deletions jupyter_collaboration/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class YDocWebSocketHandler(WebSocketHandler, JupyterHandler):

_message_queue: asyncio.Queue[Any]
_background_tasks: set[asyncio.Task]
_room_locks: dict[str, asyncio.Lock] = {}

def _room_lock(self, room_id: str) -> asyncio.Lock:
if room_id not in self._room_locks:
self._room_locks[room_id] = asyncio.Lock()
return self._room_locks[room_id]

def create_task(self, aw):
task = asyncio.create_task(aw)
Expand All @@ -70,38 +76,38 @@ async def prepare(self):
# Get room
self._room_id: str = self.request.path.split("/")[-1]

if self._websocket_server.room_exists(self._room_id):
self.room: YRoom = await self._websocket_server.get_room(self._room_id)

else:
if self._room_id.count(":") >= 2:
# DocumentRoom
file_format, file_type, file_id = decode_file_path(self._room_id)
if file_id in self._file_loaders:
self._emit(
LogLevel.WARNING,
None,
"There is another collaborative session accessing the same file.\nThe synchronization between rooms is not supported and you might lose some of your changes.",
async with self._room_lock(self._room_id):
if self._websocket_server.room_exists(self._room_id):
self.room: YRoom = await self._websocket_server.get_room(self._room_id)
else:
if self._room_id.count(":") >= 2:
# DocumentRoom
file_format, file_type, file_id = decode_file_path(self._room_id)
if file_id in self._file_loaders:
self._emit(
LogLevel.WARNING,
None,
"There is another collaborative session accessing the same file.\nThe synchronization between rooms is not supported and you might lose some of your changes.",
)

file = self._file_loaders[file_id]
updates_file_path = f".{file_type}:{file_id}.y"
ystore = self._ystore_class(path=updates_file_path, log=self.log)
self.room = DocumentRoom(
self._room_id,
file_format,
file_type,
file,
self.event_logger,
ystore,
self.log,
self._document_save_delay,
)

file = self._file_loaders[file_id]
updates_file_path = f".{file_type}:{file_id}.y"
ystore = self._ystore_class(path=updates_file_path, log=self.log)
self.room = DocumentRoom(
self._room_id,
file_format,
file_type,
file,
self.event_logger,
ystore,
self.log,
self._document_save_delay,
)

else:
# TransientRoom
# it is a transient document (e.g. awareness)
self.room = TransientRoom(self._room_id, self.log)
else:
# TransientRoom
# it is a transient document (e.g. awareness)
self.room = TransientRoom(self._room_id, self.log)

await self._websocket_server.start_room(self.room)
self._websocket_server.add_room(self._room_id, self.room)
Expand Down Expand Up @@ -184,7 +190,8 @@ async def open(self, room_id):

try:
# Initialize the room
await self.room.initialize()
async with self._room_lock(self._room_id):
await self.room.initialize()
self._emit_awareness_event(self.current_user.username, "join")
except Exception as e:
_, _, file_id = decode_file_path(self._room_id)
Expand Down Expand Up @@ -323,29 +330,31 @@ async def _clean_room(self) -> None:
contains a copy of the document. In addition, we remove the file if there is no rooms
subscribed to it.
"""
assert isinstance(self.room, DocumentRoom)

if self._cleanup_delay is None:
return

await asyncio.sleep(self._cleanup_delay)

# Remove the room from the websocket server
self.log.info("Deleting Y document from memory: %s", self.room.room_id)
self._websocket_server.delete_room(room=self.room)

# Clean room
del self.room
self.log.info("Room %s deleted", self._room_id)
self._emit(LogLevel.INFO, "clean", "Room deleted.")

# Clean the file loader if there are not rooms using it
_, _, file_id = decode_file_path(self._room_id)
file = self._file_loaders[file_id]
if file.number_of_subscriptions == 0:
self.log.info("Deleting file %s", file.path)
await self._file_loaders.remove(file_id)
self._emit(LogLevel.INFO, "clean", "Loader deleted.")
async with self._room_lock(self._room_id):
assert isinstance(self.room, DocumentRoom)

if self._cleanup_delay is None:
return

await asyncio.sleep(self._cleanup_delay)

# Remove the room from the websocket server
self.log.info("Deleting Y document from memory: %s", self.room.room_id)
self._websocket_server.delete_room(room=self.room)

# Clean room
del self.room
self.log.info("Room %s deleted", self._room_id)
self._emit(LogLevel.INFO, "clean", "Room deleted.")

# Clean the file loader if there are not rooms using it
_, _, file_id = decode_file_path(self._room_id)
file = self._file_loaders[file_id]
if file.number_of_subscriptions == 0:
self.log.info("Deleting file %s", file.path)
await self._file_loaders.remove(file_id)
self._emit(LogLevel.INFO, "clean", "Loader deleted.")
del self._room_locks[self._room_id]

def check_origin(self, origin):
"""
Expand Down
102 changes: 50 additions & 52 deletions jupyter_collaboration/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
self._save_delay = save_delay

self._update_lock = asyncio.Lock()
self._initialization_lock = asyncio.Lock()
self._cleaner: asyncio.Task | None = None
self._saving_document: asyncio.Task | None = None
self._messages: dict[str, asyncio.Lock] = {}
Expand Down Expand Up @@ -89,64 +88,63 @@ async def initialize(self) -> None:
It is important to set the ready property in the parent class (`self.ready = True`),
this setter will subscribe for updates on the shared document.
"""
async with self._initialization_lock:
if self.ready: # type: ignore[has-type]
return
if self.ready: # type: ignore[has-type]
return

self.log.info("Initializing room %s", self._room_id)
self.log.info("Initializing room %s", self._room_id)

model = await self._file.load_content(self._file_format, self._file_type)
model = await self._file.load_content(self._file_format, self._file_type)

async with self._update_lock:
# try to apply Y updates from the YStore for this document
read_from_source = True
if self.ystore is not None:
try:
await self.ystore.apply_updates(self.ydoc)
self._emit(
LogLevel.INFO,
"load",
"Content loaded from the store {}".format(
self.ystore.__class__.__qualname__
),
)
self.log.info(
"Content in room %s loaded from the ystore %s",
self._room_id,
self.ystore.__class__.__name__,
)
read_from_source = False
except YDocNotFound:
# YDoc not found in the YStore, create the document from the source file (no change history)
pass

if not read_from_source:
# if YStore updates and source file are out-of-sync, resync updates with source
if self._document.source != model["content"]:
# TODO: Delete document from the store.
self._emit(
LogLevel.INFO, "initialize", "The file is out-of-sync with the ystore."
)
self.log.info(
"Content in file %s is out-of-sync with the ystore %s",
self._file.path,
self.ystore.__class__.__name__,
)
read_from_source = True

if read_from_source:
self._emit(LogLevel.INFO, "load", "Content loaded from disk.")
async with self._update_lock:
# try to apply Y updates from the YStore for this document
read_from_source = True
if self.ystore is not None:
try:
await self.ystore.apply_updates(self.ydoc)
self._emit(
LogLevel.INFO,
"load",
"Content loaded from the store {}".format(
self.ystore.__class__.__qualname__
),
)
self.log.info(
"Content in room %s loaded from the ystore %s",
self._room_id,
self.ystore.__class__.__name__,
)
read_from_source = False
except YDocNotFound:
# YDoc not found in the YStore, create the document from the source file (no change history)
pass

if not read_from_source:
# if YStore updates and source file are out-of-sync, resync updates with source
if self._document.source != model["content"]:
# TODO: Delete document from the store.
self._emit(
LogLevel.INFO, "initialize", "The file is out-of-sync with the ystore."
)
self.log.info(
"Content in room %s loaded from file %s", self._room_id, self._file.path
"Content in file %s is out-of-sync with the ystore %s",
self._file.path,
self.ystore.__class__.__name__,
)
self._document.source = model["content"]
read_from_source = True

if read_from_source:
self._emit(LogLevel.INFO, "load", "Content loaded from disk.")
self.log.info(
"Content in room %s loaded from file %s", self._room_id, self._file.path
)
self._document.source = model["content"]

if self.ystore:
await self.ystore.encode_state_as_update(self.ydoc)
if self.ystore:
await self.ystore.encode_state_as_update(self.ydoc)

self._document.dirty = False
self.ready = True
self._emit(LogLevel.INFO, "initialize", "Room initialized")
self._document.dirty = False
self.ready = True
self._emit(LogLevel.INFO, "initialize", "Room initialized")

def _emit(self, level: LogLevel, action: str | None = None, msg: str | None = None) -> None:
data = {"level": level.value, "room": self._room_id, "path": self._file.path}
Expand Down
24 changes: 23 additions & 1 deletion tests/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# Distributed under the terms of the Modified BSD License.

import sys
from time import time

if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points

import pytest
from anyio import sleep
from anyio import create_task_group, sleep
from pycrdt_websocket import WebsocketProvider

jupyter_ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyter_ydoc")}
Expand Down Expand Up @@ -37,3 +38,24 @@ async def test_dirty(
jupyter_ydoc.dirty = True
await sleep(rtc_document_save_delay * 1.5)
assert not jupyter_ydoc.dirty


async def test_room_concurrent_initialization(
rtc_create_file,
rtc_connect_doc_client,
):
file_format = "text"
file_type = "file"
file_path = "dummy.txt"
await rtc_create_file(file_path)

async def connect(file_format, file_type, file_path):
async with await rtc_connect_doc_client(file_format, file_type, file_path) as ws:
pass

t0 = time()
async with create_task_group() as tg:
tg.start_soon(connect, file_format, file_type, file_path)
tg.start_soon(connect, file_format, file_type, file_path)
t1 = time()
assert t1 - t0 < 0.5
Loading