diff --git a/dashboard/modules/job/common.py b/dashboard/modules/job/common.py index a5d927541016e..0bc1d0c8f26fd 100644 --- a/dashboard/modules/job/common.py +++ b/dashboard/modules/job/common.py @@ -189,13 +189,26 @@ def __init__(self, gcs_aio_client: GcsAioClient): self._gcs_aio_client = gcs_aio_client assert _internal_kv_initialized() - async def put_info(self, job_id: str, job_info: JobInfo): - await self._gcs_aio_client.internal_kv_put( + async def put_info( + self, job_id: str, job_info: JobInfo, overwrite: bool = True + ) -> bool: + """Put job info to the internal kv store. + + Args: + job_id: The job id. + job_info: The job info. + overwrite: Whether to overwrite the existing job info. + + Returns: + True if a new key is added. + """ + added_num = await self._gcs_aio_client.internal_kv_put( self.JOB_DATA_KEY.format(job_id=job_id).encode(), json.dumps(job_info.to_json()).encode(), - True, + overwrite, namespace=ray_constants.KV_NAMESPACE_JOB, ) + return added_num == 1 async def get_info(self, job_id: str, timeout: int = 30) -> Optional[JobInfo]: serialized_info = await self._gcs_aio_client.internal_kv_get( diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index fbe608507a46c..f05bdff210444 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -802,8 +802,6 @@ async def submit_job( entrypoint_num_gpus = 0 if submission_id is None: submission_id = generate_job_id() - elif await self._job_info_client.get_status(submission_id) is not None: - raise RuntimeError(f"Job {submission_id} already exists.") logger.info(f"Starting job with submission_id: {submission_id}") job_info = JobInfo( @@ -816,7 +814,14 @@ async def submit_job( entrypoint_num_gpus=entrypoint_num_gpus, entrypoint_resources=entrypoint_resources, ) - await self._job_info_client.put_info(submission_id, job_info) + new_key_added = await self._job_info_client.put_info( + submission_id, job_info, overwrite=False + ) + if not new_key_added: + raise ValueError( + f"Job with submission_id {submission_id} already exists. " + "Please use a different submission_id." + ) # Wait for the actor to start up asynchronously so this call always # returns immediately and we can catch errors with the actor starting diff --git a/dashboard/modules/job/tests/test_job_manager.py b/dashboard/modules/job/tests/test_job_manager.py index 5e0d841d22bbc..952caa96c26ca 100644 --- a/dashboard/modules/job/tests/test_job_manager.py +++ b/dashboard/modules/job/tests/test_job_manager.py @@ -326,12 +326,45 @@ async def test_pass_job_id(job_manager): ) # Check that the same job_id is rejected. - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): await job_manager.submit_job( entrypoint="echo hello", submission_id=submission_id ) +@pytest.mark.asyncio +async def test_simultaneous_submit_job(job_manager): + """Test that we can submit multiple jobs at once.""" + job_ids = await asyncio.gather( + job_manager.submit_job(entrypoint="echo hello"), + job_manager.submit_job(entrypoint="echo hello"), + job_manager.submit_job(entrypoint="echo hello"), + ) + + for job_id in job_ids: + await async_wait_for_condition_async_predicate( + check_job_succeeded, job_manager=job_manager, job_id=job_id + ) + + +@pytest.mark.asyncio +async def test_simultaneous_with_same_id(job_manager): + """Test that we can submit multiple jobs at once with the same id. + + The second job should raise a friendly error. + """ + with pytest.raises(ValueError) as excinfo: + await asyncio.gather( + job_manager.submit_job(entrypoint="echo hello", submission_id="1"), + job_manager.submit_job(entrypoint="echo hello", submission_id="1"), + ) + assert "Job with submission_id 1 already exists" in str(excinfo.value) + # Check that the (first) job can still succeed. + await async_wait_for_condition_async_predicate( + check_job_succeeded, job_manager=job_manager, job_id="1" + ) + + @pytest.mark.asyncio class TestShellScriptExecution: async def test_submit_basic_echo(self, job_manager): diff --git a/python/ray/_private/gcs_utils.py b/python/ray/_private/gcs_utils.py index 7e1dd51181e2f..86faf4600132a 100644 --- a/python/ray/_private/gcs_utils.py +++ b/python/ray/_private/gcs_utils.py @@ -520,6 +520,20 @@ async def internal_kv_put( namespace: Optional[bytes], timeout: Optional[float] = None, ) -> int: + """Put a key-value pair into the GCS. + + Args: + key: The key to put. + value: The value to put. + overwrite: Whether to overwrite the value if the key already exists. + namespace: The namespace to put the key-value pair into. + timeout: The timeout in seconds. + + Returns: + The number of keys added. If overwrite is True, this will be 1 if the + key was added and 0 if the key was updated. If overwrite is False, + this will be 1 if the key was added and 0 if the key already exists. + """ logger.debug(f"internal_kv_put {key!r} {value!r} {overwrite} {namespace!r}") req = gcs_service_pb2.InternalKVPutRequest( namespace=namespace,