diff --git a/src/db.rs b/src/db.rs index f59a5aef..8433ab0f 100644 --- a/src/db.rs +++ b/src/db.rs @@ -8,17 +8,17 @@ use crate::{ReadTransaction, Result, WriteTransaction}; use std::collections::btree_set::BTreeSet; use std::fmt::{Display, Formatter}; use std::fs::{File, OpenOptions}; +use std::io; use std::io::ErrorKind; use std::marker::PhantomData; use std::ops::RangeFull; use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Mutex; -use std::{io, panic}; use crate::multimap_table::parse_subtree_roots; #[cfg(feature = "logging")] -use log::{error, info}; +use log::info; pub(crate) type TransactionId = u64; type AtomicTransactionId = AtomicU64; @@ -150,7 +150,6 @@ pub struct Database { next_transaction_id: AtomicTransactionId, live_read_transactions: Mutex>, live_write_transaction: Mutex>, - leaked_write_transaction: Mutex>>, } impl Database { @@ -345,23 +344,9 @@ impl Database { next_transaction_id: AtomicTransactionId::new(next_transaction_id), live_write_transaction: Mutex::new(None), live_read_transactions: Mutex::new(Default::default()), - leaked_write_transaction: Mutex::new(Default::default()), }) } - pub(crate) fn record_leaked_write_transaction(&self, transaction_id: TransactionId) { - assert_eq!( - transaction_id, - self.live_write_transaction.lock().unwrap().unwrap() - ); - *self.leaked_write_transaction.lock().unwrap() = Some(panic::Location::caller()); - #[cfg(feature = "logging")] - error!( - "Leaked write transaction from {}", - panic::Location::caller() - ); - } - pub(crate) fn deallocate_read_transaction(&self, id: TransactionId) { self.live_read_transactions.lock().unwrap().remove(&id); } @@ -391,12 +376,6 @@ impl Database { /// Returns a [`WriteTransaction`] which may be used to read/write to the database. Only a single /// write may be in progress at a time pub fn begin_write(&self) -> Result { - let guard = self.leaked_write_transaction.lock().unwrap(); - if let Some(leaked) = *guard { - return Err(Error::LeakedWriteTransaction(leaked)); - } - drop(guard); - assert!(self.live_write_transaction.lock().unwrap().is_none()); let id = self.next_transaction_id.fetch_add(1, Ordering::AcqRel); *self.live_write_transaction.lock().unwrap() = Some(id); diff --git a/src/error.rs b/src/error.rs index bc77f707..232e9a75 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,7 +13,6 @@ pub enum Error { requested_size: usize, }, TableDoesNotExist(String), - LeakedWriteTransaction(&'static panic::Location<'static>), // Tables cannot be opened for writing multiple times, since they could retrieve immutable & // mutable references to the same dirty pages, or multiple mutable references via insert_reserve() TableAlreadyOpen(String, &'static panic::Location<'static>), @@ -57,9 +56,6 @@ impl Display for Error { Error::TableDoesNotExist(table) => { write!(f, "Table '{}' does not exist", table) } - Error::LeakedWriteTransaction(location) => { - write!(f, "Leaked write transaction: {}", location) - } Error::TableAlreadyOpen(name, location) => { write!(f, "Table '{}' already opened at: {}", name, location) } diff --git a/src/transactions.rs b/src/transactions.rs index 4448455f..96147de0 100644 --- a/src/transactions.rs +++ b/src/transactions.rs @@ -9,14 +9,13 @@ use crate::{ Result, Table, TableDefinition, }; #[cfg(feature = "logging")] -use log::info; +use log::{info, warn}; use std::cell::RefCell; use std::cmp::min; use std::collections::HashMap; use std::mem::size_of; use std::panic; use std::rc::Rc; -use std::sync::atomic::{AtomicBool, Ordering}; /// Informational storage stats about the database #[derive(Debug)] @@ -104,7 +103,7 @@ pub struct WriteTransaction<'db> { freed_tree: BtreeMut<'db, FreedTableKey, [u8]>, freed_pages: Rc>>, open_tables: RefCell>>, - completed: AtomicBool, + completed: bool, durability: Durability, } @@ -127,7 +126,7 @@ impl<'db> WriteTransaction<'db> { freed_tree: BtreeMut::new(freed_root, db.get_memory(), freed_pages.clone()), freed_pages, open_tables: RefCell::new(Default::default()), - completed: Default::default(), + completed: false, durability: Durability::Immediate, }) } @@ -296,7 +295,7 @@ impl<'db> WriteTransaction<'db> { Durability::Immediate => self.durable_commit(false)?, } - self.completed.store(true, Ordering::Release); + self.completed = true; #[cfg(feature = "logging")] info!("Finished commit of transaction id={}", self.transaction_id); @@ -306,13 +305,17 @@ impl<'db> WriteTransaction<'db> { /// Abort the transaction /// /// All writes performed in this transaction will be rolled back - pub fn abort(self) -> Result { + pub fn abort(mut self) -> Result { + self.abort_inner() + } + + fn abort_inner(&mut self) -> Result { #[cfg(feature = "logging")] info!("Aborting transaction id={}", self.transaction_id); self.table_tree.borrow_mut().clear_table_root_updates(); self.mem.rollback_uncommitted_writes()?; self.db.deallocate_write_transaction(self.transaction_id); - self.completed.store(true, Ordering::Release); + self.completed = true; #[cfg(feature = "logging")] info!("Finished abort of transaction id={}", self.transaction_id); Ok(()) @@ -475,8 +478,12 @@ impl<'db> WriteTransaction<'db> { impl<'a> Drop for WriteTransaction<'a> { fn drop(&mut self) { - if !self.completed.load(Ordering::Acquire) { - self.db.record_leaked_write_transaction(self.transaction_id); + if !self.completed { + #[allow(unused_variables)] + if let Err(error) = self.abort_inner() { + #[cfg(feature = "logging")] + warn!("Failure automatically aborting transaction: {}", error); + } } } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 05305e4a..af44b3e6 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -951,18 +951,19 @@ fn delete_table() { } #[test] -fn leaked_write() { +fn dropped_write() { let tmpfile: NamedTempFile = NamedTempFile::new().unwrap(); let db = unsafe { Database::create(tmpfile.path(), 1024 * 1024).unwrap() }; let write_txn = db.begin_write().unwrap(); - drop(write_txn); - let result = db.begin_write(); - if let Err(Error::LeakedWriteTransaction(_message)) = result { - // Good - } else { - panic!(); + { + let mut table = write_txn.open_table(SLICE_TABLE).unwrap(); + table.insert(b"hello", b"world").unwrap(); } + drop(write_txn); + let read_txn = db.begin_read().unwrap(); + let result = read_txn.open_table(SLICE_TABLE); + assert!(matches!(result, Err(Error::TableDoesNotExist(_)))); } #[test]