diff --git a/Cargo.lock b/Cargo.lock index 9a0f789308..fd22412fa2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -952,9 +952,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crossterm" @@ -1309,6 +1309,7 @@ checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" dependencies = [ "futures-core", "futures-sink", + "nanorand", "spin 0.9.8", ] @@ -1499,8 +1500,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -2036,6 +2039,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -2249,9 +2261,9 @@ checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" [[package]] name = "parking_lot" -version = "0.12.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -3252,11 +3264,14 @@ dependencies = [ "bytes", "chrono", "crc", + "criterion", "crossbeam-queue", + "crossbeam-utils", "digest", "either", "encoding_rs", "event-listener 5.2.0", + "flume", "futures-channel", "futures-core", "futures-intrusive", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index f19c1a1e0d..39f764b3d9 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -58,6 +58,7 @@ byteorder = { version = "1.4.3", default-features = false, features = ["std"] } chrono = { version = "0.4.34", default-features = false, features = ["clock"], optional = true } crc = { version = "3", optional = true } crossbeam-queue = "0.3.2" +crossbeam-utils = "0.8.20" digest = { version = "0.10.0", default-features = false, optional = true, features = ["std"] } encoding_rs = { version = "0.8.30", optional = true } either = "1.6.1" @@ -91,4 +92,12 @@ hashbrown = "0.14.5" [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } -tokio = { version = "1", features = ["rt"] } +async-std = { workspace = true, features = ["attributes"] } +tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } +criterion = { version = "0.5.1", features = ["async_tokio"] } +flume = "0.11.0" + +[[bench]] +name = "channels" +harness = false +required-features = ["_rt-tokio"] diff --git a/sqlx-core/benches/channels.rs b/sqlx-core/benches/channels.rs new file mode 100644 index 0000000000..a1e563062c --- /dev/null +++ b/sqlx-core/benches/channels.rs @@ -0,0 +1,120 @@ +use criterion::{Bencher, BenchmarkId, Criterion, criterion_group, criterion_main, Throughput}; + +fn bench_spsc(c: &mut Criterion) { + let mut group = c.benchmark_group("bench_spsc(threaded, count, capacity)"); + + for threaded in [false, true] { + for count in [100u64, 1000, 10_000] { + group.throughput(Throughput::Bytes(size_of::() as u64 * count)); + + for capacity in [16usize, 64, 256] { + group.bench_with_input( + BenchmarkId::from_parameter( + format!("tokio::sync::mpsc({threaded}, {count}, {capacity})") + ), + &(threaded, count, capacity), + bench_spsc_tokio, + ); + + group.bench_with_input( + BenchmarkId::from_parameter( + format!("flume({threaded}, {count}, {capacity})") + ), + &(threaded, count, capacity), + bench_spsc_flume, + ); + + group.bench_with_input( + BenchmarkId::from_parameter( + format!("double_buffer({threaded}, {count}, {capacity})") + ), + &(threaded, count, capacity), + bench_spsc_double_buffer, + ); + } + } + } + + group.finish(); +} + +fn bench_spsc_tokio(bencher: &mut Bencher, &(threaded, count, capacity): &(bool, u64, usize)) { + bencher.to_async(build_spsc_runtime(threaded)).iter(|| async { + let (mut tx, mut rx) = tokio::sync::mpsc::channel(capacity); + + tokio::try_join!( + tokio::spawn(async move { + for i in 0 .. count { + tx.send(i).await.expect("BUG: channel closed early"); + } + }), + tokio::spawn(async move { + for expected in 0 .. count { + assert_eq!(rx.recv().await, Some(expected)); + } + + assert_eq!(rx.recv().await, None); + }) + ).unwrap(); + }); +} + +fn bench_spsc_flume(bencher: &mut Bencher, &(threaded, count, capacity): &(bool, u64, usize)) { + bencher.to_async(build_spsc_runtime(threaded)).iter(|| async { + let (mut tx, mut rx) = flume::bounded(capacity); + + tokio::try_join!( + tokio::spawn(async move { + for i in 0 .. count { + tx.send_async(i).await.expect("BUG: channel closed early"); + } + }), + tokio::spawn(async move { + for expected in 0 .. count { + assert_eq!(rx.recv_async().await, Ok(expected)); + } + + assert_eq!(rx.recv_async().await.ok(), None); + }) + ).unwrap(); + }); +} + +fn bench_spsc_double_buffer(bencher: &mut Bencher, &(threaded, count, capacity): &(bool, u64, usize)) { + bencher.to_async(build_spsc_runtime(threaded)).iter(|| async { + let (mut tx, mut rx) = sqlx_core::common::channel::double_buffer::channel(capacity); + + tokio::try_join!( + tokio::spawn(async move { + for i in 0 .. count { + tx.send(i).await.expect("BUG: channel closed early"); + } + }), + tokio::spawn(async move { + for expected in 0 .. count { + assert_eq!(rx.recv().await, Some(expected)); + } + + assert_eq!(rx.recv().await, None); + }) + ).unwrap(); + }); +} + +fn build_spsc_runtime(threaded: bool) -> tokio::runtime::Runtime { + let mut builder = if threaded { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.worker_threads(2); + builder + } else { + tokio::runtime::Builder::new_current_thread() + }; + + builder + .enable_all() + .build() + .unwrap() +} + +criterion_group!(benches, bench_spsc); +criterion_main!(benches); diff --git a/sqlx-core/src/common/channel/double_buffer.rs b/sqlx-core/src/common/channel/double_buffer.rs new file mode 100644 index 0000000000..d2021ab384 --- /dev/null +++ b/sqlx-core/src/common/channel/double_buffer.rs @@ -0,0 +1,357 @@ +use std::collections::VecDeque; +use std::mem; +use std::sync::{Arc, Mutex, MutexGuard}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::Poll; + +use futures_util::task::AtomicWaker; + +pub struct Sender { + shared: Arc>, + buffer: BufferOption, +} + +pub struct Receiver { + shared: Arc>, + buffer: BufferOption, +} + +struct BufferShared { + header: Header, + // Instead of writing to buffers in shared memory, which would require up to + // 128 bytes of padding to prevent false sharing, the sender and receiver each take + // exclusive ownership of the buffer they're currently accessing. + // + // This way, contended access to shared memory only happens when it's time for a buffer swap. + front: Mutex>>, + back: Mutex>>, +} + +enum BufferOption { + Wants(SelectedBuffer), + HasFront(VecDeque), + HasBack(VecDeque), +} + +#[derive(Debug)] +struct Header { + sender_waiting: AtomicWaker, + receiver_waiting: AtomicWaker, + + closed: AtomicBool, + + front_flushed: AtomicBool, + back_flushed: AtomicBool, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +enum SelectedBuffer { + Front, + Back, +} + +pub fn channel(capacity: usize) -> (Sender, Receiver) { + let buffer_capacity = capacity / 2; + assert_ne!(buffer_capacity, 0, "capacity / 2 must not be zero"); + + // Sender starts out owning the front buffer, + // receiver starts out _wanting_ the front buffer. + let shared = Arc::new(BufferShared { + header: Header { + closed: AtomicBool::new(false), + front_flushed: AtomicBool::new(false), + back_flushed: AtomicBool::new(false), + sender_waiting: AtomicWaker::new(), + receiver_waiting: AtomicWaker::new(), + }, + front: Mutex::new(None), + back: Mutex::new(Some(VecDeque::with_capacity(buffer_capacity))), + }); + + ( + Sender { + shared: shared.clone(), + buffer: BufferOption::HasFront(VecDeque::with_capacity(buffer_capacity)), + }, + Receiver { + shared: shared.clone(), + buffer: BufferOption::Wants(SelectedBuffer::Front), + } + ) +} + +impl Sender { + /// Flush the current buffer and wake the reader. + fn flush_buffer(&mut self) { + let selected = self.buffer.as_selected(); + + let Some(buf) = mem::replace(&mut self.buffer, BufferOption::Wants(selected.next())) + .into_buf() else { + return; + }; + + self.shared.put_buffer(selected, buf); + + self.shared.header.flushed_status(selected) + .store(true, Ordering::Release); + + self.shared.header.receiver_waiting.wake(); + } + + pub async fn send(&mut self, val: T) -> Result<(), T> { + loop { + if self.shared.header.is_closed() { + return Err(val); + } + + let selected = self.buffer.as_selected(); + let flushed_status = self.shared.header.flushed_status(selected); + + if let Some(buf) = self.buffer.get_mut() { + buf.push_back(val); + + if buf.len() == buf.capacity() { + // Advances to the next buffer. + self.flush_buffer(); + } + + return Ok(()); + } + + let res = std::future::poll_fn(|cx| { + self.shared.header.sender_waiting.register(cx.waker()); + + if self.shared.header.is_closed() { + return Poll::Ready(Err(())); + } + + if flushed_status.load(Ordering::Acquire) { + return Poll::Pending + } + + Poll::Ready(Ok(())) + }).await; + + if let Err(()) = res { + return Err(val); + } + + let buf = self.shared.take_buffer(self.buffer.as_selected()); + self.buffer.put(buf); + } + } +} + +/// Closes the channel. +/// +/// The receiver may continue to read messages until the channel is drained. +impl Drop for Sender { + fn drop(&mut self) { + self.flush_buffer(); + self.shared.header.close(); + } +} + +impl Receiver { + fn release_buffer(&mut self) { + let selected = self.buffer.as_selected(); + + let Some(buf) = mem::replace(&mut self.buffer, BufferOption::Wants(selected.next())) + .into_buf() else { + return; + }; + + self.shared.put_buffer(selected, buf); + + self.shared.header.flushed_status(selected) + .store(false, Ordering::Release); + + self.shared.header.sender_waiting.wake(); + } + + + pub async fn recv(&mut self) -> Option { + loop { + // Note: we don't check if the channel is closed until we swap buffers. + if let Some(buf) = self.buffer.get_mut() { + if let Some(val) = buf.pop_front() { + if buf.is_empty() { + self.release_buffer(); + } + + return Some(val); + } + + // This *should* be a no-op, but it doesn't hurt to check again. + self.release_buffer(); + } + + let flushed_status = self.shared.header.flushed_status(self.buffer.as_selected()); + + std::future::poll_fn(|cx| { + self.shared.header.receiver_waiting.register(cx.waker()); + + // Sender has flushed this buffer. + if flushed_status.load(Ordering::Acquire) { + return Poll::Ready(Some(())); + } + + // Allow the reader to drain messages until the channel is empty. + if self.shared.header.is_closed() { + return Poll::Ready(None); + } + + // Waiting for the sender to write to and flush this buffer. + Poll::Pending + }).await?; + + let buf = self.shared.take_buffer(self.buffer.as_selected()); + self.buffer.put(buf); + } + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + // Unlike + self.shared.header.close(); + } +} + +impl Header { + fn close(&self) { + self.closed.store(true, Ordering::Release); + self.sender_waiting.wake(); + self.receiver_waiting.wake(); + } + + fn is_closed(&self) -> bool { + self.closed.load(Ordering::Acquire) + } + + + fn flushed_status(&self, buffer: SelectedBuffer) -> &AtomicBool { + match buffer { + SelectedBuffer::Front => &self.front_flushed, + SelectedBuffer::Back => &self.back_flushed, + } + } +} + +impl BufferShared { + fn lock_buffer_place(&self, buffer: SelectedBuffer) -> MutexGuard<'_, Option>> { + match buffer { + SelectedBuffer::Front => &self.front, + SelectedBuffer::Back => &self.back, + } + .lock() + .unwrap_or_else(|it| it.into_inner()) + } + + fn take_buffer(&self, selected: SelectedBuffer) -> VecDeque { + self + .lock_buffer_place(selected) + .take() + .unwrap_or_else(|| panic!("expected to take {selected:?}, found nothing")) + } + + fn put_buffer(&self, selected: SelectedBuffer, buf: VecDeque) { + let replaced = mem::replace(&mut *self.lock_buffer_place(selected), Some(buf)); + + if let Some(replaced) = replaced { + panic!("BUG: replaced buffer {selected:?} with {} elements", replaced.len()); + } + } +} + +impl BufferOption { + fn as_selected(&self) -> SelectedBuffer { + match *self { + Self::Wants(wants) => wants, + Self::HasFront(_) => SelectedBuffer::Front, + Self::HasBack(_) => SelectedBuffer::Back, + } + } + + fn get_mut(&mut self) -> Option<&mut VecDeque> { + match self { + Self::HasFront(front) => Some(front), + Self::HasBack(back) => Some(back), + _ => None, + } + } + + fn put(&mut self, buf: VecDeque) { + match self { + Self::Wants(SelectedBuffer::Front) => *self = Self::HasFront(buf), + Self::Wants(SelectedBuffer::Back) => *self = Self::HasBack(buf), + Self::HasFront(front) => { + panic!("BUG: replacing front buffer of len {} with buffer of len {}", front.len(), buf.len()); + } + Self::HasBack(back) => { + panic!("BUG: replacing back buffer of len {} with buffer of len {}", back.len(), buf.len()); + } + } + } + + fn into_buf(self) -> Option> { + match self { + Self::HasFront(front) => Some(front), + Self::HasBack(back) => Some(back), + _ => None, + } + } +} + +impl SelectedBuffer { + fn next(&self) -> Self { + match self { + Self::Front => Self::Back, + Self::Back => Self::Front, + } + } +} + + +#[cfg(all(test, any(feature = "_rt-tokio", feature = "_rt-async-std")))] +mod tests { + // Cannot use `#[sqlx::test]` because we want to configure the Tokio runtime to use 2 threads + #[cfg(feature = "_rt-tokio")] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_double_buffer_tokio() { + test_double_buffer().await; + } + + #[cfg(feature = "_rt-async-std")] + #[async_std::test] + async fn test_double_buffer_async_std() { + test_double_buffer().await; + } + + async fn test_double_buffer() { + const CAPACITY: usize = 50; + const END: usize = 1000; + + let (mut tx, mut rx) = super::channel::(CAPACITY); + + let reader = crate::rt::spawn(async move { + for expected in 0usize..=END { + assert_eq!(rx.recv().await, Some(expected)); + } + + assert_eq!(rx.recv().await, None) + }); + + let writer = crate::rt::spawn(async move { + for val in 0usize..=END { + tx.send(val).await.expect("buffer closed prematurely") + } + }); + + // Our wrapper for `JoinHandle` propagates panics in both cases + futures_util::future::join( + reader, + writer, + ).await; + } +} diff --git a/sqlx-core/src/common/channel/mod.rs b/sqlx-core/src/common/channel/mod.rs new file mode 100644 index 0000000000..0db0f39b99 --- /dev/null +++ b/sqlx-core/src/common/channel/mod.rs @@ -0,0 +1 @@ +pub mod double_buffer; diff --git a/sqlx-core/src/common/mod.rs b/sqlx-core/src/common/mod.rs index 794007155f..cf176b6cb2 100644 --- a/sqlx-core/src/common/mod.rs +++ b/sqlx-core/src/common/mod.rs @@ -1,3 +1,4 @@ +pub mod channel; mod statement_cache; pub use statement_cache::StatementCache; diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 27ad29c33e..2677ad62db 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -10,6 +10,9 @@ pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; #[cfg(feature = "_rt-tokio")] pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; +#[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] +pub use dummy::*; + pub struct AsyncSemaphore { // We use the semaphore from futures-intrusive as the one from async-std // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: @@ -141,3 +144,42 @@ impl AsyncSemaphoreReleaser<'_> { crate::rt::missing_rt(()) } } + +#[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] +mod dummy { + use std::marker::PhantomData; + + use std::ops::{Deref, DerefMut}; + + pub struct AsyncMutex { + _marker: PhantomData, + } + + impl AsyncMutex { + pub fn new(val: T) -> Self { + crate::rt::missing_rt(val) + } + + pub async fn lock(&self) -> AsyncMutexGuard<'_, T> { + crate::rt::missing_rt(()) + } + } + + pub struct AsyncMutexGuard<'a, T> { + mutex: &'a AsyncMutex, + } + + impl<'a, T> Deref for AsyncMutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + crate::rt::missing_rt(()) + } + } + + impl<'a, T> DerefMut for AsyncMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + crate::rt::missing_rt(()) + } + } +}