Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions shuttle/src/future/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,31 @@ impl BatchSemaphore {
}
drop(state);
}

/// Atomically `upgrade`s from holding `permits_currently_held` to holding `permits_to_be_held`.
/// The motivating use case for this is `parking_lot`s `RwLockUpgradableReadGuard::ugrade`, where we want to be able to
/// go from having a read guard to a write guard while honoring the order of `acquire`s.
///
/// This is implemented by first trying to `acquire` `permits_to_be_held` (which for the `RwLock::upgrade` case would never
/// succeed, as the task is holding one permit, and wants to acquire all of them, meaning even with no other tasks it will
/// block on itself), then `release`ing `permits_currently_held`.
///
/// This ensures the order of `acquire`s is honored, and prevents the potential deadlock situation which could occur in the
/// naive implementation where `permits_to_be_held - permits_currently_held` is `acquire`d, and two tasks try to `upgrade`
/// concurrently (or one `upgrade` in the presence of a `write`).
pub fn upgrade(&self, permits_currently_held: usize, permits_to_be_held: usize) -> Acquire<'_> {
assert!(permits_currently_held > 0);
assert!(permits_to_be_held > permits_currently_held);

let mut acquire = Box::pin(self.acquire(permits_to_be_held));
Comment thread
sarsko marked this conversation as resolved.
let waker = ExecutionState::with(|state| state.current_mut().waker());
let cx = &mut Context::from_waker(&waker);
let _poll = acquire.as_mut().poll(cx);

self.release(permits_currently_held);

*Pin::into_inner(acquire)
}
}

// Safety: Semaphore is never actually passed across true threads, only across continuations. The
Expand Down
268 changes: 261 additions & 7 deletions wrappers/parking_lot/parking_lot_impl/src/rwlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub struct RwLock<T: ?Sized> {
//semaphore to coordinate read and write access to T
sem: BatchSemaphore,

// There can only be one upgradable read at a time, so this semaphore
// is used to ensure that.
upgradable_read_sem: BatchSemaphore,

//inner data T
inner: UnsafeCell<T>,
}
Expand All @@ -35,6 +39,7 @@ impl<T> RwLock<T> {
RwLock {
max_readers,
sem: BatchSemaphore::const_new(max_readers, Fairness::StrictlyFair),
upgradable_read_sem: BatchSemaphore::const_new(1, Fairness::StrictlyFair),
inner: UnsafeCell::new(value),
}
}
Expand Down Expand Up @@ -116,8 +121,8 @@ impl<T: ?Sized> RwLock<T> {
trace!("parking_lot rwlock {:p} write acquired RwLockWriteGuard", self);
RwLockWriteGuard {
permits_acquired: self.max_readers,
data: self.inner.get(),
sem: &self.sem,
data: self.inner.get(),
_p: PhantomData,
}
}
Expand Down Expand Up @@ -151,6 +156,109 @@ impl<T: ?Sized> RwLock<T> {
})
}

/// Locks this `RwLock` with upgradable read access, blocking the current
/// thread until it can be acquired.
#[inline]
#[track_caller]
pub fn upgradable_read(&self) -> RwLockUpgradableReadGuard<'_, T> {
trace!(
"parking_lot rwlock {:p} upgradeable_read acquiring upgradable_read_sem",
self
);

self.upgradable_read_sem.acquire_blocking(1).unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
if !std::thread::panicking() {
unreachable!()
}
});

trace!(
"parking_lot rwlock {:p} upgradeable_read acquiring RwLockUpgradableReadGuard",
self
);

let drop_guard = UpgradableRwLockSemWrapper {
upgradable_read_sem: &self.upgradable_read_sem,
};

self.sem.acquire_blocking(1).unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
if !std::thread::panicking() {
unreachable!()
}
});

trace!(
"parking_lot rwlock {:p} upgradeable_read acquired RwLockUpgradableReadGuard",
self
);

RwLockUpgradableReadGuard {
sem: &self.sem,
_upgradable_read_sem: drop_guard,
max_readers: self.max_readers,
data: self.inner.get(),
_p: PhantomData,
}
}

