From 776e950a3db42a34cb898be8b899407ae0d2275f Mon Sep 17 00:00:00 2001 From: Chris Staite Date: Tue, 14 Oct 2025 22:19:27 +0100 Subject: [PATCH] Fast-slow deadlock There are a number of different semaphores in the system, for example the file open semaphore and the GCS connections semaphore. When the fast-slow store interacts between these then it can cause deadlocks. Synchronising the collection of semaphores throughout the system is incredibly hard with the interfaces that are in place. Therefore we don't even try to. Instead, add a check for the reader and writer being started for both sides of the fast-slow store and time out if they aren't. This should catch deadlocks and kick the system back to life as a watchdog timer. It's not the best solution, but it's something for now. --- nativelink-store/src/fast_slow_store.rs | 87 +++++++-- nativelink-store/src/gcs_client/client.rs | 98 ++++++---- nativelink-store/src/gcs_client/mocks.rs | 16 +- nativelink-store/src/gcs_store.rs | 90 ++++----- .../tests/fast_slow_store_test.rs | 183 ++++++++++++++++++ nativelink-store/tests/gcs_client_test.rs | 5 +- nativelink-util/src/buf_channel.rs | 21 ++ 7 files changed, 385 insertions(+), 115 deletions(-) diff --git a/nativelink-store/src/fast_slow_store.rs b/nativelink-store/src/fast_slow_store.rs index b76e13fd3..b04961736 100644 --- a/nativelink-store/src/fast_slow_store.rs +++ b/nativelink-store/src/fast_slow_store.rs @@ -17,6 +17,7 @@ use core::cmp::{max, min}; use core::ops::Range; use core::pin::Pin; use core::sync::atomic::{AtomicU64, Ordering}; +use core::time::Duration; use std::collections::HashMap; use std::ffi::OsString; use std::sync::{Arc, Weak}; @@ -61,6 +62,9 @@ pub struct FastSlowStore { // actually it's faster because we're not downloading the file multiple // times are doing loads of duplicate IO. populating_digests: Mutex, Loader>>, + // The amount of time to allow stores to start before determining that they + // have deadlocked and retrying. + deadlock_timeout: Duration, } // This guard ensures that the populating_digests is cleared even if the future @@ -114,12 +118,22 @@ impl Drop for LoaderGuard<'_> { impl FastSlowStore { pub fn new(_spec: &FastSlowSpec, fast_store: Store, slow_store: Store) -> Arc { + Self::new_with_deadlock_timeout(_spec, fast_store, slow_store, Duration::from_secs(5)) + } + + pub fn new_with_deadlock_timeout( + _spec: &FastSlowSpec, + fast_store: Store, + slow_store: Store, + deadlock_timeout: Duration, + ) -> Arc { Arc::new_cyclic(|weak_self| Self { fast_store, slow_store, weak_self: weak_self.clone(), metrics: FastSlowStoreMetrics::default(), populating_digests: Mutex::new(HashMap::new()), + deadlock_timeout, }) } @@ -185,8 +199,62 @@ impl FastSlowStore { let send_range = offset..length.map_or(u64::MAX, |length| length + offset); let mut bytes_received: u64 = 0; - let (mut fast_tx, fast_rx) = make_buf_channel_pair(); - let (slow_tx, mut slow_rx) = make_buf_channel_pair(); + // There's a strong possibility of a deadlock here as we're working with multiple + // stores. We need to be careful that we don't hold a get semaphore if we can't + // open the update. This doesn't know anything about the downstream implementations, + // so simply makes use of a timeout to check that the reader and writers are set up. + let (stores_fut, mut slow_rx, mut fast_tx) = loop { + let (mut fast_tx, fast_rx) = make_buf_channel_pair(); + let (slow_tx, mut slow_rx) = make_buf_channel_pair(); + + let slow_store_fut = self.slow_store.get(key.borrow(), slow_tx); + let fast_store_fut = + self.fast_store + .update(key.borrow(), fast_rx, UploadSizeInfo::ExactSize(sz)); + let mut stores_fut = futures::future::join(slow_store_fut, fast_store_fut); + let has_semaphores_fut = tokio::time::timeout( + self.deadlock_timeout, + futures::future::join(slow_rx.peek(), fast_tx.is_waiting()), + ); + tokio::select! { + result = &mut stores_fut => { + match result { + (Ok(()), Ok(())) => { + // Both stores completed without the writers, probably zero byte. + return Ok(()); + } + (Ok(()), Err(err)) | (Err(err), Ok(())) => { + return Err(err); + } + (Err(err1), Err(err2)) => { + return Err(err1.merge(err2)); + } + } + } + result = has_semaphores_fut => { + match result { + Ok((Ok(_), Ok(()))) => { + // Both sides have started reading/writing, we assume they hold + // all the permits they require and it's safe to continue. + break (stores_fut, slow_rx, fast_tx); + } + Ok((Ok(_), Err(err)) | (Err(err), Ok(()))) => { + return Err(err); + } + Ok((Err(err1), Err(err2))) => { + return Err(err1.merge(err2)); + } + Err(_timeout) => { + // There was probably a deadlock... we need to drop and try again. + drop(stores_fut); + tracing::warn!("Possible deadlock in fast-slow, retrying."); + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + } + }; + }; + let data_stream_fut = async move { let mut maybe_writer_pin = maybe_writer.map(Pin::new); loop { @@ -225,13 +293,7 @@ impl FastSlowStore { } }; - let slow_store_fut = self.slow_store.get(key.borrow(), slow_tx); - let fast_store_fut = - self.fast_store - .update(key.borrow(), fast_rx, UploadSizeInfo::ExactSize(sz)); - - let (data_stream_res, slow_res, fast_res) = - join!(data_stream_fut, slow_store_fut, fast_store_fut); + let (data_stream_res, (slow_res, fast_res)) = join!(data_stream_fut, stores_fut); match data_stream_res { Ok((fast_eof_res, maybe_writer_pin)) => // Sending the EOF will drop us almost immediately in bytestream_server @@ -262,8 +324,7 @@ impl FastSlowStore { if maybe_size_info.is_some() { return Ok(()); } - let loader = self.get_loader(key.borrow()); - loader + self.get_loader(key.borrow()) .get_or_try_init(|| { Pin::new(self).populate_and_maybe_stream(key.borrow(), None, 0, None) }) @@ -474,9 +535,8 @@ impl StoreDriver for FastSlowStore { return Ok(()); } - let loader = self.get_loader(key.borrow()); let mut writer = Some(writer); - loader + self.get_loader(key.borrow()) .get_or_try_init(|| { writer .take() @@ -486,7 +546,6 @@ impl StoreDriver for FastSlowStore { .expect("writer somehow became None") }) .await?; - drop(loader); // If we didn't stream then re-enter which will stream from the fast // store, or retry the download. We should not get in a loop here diff --git a/nativelink-store/src/gcs_client/client.rs b/nativelink-store/src/gcs_client/client.rs index dd27df601..65c0faf59 100644 --- a/nativelink-store/src/gcs_client/client.rs +++ b/nativelink-store/src/gcs_client/client.rs @@ -41,6 +41,12 @@ use crate::gcs_client::types::{ SIMPLE_UPLOAD_THRESHOLD, Timestamp, }; +#[derive(Debug)] +pub struct UploadRef { + pub upload_ref: String, + pub(crate) _permit: OwnedSemaphorePermit, +} + /// A trait that defines the required GCS operations. /// This abstraction allows for easier testing by mocking GCS responses. pub trait GcsOperations: Send + Sync + Debug { @@ -71,12 +77,12 @@ pub trait GcsOperations: Send + Sync + Debug { fn start_resumable_write( &self, object_path: &ObjectPath, - ) -> impl Future> + Send; + ) -> impl Future> + Send; /// Upload a chunk of data in a resumable upload session fn upload_chunk( &self, - upload_url: &str, + upload_url: &UploadRef, object_path: &ObjectPath, data: Bytes, offset: u64, @@ -306,12 +312,21 @@ impl GcsClient { } // Check if the object exists - match self.read_object_metadata(object_path).await? { - Some(_) => Ok(()), - None => Err(make_err!( - Code::Internal, - "Upload completed but object not found" - )), + let request = GetObjectRequest { + bucket: object_path.bucket.clone(), + object: object_path.path.clone(), + ..Default::default() + }; + + match self.client.get_object(&request).await { + Ok(_) => Ok(()), + Err(GcsError::Response(resp)) if resp.code == 404 => { + return Err(make_err!( + Code::Internal, + "Upload completed but object not found" + )); + } + Err(err) => Err(Self::handle_gcs_error(&err)), } }) .await @@ -440,55 +455,58 @@ impl GcsOperations for GcsClient { .await } - async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { - self.with_connection(|| async { - let request = UploadObjectRequest { - bucket: object_path.bucket.clone(), - ..Default::default() - }; + async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { + let permit = + self.semaphore.clone().acquire_owned().await.map_err(|e| { + make_err!(Code::Internal, "Failed to acquire connection permit: {}", e) + })?; + let request = UploadObjectRequest { + bucket: object_path.bucket.clone(), + ..Default::default() + }; - let upload_type = UploadType::Multipart(Box::new(Object { - name: object_path.path.clone(), - content_type: Some(DEFAULT_CONTENT_TYPE.to_string()), - ..Default::default() - })); + let upload_type = UploadType::Multipart(Box::new(Object { + name: object_path.path.clone(), + content_type: Some(DEFAULT_CONTENT_TYPE.to_string()), + ..Default::default() + })); - // Start resumable upload session - let uploader = self - .client - .prepare_resumable_upload(&request, &upload_type) - .await - .map_err(|e| Self::handle_gcs_error(&e))?; + // Start resumable upload session + let uploader = self + .client + .prepare_resumable_upload(&request, &upload_type) + .await + .map_err(|e| Self::handle_gcs_error(&e))?; - Ok(uploader.url().to_string()) + Ok(UploadRef { + upload_ref: uploader.url().to_string(), + _permit: permit, }) - .await } async fn upload_chunk( &self, - upload_url: &str, + upload_url: &UploadRef, _object_path: &ObjectPath, data: Bytes, offset: u64, end_offset: u64, total_size: Option, ) -> Result<(), Error> { - self.with_connection(|| async { - let uploader = self.client.get_resumable_upload(upload_url.to_string()); + let uploader = self + .client + .get_resumable_upload(upload_url.upload_ref.clone()); - let last_byte = if end_offset == 0 { 0 } else { end_offset - 1 }; - let chunk_def = ChunkSize::new(offset, last_byte, total_size); + let last_byte = if end_offset == 0 { 0 } else { end_offset - 1 }; + let chunk_def = ChunkSize::new(offset, last_byte, total_size); - // Upload chunk - uploader - .upload_multiple_chunk(data, &chunk_def) - .await - .map_err(|e| Self::handle_gcs_error(&e))?; + // Upload chunk + uploader + .upload_multiple_chunk(data, &chunk_def) + .await + .map_err(|e| Self::handle_gcs_error(&e))?; - Ok(()) - }) - .await + Ok(()) } async fn upload_from_reader( diff --git a/nativelink-store/src/gcs_client/mocks.rs b/nativelink-store/src/gcs_client/mocks.rs index 5d593283e..0a5773a81 100644 --- a/nativelink-store/src/gcs_client/mocks.rs +++ b/nativelink-store/src/gcs_client/mocks.rs @@ -15,15 +15,16 @@ use core::fmt::Debug; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::collections::HashMap; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use bytes::Bytes; use futures::Stream; use nativelink_error::{Code, Error, make_err}; use nativelink_util::buf_channel::DropCloserReadHalf; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, Semaphore}; -use crate::gcs_client::client::GcsOperations; +use crate::gcs_client::client::{GcsOperations, UploadRef}; use crate::gcs_client::types::{DEFAULT_CONTENT_TYPE, GcsObject, ObjectPath, Timestamp}; /// A mock implementation of `GcsOperations` for testing @@ -379,7 +380,7 @@ impl GcsOperations for MockGcsOperations { Ok(()) } - async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { + async fn start_resumable_write(&self, object_path: &ObjectPath) -> Result { self.call_counts .start_resumable_calls .fetch_add(1, Ordering::Relaxed); @@ -392,12 +393,15 @@ impl GcsOperations for MockGcsOperations { self.handle_failure().await?; let upload_id = format!("mock-upload-{}-{}", object_path.bucket, object_path.path); - Ok(upload_id) + Ok(UploadRef { + upload_ref: upload_id, + _permit: Arc::new(Semaphore::new(1)).acquire_owned().await.unwrap(), + }) } async fn upload_chunk( &self, - upload_url: &str, + upload_url: &UploadRef, object_path: &ObjectPath, data: Bytes, offset: u64, @@ -408,7 +412,7 @@ impl GcsOperations for MockGcsOperations { .upload_chunk_calls .fetch_add(1, Ordering::Relaxed); self.requests.write().await.push(MockRequest::UploadChunk { - upload_url: upload_url.to_string(), + upload_url: upload_url.upload_ref.clone(), object_path: object_path.clone(), data_len: data.len(), offset, diff --git a/nativelink-store/src/gcs_store.rs b/nativelink-store/src/gcs_store.rs index b8bcacc20..1083cd6b5 100644 --- a/nativelink-store/src/gcs_store.rs +++ b/nativelink-store/src/gcs_store.rs @@ -261,8 +261,23 @@ where } else { None }; - let mut upload_id: Option = None; let client = &self.client; + let upload_id = self + .retrier + .retry(unfold((), |()| async { + match client.start_resumable_write(&object_path).await { + Ok(id) => Some((RetryResult::Ok(id), ())), + Err(e) => Some(( + RetryResult::Retry(make_err!( + Code::Aborted, + "Failed to start resumable upload: {:?}", + e + )), + (), + )), + } + })) + .await?; loop { let chunk = reader.consume(Some(self.max_chunk_size)).await?; @@ -274,35 +289,12 @@ where total_size = Some(offset + chunk.len() as u64); } - let upload_id_ref = if let Some(upload_id_ref) = &upload_id { - upload_id_ref - } else { - // Initiate the upload session on the first non-empty chunk. - upload_id = Some( - self.retrier - .retry(unfold((), |()| async { - match client.start_resumable_write(&object_path).await { - Ok(id) => Some((RetryResult::Ok(id), ())), - Err(e) => Some(( - RetryResult::Retry(make_err!( - Code::Aborted, - "Failed to start resumable upload: {:?}", - e - )), - (), - )), - } - })) - .await?, - ); - upload_id.as_deref().unwrap() - }; - let current_offset = offset; offset += chunk.len() as u64; // Uploading the chunk with a retry let object_path_ref = &object_path; + let upload_id_ref = &upload_id; self.retrier .retry(unfold(chunk, |chunk| async move { match client @@ -325,40 +317,30 @@ where // Handle the case that the stream was of unknown length and // happened to be an exact multiple of chunk size. - if let Some(upload_id_ref) = &upload_id { - if total_size.is_none() { - let object_path_ref = &object_path; - self.retrier - .retry(unfold((), |()| async move { - match client - .upload_chunk( - upload_id_ref, - object_path_ref, - Bytes::new(), - offset, - offset, - Some(offset), - ) - .await - { - Ok(()) => Some((RetryResult::Ok(()), ())), - Err(e) => Some((RetryResult::Retry(e), ())), - } - })) - .await?; - } - } else { - // Handle streamed empty file. - return self - .retrier - .retry(unfold((), |()| async { - match client.write_object(&object_path, Vec::new()).await { + if total_size.is_none() { + let object_path_ref = &object_path; + let upload_id_ref = &upload_id; + self.retrier + .retry(unfold((), |()| async move { + match client + .upload_chunk( + upload_id_ref, + object_path_ref, + Bytes::new(), + offset, + offset, + Some(offset), + ) + .await + { Ok(()) => Some((RetryResult::Ok(()), ())), Err(e) => Some((RetryResult::Retry(e), ())), } })) - .await; + .await?; } + // Ensure we drop the permit before verifying. + drop(upload_id); // Verifying if the upload was successful self.retrier diff --git a/nativelink-store/tests/fast_slow_store_test.rs b/nativelink-store/tests/fast_slow_store_test.rs index 0ea4be4f1..52dddc1d1 100644 --- a/nativelink-store/tests/fast_slow_store_test.rs +++ b/nativelink-store/tests/fast_slow_store_test.rs @@ -437,3 +437,186 @@ async fn has_checks_fast_store_when_noop() -> Result<(), Error> { ); Ok(()) } + +#[derive(MetricsComponent)] +struct SemaphoreStore { + sem: Arc, + inner: Arc, +} + +impl SemaphoreStore { + fn new(sem: Arc) -> Arc { + Arc::new(Self { + sem, + inner: MemoryStore::new(&MemorySpec::default()), + }) + } + + async fn get_permit(&self) -> Result, Error> { + self.sem + .acquire() + .await + .map_err(|e| make_err!(Code::Internal, "Failed to acquire permit: {e:?}")) + } +} + +#[async_trait] +impl StoreDriver for SemaphoreStore { + async fn get_part( + self: Pin<&Self>, + key: StoreKey<'_>, + writer: &mut nativelink_util::buf_channel::DropCloserWriteHalf, + offset: u64, + length: Option, + ) -> Result<(), Error> { + let _guard = self.get_permit().await?; + // Ensure this isn't returned in two or fewer writes as that is the buffer size. + let (second_writer, mut second_reader) = make_buf_channel_pair(); + let write_fut = async move { + let data = second_reader.recv().await?; + if data.len() > 6 { + writer.send(data.slice(0..1)).await?; + writer.send(data.slice(1..2)).await?; + writer.send(data.slice(2..3)).await?; + writer.send(data.slice(3..4)).await?; + writer.send(data.slice(4..5)).await?; + writer.send(data.slice(5..)).await?; + } else { + writer.send(data).await?; + } + loop { + let data = second_reader.recv().await?; + if data.is_empty() { + break; + } + writer.send(data).await?; + } + writer.send_eof() + }; + let (res1, res2) = tokio::join!( + write_fut, + self.inner.get_part(key, second_writer, offset, length) + ); + res1.merge(res2) + } + + async fn has_with_results( + self: Pin<&Self>, + digests: &[StoreKey<'_>], + results: &mut [Option], + ) -> Result<(), Error> { + let _guard = self.get_permit().await?; + self.inner.has_with_results(digests, results).await + } + + async fn update( + self: Pin<&Self>, + key: StoreKey<'_>, + mut reader: nativelink_util::buf_channel::DropCloserReadHalf, + upload_size: nativelink_util::store_trait::UploadSizeInfo, + ) -> Result<(), Error> { + let _guard = self.get_permit().await?; + let (mut second_writer, second_reader) = make_buf_channel_pair(); + let write_fut = async move { + let data = reader.recv().await?; + if data.len() > 6 { + // We have two buffers each with two in so we have to chunk to cause a lock up. + second_writer.send(data.slice(0..1)).await?; + second_writer.send(data.slice(1..2)).await?; + second_writer.send(data.slice(2..3)).await?; + second_writer.send(data.slice(3..4)).await?; + second_writer.send(data.slice(4..5)).await?; + second_writer.send(data.slice(5..)).await?; + } else { + second_writer.send(data).await?; + } + loop { + let data = reader.recv().await?; + if data.is_empty() { + break; + } + second_writer.send(data).await?; + } + second_writer.send_eof() + }; + let (res1, res2) = tokio::join!( + write_fut, + self.inner.update(key, second_reader, upload_size) + ); + res1.merge(res2) + } + + fn inner_store(&self, _digest: Option>) -> &dyn StoreDriver { + self + } + + fn as_any(&self) -> &(dyn core::any::Any + Sync + Send + 'static) { + self + } + + fn as_any_arc(self: Arc) -> Arc { + self + } + + fn register_remove_callback( + self: Arc, + callback: &Arc>, + ) -> Result<(), Error> { + self.inner.clone().register_remove_callback(callback) + } +} + +default_health_status_indicator!(SemaphoreStore); + +#[nativelink_test] +async fn semaphore_deadlocks_handled() -> Result<(), Error> { + // Just enough semaphores for the action to function, one for each store. + let semaphore = Arc::new(tokio::sync::Semaphore::new(2)); + let fast_store = Store::new(SemaphoreStore::new(semaphore.clone())); + let slow_store = Store::new(SemaphoreStore::new(semaphore.clone())); + let fast_slow_store_config = FastSlowSpec { + fast: StoreSpec::Memory(MemorySpec::default()), + slow: StoreSpec::Noop(NoopSpec::default()), + }; + let fast_slow_store = Arc::new(FastSlowStore::new_with_deadlock_timeout( + &fast_slow_store_config, + fast_store.clone(), + slow_store.clone(), + core::time::Duration::from_secs(1), + )); + + let data = make_random_data(100); + let digest = DigestInfo::try_new(VALID_HASH, data.len()).unwrap(); + + // Upload some dummy data to the slow store. + slow_store + .update_oneshot(digest, data.clone().into()) + .await?; + + // Now try to get it back without a permit, this should deadlock. We release the + // semaphore when it's released from the other store. + let guard = semaphore.clone().acquire_owned().await.unwrap(); + let release_fut = async move { + // Wait for the store to get the last permit. + while semaphore.available_permits() > 0 { + tokio::time::sleep(core::time::Duration::from_millis(10)).await; + } + // Now wait for it to be released. + let _second_guard = semaphore.acquire().await.unwrap(); + // Now release all the permits. + drop(guard); + }; + let (_, result) = tokio::join!( + release_fut, + tokio::time::timeout( + core::time::Duration::from_secs(10), + fast_slow_store.get_part_unchunked(digest, 0, None) + ) + ); + assert_eq!( + result.map_err(|_| make_err!(Code::Internal, "Semaphore deadlock"))?, + Ok(data.into()) + ); + + Ok(()) +} diff --git a/nativelink-store/tests/gcs_client_test.rs b/nativelink-store/tests/gcs_client_test.rs index 95cff00f8..c0b14b210 100644 --- a/nativelink-store/tests/gcs_client_test.rs +++ b/nativelink-store/tests/gcs_client_test.rs @@ -171,7 +171,10 @@ async fn test_resumable_upload() -> Result<(), Error> { // Start a resumable upload let upload_id = mock_ops.start_resumable_write(&object_path).await?; - assert!(!upload_id.is_empty(), "Expected non-empty upload ID"); + assert!( + !upload_id.upload_ref.is_empty(), + "Expected non-empty upload ID" + ); // Upload chunks let chunk1 = Bytes::from_static(b"first chunk "); diff --git a/nativelink-util/src/buf_channel.rs b/nativelink-util/src/buf_channel.rs index ad3b8c288..9cf13677c 100644 --- a/nativelink-util/src/buf_channel.rs +++ b/nativelink-util/src/buf_channel.rs @@ -39,12 +39,14 @@ pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) { // a little time for another thread to wake up and consume data if another // thread is pumping large amounts of data into the channel. let (tx, rx) = mpsc::channel(2); + let (recv_tx, recv_rx) = tokio::sync::oneshot::channel(); let eof_sent = Arc::new(AtomicBool::new(false)); ( DropCloserWriteHalf { tx: Some(tx), bytes_written: 0, eof_sent: eof_sent.clone(), + recv_rx: Some(recv_rx), }, DropCloserReadHalf { rx, @@ -54,6 +56,7 @@ pub fn make_buf_channel_pair() -> (DropCloserWriteHalf, DropCloserReadHalf) { bytes_received: 0, recent_data: Vec::new(), max_recent_data_size: 0, + recv_tx: Some(recv_tx), }, ) } @@ -64,6 +67,7 @@ pub struct DropCloserWriteHalf { tx: Option>, bytes_written: u64, eof_sent: Arc, + recv_rx: Option>, } impl DropCloserWriteHalf { @@ -72,6 +76,18 @@ impl DropCloserWriteHalf { self.send_get_bytes_on_error(buf).map_err(|err| err.0) } + /// Returns when the DropCloserReadHalf has called recv() for the first time. + pub async fn is_waiting(&mut self) -> Result<(), Error> { + let Some(recv_rx) = self.recv_rx.take() else { + // Once it's None then it's already been successful. + return Ok(()); + }; + match recv_rx.await { + Ok(()) => Ok(()), + Err(_err) => Err(make_err!(Code::Internal, "Dropped before recv")), + } + } + /// Sends data over the channel to the receiver. #[inline] async fn send_get_bytes_on_error(&mut self, buf: Bytes) -> Result<(), (Error, Bytes)> { @@ -207,6 +223,8 @@ pub struct DropCloserReadHalf { /// Amount of data to keep in the `recent_data` buffer before clearing it /// and no longer populating it. max_recent_data_size: u64, + /// A one shot that's sent when the first call to recv() is called. + recv_tx: Option>, } impl DropCloserReadHalf { @@ -238,6 +256,9 @@ impl DropCloserReadHalf { /// Try to receive a chunk of data, returning `None` if none is available. pub fn try_recv(&mut self) -> Option> { + if let Some(recv_tx) = self.recv_tx.take() { + let _ = recv_tx.send(()); + } if let Some(err) = &self.last_err { return Some(Err(err.clone())); }