Skip to content

Commit 31828ec

Browse files
committed
first commit
1 parent 4449f3d commit 31828ec

File tree

1 file changed

+182
-20
lines changed
  • lib/runtime/src/storage/key_value_store

1 file changed

+182
-20
lines changed

lib/runtime/src/storage/key_value_store/etcd.rs

Lines changed: 182 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::time::Duration;
2020
use crate::{slug::Slug, transports::etcd::Client};
2121
use async_stream::stream;
2222
use async_trait::async_trait;
23-
use etcd_client::{EventType, PutOptions, WatchOptions};
23+
use etcd_client::{Compare, CompareOp, EventType, PutOptions, Txn, TxnOp, WatchOptions};
2424

2525
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
2626

@@ -158,31 +158,44 @@ impl EtcdBucket {
158158
let k = make_key(&self.bucket_name, key);
159159
tracing::trace!("etcd create: {k}");
160160

161-
// Does it already exists? For 'create' it shouldn't.
162-
let kvs = self
163-
.client
164-
.kv_get(k.clone(), None)
165-
.await
166-
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
167-
if !kvs.is_empty() {
168-
let version = kvs.first().unwrap().version();
169-
return Ok(StorageOutcome::Exists(version as u64));
170-
}
161+
// Use atomic transaction to check and create in one operation
162+
let put_options = PutOptions::new();
171163

172-
// Write it
173-
let mut put_resp = self
164+
// Build transaction that creates key only if it doesn't exist
165+
let txn = Txn::new()
166+
.when(vec![Compare::version(k.as_str(), CompareOp::Equal, 0)]) // Atomic check
167+
.and_then(vec![TxnOp::put(k.as_str(), value, Some(put_options))]) // Only if check passes
168+
.or_else(vec![
169+
TxnOp::get(k.as_str(), None), // Key exists, get its info
170+
]);
171+
172+
// Execute the transaction
173+
let result = self
174174
.client
175-
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
175+
.etcd_client()
176+
.kv_client()
177+
.txn(txn)
176178
.await
177179
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
178-
// Check if we overwrite something
179-
if put_resp.take_prev_key().is_some() {
180-
// Key created between our get and put
181-
return Err(StorageError::Retry);
180+
181+
if result.succeeded() {
182+
// Key was created successfully
183+
return Ok(StorageOutcome::Created(1)); // version of new key is always 1
182184
}
183185

184-
// version of a new key is always 1
185-
Ok(StorageOutcome::Created(1))
186+
// Key already existed, get its version
187+
if let Some(etcd_client::TxnOpResponse::Get(get_resp)) =
188+
result.op_responses().into_iter().next()
189+
{
190+
if let Some(kv) = get_resp.kvs().first() {
191+
let version = kv.version() as u64;
192+
return Ok(StorageOutcome::Exists(version));
193+
}
194+
}
195+
// Shouldn't happen, but handle edge case
196+
Err(StorageError::EtcdError(
197+
"Unexpected transaction response".to_string(),
198+
))
186199
}
187200

188201
async fn update(
@@ -241,3 +254,152 @@ fn make_key(bucket_name: &str, key: &str) -> String {
241254
]
242255
.join("/")
243256
}
257+
258+
// #[cfg(feature = "integration")]
259+
#[cfg(test)]
260+
mod concurrent_create_tests {
261+
use super::*;
262+
use crate::{distributed::DistributedConfig, DistributedRuntime, Runtime};
263+
use std::sync::Arc;
264+
use tokio::sync::Barrier;
265+
266+
#[test]
267+
fn test_concurrent_etcd_create_race_condition() {
268+
let rt = Runtime::from_settings().unwrap();
269+
let rt_clone = rt.clone();
270+
let config = DistributedConfig::from_settings(false);
271+
272+
rt_clone.primary().block_on(async move {
273+
let drt = DistributedRuntime::new(rt, config).await.unwrap();
274+
test_concurrent_create(drt).await.unwrap();
275+
});
276+
}
277+
278+
async fn test_concurrent_create(drt: DistributedRuntime) -> Result<(), StorageError> {
279+
let etcd_client = drt.etcd_client().expect("etcd client should be available");
280+
let storage = EtcdStorage::new(etcd_client);
281+
282+
// Create a bucket for testing
283+
let bucket = Arc::new(tokio::sync::Mutex::new(
284+
storage
285+
.get_or_create_bucket("test_concurrent_bucket", None)
286+
.await?,
287+
));
288+
289+
// Number of concurrent workers
290+
let num_workers = 10;
291+
let barrier = Arc::new(Barrier::new(num_workers));
292+
293+
// Shared test data
294+
let test_key = format!("concurrent_test_key_{}", uuid::Uuid::new_v4());
295+
let test_value = "test_value";
296+
297+
// Spawn multiple tasks that will all try to create the same key simultaneously
298+
let mut handles = Vec::new();
299+
let success_count = Arc::new(tokio::sync::Mutex::new(0));
300+
let exists_count = Arc::new(tokio::sync::Mutex::new(0));
301+
302+
for worker_id in 0..num_workers {
303+
let bucket_clone = bucket.clone();
304+
let barrier_clone = barrier.clone();
305+
let key_clone = test_key.clone();
306+
let value_clone = format!("{}_from_worker_{}", test_value, worker_id);
307+
let success_count_clone = success_count.clone();
308+
let exists_count_clone = exists_count.clone();
309+
310+
let handle = tokio::spawn(async move {
311+
// Wait for all workers to be ready
312+
barrier_clone.wait().await;
313+
314+
// All workers try to create the same key at the same time
315+
let result = bucket_clone
316+
.lock()
317+
.await
318+
.insert(key_clone, value_clone, 0)
319+
.await;
320+
321+
match result {
322+
Ok(StorageOutcome::Created(version)) => {
323+
println!(
324+
"Worker {} successfully created key with version {}",
325+
worker_id, version
326+
);
327+
let mut count = success_count_clone.lock().await;
328+
*count += 1;
329+
Ok(version)
330+
}
331+
Ok(StorageOutcome::Exists(version)) => {
332+
println!(
333+
"Worker {} found key already exists with version {}",
334+
worker_id, version
335+
);
336+
let mut count = exists_count_clone.lock().await;
337+
*count += 1;
338+
Ok(version)
339+
}
340+
Err(e) => {
341+
println!("Worker {} got error: {:?}", worker_id, e);
342+
Err(e)
343+
}
344+
}
345+
});
346+
347+
handles.push(handle);
348+
}
349+
350+
// Wait for all workers to complete
351+
let mut results = Vec::new();
352+
for handle in handles {
353+
let result = handle.await.unwrap();
354+
if let Ok(version) = result {
355+
results.push(version);
356+
}
357+
}
358+
359+
// Verify results
360+
let final_success_count = *success_count.lock().await;
361+
let final_exists_count = *exists_count.lock().await;
362+
363+
println!(
364+
"Final counts - Created: {}, Exists: {}",
365+
final_success_count, final_exists_count
366+
);
367+
368+
// CRITICAL ASSERTIONS:
369+
// 1. Exactly ONE worker should have successfully created the key
370+
assert_eq!(
371+
final_success_count, 1,
372+
"Exactly one worker should create the key"
373+
);
374+
375+
// 2. All other workers should have gotten "Exists" response
376+
assert_eq!(
377+
final_exists_count,
378+
num_workers - 1,
379+
"All other workers should see key exists"
380+
);
381+
382+
// 3. Total successful operations should equal number of workers
383+
assert_eq!(
384+
results.len(),
385+
num_workers,
386+
"All workers should complete successfully"
387+
);
388+
389+
// 4. Verify the key actually exists in etcd
390+
let stored_value = bucket.lock().await.get(&test_key).await?;
391+
assert!(stored_value.is_some(), "Key should exist in etcd");
392+
393+
// 5. The stored value should be from one of the workers
394+
let stored_str = String::from_utf8(stored_value.unwrap().to_vec()).unwrap();
395+
assert!(
396+
stored_str.starts_with(test_value),
397+
"Stored value should match expected prefix"
398+
);
399+
400+
// Clean up
401+
bucket.lock().await.delete(&test_key).await?;
402+
403+
Ok(())
404+
}
405+
}

0 commit comments

Comments
 (0)