/// Attempts to acquire this `RwLock` with upgradable read access.
///
/// If the access could not be granted at this time, then `None` is returned.
/// Otherwise, an RAII guard is returned which will release the shared access
/// when it is dropped.
///
/// This function does not block.
#[inline]
#[track_caller]
pub fn try_upgradable_read(&self) -> Option<RwLockUpgradableReadGuard<'_, T>> {
trace!(
"parking_lot rwlock {:p} try_upgradeable_read acquiring upgradable_read_sem",
self
);

if let Err(try_acquire_error) = self.upgradable_read_sem.try_acquire(1) {
trace!(
"parking_lot rwlock {:p} try_upgradeable_read returned {try_acquire_error:?}, returning None",
self
);
return None;
}

let drop_guard = UpgradableRwLockSemWrapper {
upgradable_read_sem: &self.upgradable_read_sem,
};

trace!(
"parking_lot rwlock {:p} try_upgradeable_read acquiring RwLockUpgradableReadGuard",
self
);

if let Err(try_acquire_error) = self.sem.try_acquire(1) {
trace!(
"parking_lot rwlock {:p} try_upgradeable_read returned {try_acquire_error:?}, returning None",
self
);
return None;
}

trace!(
"parking_lot rwlock {:p} try_upgradeable_read acquired RwLockUpgradableReadGuard",
self
);

Some(RwLockUpgradableReadGuard {
sem: &self.sem,
_upgradable_read_sem: drop_guard,
max_readers: self.max_readers,
data: self.inner.get(),
_p: PhantomData,
})
}

/// Returns a mutable reference to the underlying data.
///
/// Since this call borrows the `RwLock` mutably, no actual locking needs to
Expand All @@ -171,6 +279,61 @@ impl<T: ?Sized> RwLock<T> {
}
}

/// An RAII guard for upgradable read access to an `RwLock`.
pub struct RwLockUpgradableReadGuard<'a, T: ?Sized> {
_upgradable_read_sem: UpgradableRwLockSemWrapper<'a>,
sem: &'a BatchSemaphore,
max_readers: usize,
data: *const T,
_p: PhantomData<&'a T>,
}

impl<'a, T: ?Sized> RwLockUpgradableReadGuard<'a, T> {
/// Atomically upgrades an upgradable read lock lock into an exclusive write lock,
/// blocking the current thread until it can be acquired.
#[inline]
pub fn upgrade(s: Self) -> RwLockWriteGuard<'a, T> {
s.sem.upgrade(1, s.max_readers);

// When we return, the `UpgradableRwLockSemWrapper` goes out of scope and it's possible for another
// task to acquire the upgradable read lock.

RwLockWriteGuard {
permits_acquired: s.max_readers,
sem: s.sem,
data: s.data as *mut T,
_p: PhantomData,
}
}

/// Atomically downgrades an upgradable read lock lock into a shared read lock
/// without allowing any writers to take exclusive access of the lock in the
/// meantime.
///
/// Note that if there are any writers currently waiting to take the lock
/// then other readers may not be able to acquire the lock even if it was
/// downgraded.
#[track_caller]
pub fn downgrade(s: Self) -> RwLockReadGuard<'a, T> {
// When we return, the `UpgradableRwLockSemWrapper` goes out of scope and it's possible for another
// task to acquire the upgradable read lock.

RwLockReadGuard {
sem: s.sem,
data: s.data as *mut T,
_p: PhantomData,
}
}
}

impl<T: ?Sized> std::ops::Deref for RwLockUpgradableReadGuard<'_, T> {
type Target = T;

fn deref(&self) -> &T {
unsafe { &*self.data }
}
}

