@@ -20,7 +20,7 @@ use std::time::Duration;
2020use crate :: { slug:: Slug , transports:: etcd:: Client } ;
2121use async_stream:: stream;
2222use async_trait:: async_trait;
23- use etcd_client:: { EventType , PutOptions , WatchOptions } ;
23+ use etcd_client:: { Compare , CompareOp , EventType , PutOptions , Txn , TxnOp , WatchOptions } ;
2424
2525use super :: { KeyValueBucket , KeyValueStore , StorageError , StorageOutcome } ;
2626
@@ -158,31 +158,44 @@ impl EtcdBucket {
158158 let k = make_key ( & self . bucket_name , key) ;
159159 tracing:: trace!( "etcd create: {k}" ) ;
160160
161- // Does it already exists? For 'create' it shouldn't.
162- let kvs = self
163- . client
164- . kv_get ( k. clone ( ) , None )
165- . await
166- . map_err ( |e| StorageError :: EtcdError ( e. to_string ( ) ) ) ?;
167- if !kvs. is_empty ( ) {
168- let version = kvs. first ( ) . unwrap ( ) . version ( ) ;
169- return Ok ( StorageOutcome :: Exists ( version as u64 ) ) ;
170- }
161+ // Use atomic transaction to check and create in one operation
162+ let put_options = PutOptions :: new ( ) ;
171163
172- // Write it
173- let mut put_resp = self
164+ // Build transaction that creates key only if it doesn't exist
165+ let txn = Txn :: new ( )
166+ . when ( vec ! [ Compare :: version( k. as_str( ) , CompareOp :: Equal , 0 ) ] ) // Atomic check
167+ . and_then ( vec ! [ TxnOp :: put( k. as_str( ) , value, Some ( put_options) ) ] ) // Only if check passes
168+ . or_else ( vec ! [
169+ TxnOp :: get( k. as_str( ) , None ) , // Key exists, get its info
170+ ] ) ;
171+
172+ // Execute the transaction
173+ let result = self
174174 . client
175- . kv_put_with_options ( k, value, Some ( PutOptions :: new ( ) . with_prev_key ( ) ) )
175+ . etcd_client ( )
176+ . kv_client ( )
177+ . txn ( txn)
176178 . await
177179 . map_err ( |e| StorageError :: EtcdError ( e. to_string ( ) ) ) ?;
178- // Check if we overwrite something
179- if put_resp . take_prev_key ( ) . is_some ( ) {
180- // Key created between our get and put
181- return Err ( StorageError :: Retry ) ;
180+
181+ if result . succeeded ( ) {
182+ // Key was created successfully
183+ return Ok ( StorageOutcome :: Created ( 1 ) ) ; // version of new key is always 1
182184 }
183185
184- // version of a new key is always 1
185- Ok ( StorageOutcome :: Created ( 1 ) )
186+ // Key already existed, get its version
187+ if let Some ( etcd_client:: TxnOpResponse :: Get ( get_resp) ) =
188+ result. op_responses ( ) . into_iter ( ) . next ( )
189+ {
190+ if let Some ( kv) = get_resp. kvs ( ) . first ( ) {
191+ let version = kv. version ( ) as u64 ;
192+ return Ok ( StorageOutcome :: Exists ( version) ) ;
193+ }
194+ }
195+ // Shouldn't happen, but handle edge case
196+ Err ( StorageError :: EtcdError (
197+ "Unexpected transaction response" . to_string ( ) ,
198+ ) )
186199 }
187200
188201 async fn update (
@@ -241,3 +254,152 @@ fn make_key(bucket_name: &str, key: &str) -> String {
241254 ]
242255 . join ( "/" )
243256}
257+
258+ // #[cfg(feature = "integration")]
259+ #[ cfg( test) ]
260+ mod concurrent_create_tests {
261+ use super :: * ;
262+ use crate :: { distributed:: DistributedConfig , DistributedRuntime , Runtime } ;
263+ use std:: sync:: Arc ;
264+ use tokio:: sync:: Barrier ;
265+
266+ #[ test]
267+ fn test_concurrent_etcd_create_race_condition ( ) {
268+ let rt = Runtime :: from_settings ( ) . unwrap ( ) ;
269+ let rt_clone = rt. clone ( ) ;
270+ let config = DistributedConfig :: from_settings ( false ) ;
271+
272+ rt_clone. primary ( ) . block_on ( async move {
273+ let drt = DistributedRuntime :: new ( rt, config) . await . unwrap ( ) ;
274+ test_concurrent_create ( drt) . await . unwrap ( ) ;
275+ } ) ;
276+ }
277+
278+ async fn test_concurrent_create ( drt : DistributedRuntime ) -> Result < ( ) , StorageError > {
279+ let etcd_client = drt. etcd_client ( ) . expect ( "etcd client should be available" ) ;
280+ let storage = EtcdStorage :: new ( etcd_client) ;
281+
282+ // Create a bucket for testing
283+ let bucket = Arc :: new ( tokio:: sync:: Mutex :: new (
284+ storage
285+ . get_or_create_bucket ( "test_concurrent_bucket" , None )
286+ . await ?,
287+ ) ) ;
288+
289+ // Number of concurrent workers
290+ let num_workers = 10 ;
291+ let barrier = Arc :: new ( Barrier :: new ( num_workers) ) ;
292+
293+ // Shared test data
294+ let test_key = format ! ( "concurrent_test_key_{}" , uuid:: Uuid :: new_v4( ) ) ;
295+ let test_value = "test_value" ;
296+
297+ // Spawn multiple tasks that will all try to create the same key simultaneously
298+ let mut handles = Vec :: new ( ) ;
299+ let success_count = Arc :: new ( tokio:: sync:: Mutex :: new ( 0 ) ) ;
300+ let exists_count = Arc :: new ( tokio:: sync:: Mutex :: new ( 0 ) ) ;
301+
302+ for worker_id in 0 ..num_workers {
303+ let bucket_clone = bucket. clone ( ) ;
304+ let barrier_clone = barrier. clone ( ) ;
305+ let key_clone = test_key. clone ( ) ;
306+ let value_clone = format ! ( "{}_from_worker_{}" , test_value, worker_id) ;
307+ let success_count_clone = success_count. clone ( ) ;
308+ let exists_count_clone = exists_count. clone ( ) ;
309+
310+ let handle = tokio:: spawn ( async move {
311+ // Wait for all workers to be ready
312+ barrier_clone. wait ( ) . await ;
313+
314+ // All workers try to create the same key at the same time
315+ let result = bucket_clone
316+ . lock ( )
317+ . await
318+ . insert ( key_clone, value_clone, 0 )
319+ . await ;
320+
321+ match result {
322+ Ok ( StorageOutcome :: Created ( version) ) => {
323+ println ! (
324+ "Worker {} successfully created key with version {}" ,
325+ worker_id, version
326+ ) ;
327+ let mut count = success_count_clone. lock ( ) . await ;
328+ * count += 1 ;
329+ Ok ( version)
330+ }
331+ Ok ( StorageOutcome :: Exists ( version) ) => {
332+ println ! (
333+ "Worker {} found key already exists with version {}" ,
334+ worker_id, version
335+ ) ;
336+ let mut count = exists_count_clone. lock ( ) . await ;
337+ * count += 1 ;
338+ Ok ( version)
339+ }
340+ Err ( e) => {
341+ println ! ( "Worker {} got error: {:?}" , worker_id, e) ;
342+ Err ( e)
343+ }
344+ }
345+ } ) ;
346+
347+ handles. push ( handle) ;
348+ }
349+
350+ // Wait for all workers to complete
351+ let mut results = Vec :: new ( ) ;
352+ for handle in handles {
353+ let result = handle. await . unwrap ( ) ;
354+ if let Ok ( version) = result {
355+ results. push ( version) ;
356+ }
357+ }
358+
359+ // Verify results
360+ let final_success_count = * success_count. lock ( ) . await ;
361+ let final_exists_count = * exists_count. lock ( ) . await ;
362+
363+ println ! (
364+ "Final counts - Created: {}, Exists: {}" ,
365+ final_success_count, final_exists_count
366+ ) ;
367+
368+ // CRITICAL ASSERTIONS:
369+ // 1. Exactly ONE worker should have successfully created the key
370+ assert_eq ! (
371+ final_success_count, 1 ,
372+ "Exactly one worker should create the key"
373+ ) ;
374+
375+ // 2. All other workers should have gotten "Exists" response
376+ assert_eq ! (
377+ final_exists_count,
378+ num_workers - 1 ,
379+ "All other workers should see key exists"
380+ ) ;
381+
382+ // 3. Total successful operations should equal number of workers
383+ assert_eq ! (
384+ results. len( ) ,
385+ num_workers,
386+ "All workers should complete successfully"
387+ ) ;
388+
389+ // 4. Verify the key actually exists in etcd
390+ let stored_value = bucket. lock ( ) . await . get ( & test_key) . await ?;
391+ assert ! ( stored_value. is_some( ) , "Key should exist in etcd" ) ;
392+
393+ // 5. The stored value should be from one of the workers
394+ let stored_str = String :: from_utf8 ( stored_value. unwrap ( ) . to_vec ( ) ) . unwrap ( ) ;
395+ assert ! (
396+ stored_str. starts_with( test_value) ,
397+ "Stored value should match expected prefix"
398+ ) ;
399+
400+ // Clean up
401+ bucket. lock ( ) . await . delete ( & test_key) . await ?;
402+
403+ Ok ( ( ) )
404+ }
405+ }
0 commit comments