From 1fa7c262fc6b16e48c8e98ef7c4c2cf1c3c596c5 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 19 Aug 2025 18:06:28 -0700 Subject: [PATCH] WIP feat: create sharding structure for pool --- Cargo.lock | 13 +-- sqlx-core/Cargo.toml | 4 + sqlx-core/src/pool/mod.rs | 2 + sqlx-core/src/pool/shard.rs | 194 ++++++++++++++++++++++++++++++++++++ 4 files changed, 207 insertions(+), 6 deletions(-) create mode 100644 sqlx-core/src/pool/shard.rs diff --git a/Cargo.lock b/Cargo.lock index 74f9054912..463b2dca37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2095,9 +2095,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -2447,9 +2447,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -2457,9 +2457,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -3456,6 +3456,7 @@ dependencies = [ "mac_address", "memchr", "native-tls", + "parking_lot", "percent-encoding", "rust_decimal", "rustls", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 37cf9d3b91..b5d8729ddc 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -92,6 +92,10 @@ indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.15.0" +[dependencies.parking_lot] +version = "0.12.4" +features = ["arc_lock"] + [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } tokio = { version = "1", features = ["rt"] } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index f11ff1d76a..d05140855b 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -87,6 +87,8 @@ mod connection; mod inner; mod options; +mod shard; + /// An asynchronous pool of SQLx database connections. /// /// Create a pool with [Pool::connect] or [Pool::connect_with] and then call [Pool::acquire] diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs new file mode 100644 index 0000000000..2a1e3f5907 --- /dev/null +++ b/sqlx-core/src/pool/shard.rs @@ -0,0 +1,194 @@ +use event_listener::Event; +use std::cell::OnceCell; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{atomic, Arc}; +use std::{array, iter}; + +use parking_lot::Mutex; + +type ShardId = usize; +type ConnectionIndex = usize; + +pub struct Sharded { + shards: Box<[Arc>>]>>]>, + global_unlock_event: Event<(ShardId, ConnectionIndex)>, +} + +type ArcMutexGuard = parking_lot::ArcMutexGuard; + +pub struct ConnectedGuard { + locked: ArcMutexGuard>, +} + +pub struct UnconnectedGuard { + locked: ArcMutexGuard>, +} + +// Align to cache lines. +// Simplified from https://docs.rs/crossbeam-utils/0.8.21/src/crossbeam_utils/cache_padded.rs.html#80 +// +// Instead of listing every possible architecture, we just assume 64-bit architectures have 128-byte +// cache lines, which is at least true for newer versions of x86-64 and AArch64. +// A larger alignment isn't harmful as long as we make use of the space. +#[cfg_attr(target_pointer_width = "64", repr(align(128)))] +#[cfg_attr(not(target_pointer_width = "64"), repr(align(64)))] +struct Shard { + locked_set: AtomicUsize, + unlock_event: Event, + connected_set: AtomicUsize, + connections: T, +} + +#[derive(Debug)] +struct Params { + shards: usize, + shard_size: usize, + remainder: usize, +} + +const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { + 64 +} else { + usize::BITS as usize +}; + +impl Sharded { + pub fn new(connections: usize, shards: usize) -> Sharded { + let shards = Params::calc(connections, shards) + .shard_sizes() + .map(|shard_size| Shard::new(shard_size, || Arc::new(Mutex::new(None)))) + .collect::>(); + + Sharded { + shards, + global_unlock_event: Event::with_tag(), + } + } + + pub async fn lock_connected(&self) -> ConnectedGuard {} + + pub async fn lock_unconnected(&self) -> UnconnectedGuard {} +} + +impl Shard<[T]> { + fn new(len: usize, mut fill: impl FnMut() -> T) -> Arc> { + macro_rules! make_array { + ($($n:literal),+) => { + match len { + $($n => Arc::new(Shard { + locked_set: AtomicUsize::new(0), + unlock_event: Event::with_tag(), + connected_set: AtomicUsize::new(0), + connections: array::from_fn::<_, $n, _>(|_| fill()) + }),)* + _ => unreachable!("BUG: length not supported: {len}"), + } + } + } + + make_array!( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ) + } + + async fn acquire(&self, connected: bool) -> ArcMutexGuard> { + if self.unlock_event.total_listeners() > 0 {} + + loop { + let locked_set = self.locked_set.load(Ordering::Acquire); + let connected_set = self.connected_set.load(Ordering::Relaxed); + + let connected_mask = if connected { + connected_set + } else { + !connected_set + }; + + let index = (locked_set & connected_mask).trailing_zeros() as usize; + + if let Some(guard) = self.try_lock(index) { + return guard; + } + } + } + + fn try_lock(&self, index: ConnectionIndex) -> Option>> {} +} + +impl Params { + fn calc(connections: usize, mut shards: usize) -> Params { + let mut shard_size = connections / shards; + let mut remainder = connections % shards; + + if shard_size == 0 { + tracing::debug!(connections, shards, "more shards than connections; clamping shard size to 1, shard count to connections"); + shards = connections; + shard_size = 1; + remainder = 0; + } else if shard_size >= MAX_SHARD_SIZE { + let new_shards = connections.div_ceil(MAX_SHARD_SIZE); + + tracing::debug!(connections, shards, "clamping shard count to {new_shards}"); + + shards = new_shards; + shard_size = connections / shards; + remainder = connections % shards; + } + + Params { + shards, + shard_size, + remainder, + } + } + + fn shard_sizes(&self) -> impl Iterator { + iter::repeat_n(self.shard_size + 1, self.remainder).chain(iter::repeat_n( + self.shard_size, + self.shards - self.remainder, + )) + } +} + +fn thread_id() -> usize { + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + static CURRENT_THREAD_ID: usize = { + THREAD_ID.fetch_add(1, Ordering::SeqCst) + }; + } + + CURRENT_THREAD_ID.with(|i| *i) +} + +#[cfg(test)] +mod tests { + use super::{Params, MAX_SHARD_SIZE}; + + #[test] + fn test_params() { + for connections in 0..100 { + for shards in 1..32 { + let params = Params::calc(connections, shards); + + let mut sum = 0; + + for (i, size) in params.shard_sizes().enumerate() { + assert!(size <= MAX_SHARD_SIZE, "Params::calc({connections}, {shards}) exceeded MAX_SHARD_SIZE at shard #{i}, size {size}"); + + sum += size; + + assert!(sum <= connections, "Params::calc({connections}, {shards}) exceeded connections at shard #{i}, size {size}"); + } + + assert_eq!( + sum, connections, + "Params::calc({connections}, {shards}) does not add up ({params:?}" + ); + } + } + } +}