impl<T> From<T> for RwLock<T> {
fn from(s: T) -> Self {
Self::new(s)
Expand All @@ -194,16 +357,20 @@ unsafe impl<T> Sync for RwLock<T> where T: ?Sized + Send + Sync {}
// NB: These impls need to be explicit since we're storing a raw pointer.
// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over
// `T` is `Send`.
unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Sync {}
unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {}

unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
unsafe impl<T> Send for RwLockReadGuard<'_, T> where T: ?Sized + Send + Sync {}
unsafe impl<T> Sync for RwLockReadGuard<'_, T> where T: ?Sized + Sync {}

// Safety: Stores a raw pointer to `T`, so if `T` is `Sync`, the lock guard over
// `T` is `Send` - but since this is also provides mutable access, we need to
// make sure that `T` is `Send` since its value can be sent across thread
// boundaries.
unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Sync {}

// SAFETY: The raw pointer is not actually sent across threads
unsafe impl<T> Send for RwLockUpgradableReadGuard<'_, T> where T: ?Sized + Send + Sync {}
// SAFETY: The raw pointer is not actually sent across threads
unsafe impl<T> Sync for RwLockUpgradableReadGuard<'_, T> where T: ?Sized + Sync {}

/// RAII structure used to release the shared read access of a lock when
/// dropped.
Expand Down Expand Up @@ -242,6 +409,16 @@ impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
}
}

struct UpgradableRwLockSemWrapper<'a> {
upgradable_read_sem: &'a BatchSemaphore,
}

impl<'a> Drop for UpgradableRwLockSemWrapper<'a> {
fn drop(&mut self) {
self.upgradable_read_sem.release(1);
}
}

/// RAII structure used to release the exclusive write access of a lock when
/// dropped.
pub struct RwLockWriteGuard<'a, T: ?Sized> {
Expand Down Expand Up @@ -312,9 +489,9 @@ impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {

#[cfg(test)]
mod tests {
use super::RwLock;
use super::{RwLock, RwLockUpgradableReadGuard};
use shuttle::{check_dfs, thread::spawn};
use std::sync::Arc;
use std::sync::{Arc, atomic::Ordering};

#[test]
#[should_panic = "deadlock"]
Expand All @@ -335,4 +512,81 @@ mod tests {
None,
);
}

// Same as above, but checking that we don't allow multiple upgradable read locks at the same time.
#[test]
#[should_panic = "deadlock"]
fn mem_forget_upgradable_read_guard_deadlock() {
check_dfs(
move || {
let rwlock = Arc::new(RwLock::new(()));
let r1 = rwlock.clone();
let t1 = spawn(move || {
std::mem::forget(r1.upgradable_read());
});
let t2 = spawn(move || {
let _g = rwlock.upgradable_read();
});
t1.join().unwrap();
t2.join().unwrap();
},
None,
);
}

#[test]
fn upgrade_sanity() {
check_dfs(
move || {
let rwlock = Arc::new(RwLock::new(0));
let current_holders = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let current_holders1 = current_holders.clone();
let r1 = rwlock.clone();
let r2 = rwlock.clone();
let t1 = spawn(move || {
let guard = r1.upgradable_read();
let holders = current_holders.fetch_add(1, Ordering::SeqCst);
assert!(holders == 0);
assert!(*guard < 2);
let mut write = RwLockUpgradableReadGuard::<'_, _>::upgrade(guard);
let holders = current_holders.fetch_sub(1, Ordering::SeqCst);
assert!(holders == 1);
*write += 1;
});
let t2 = spawn(move || {
let guard = r2.upgradable_read();
let holders = current_holders1.fetch_add(1, Ordering::SeqCst);
assert!(holders == 0, "{}", format!("{holders}"));
assert!(*guard < 2);
let mut write = RwLockUpgradableReadGuard::<'_, _>::upgrade(guard);
let holders = current_holders1.fetch_sub(1, Ordering::SeqCst);
assert!(holders == 1);
*write += 1;
});
t1.join().unwrap();
t2.join().unwrap();
assert!(*rwlock.read() == 2);
},
None,
);
}

#[test]
fn upgradable_read_does_not_block_write() {
check_dfs(
move || {
let rwlock = Arc::new(RwLock::new(0));
let r1 = rwlock.clone();

let rg = rwlock.upgradable_read();
let t2 = spawn(move || {
let _g = r1.write();
});
let g = RwLockUpgradableReadGuard::upgrade(rg);
drop(g);
t2.join().unwrap();
},
None,
);
}
}
Loading