diff --git a/Cargo.lock b/Cargo.lock index 033f40303b..8bd066a7c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2980,6 +2980,7 @@ name = "libdd-common" version = "2.0.0" dependencies = [ "anyhow", + "async-trait", "bytes", "cc", "const_format", @@ -3094,6 +3095,7 @@ version = "2.0.0" dependencies = [ "anyhow", "arc-swap", + "async-trait", "bytes", "clap", "criterion", @@ -3316,6 +3318,7 @@ name = "libdd-telemetry" version = "3.0.0" dependencies = [ "anyhow", + "async-trait", "base64 0.22.1", "futures", "hashbrown 0.15.1", diff --git a/libdd-common/Cargo.toml b/libdd-common/Cargo.toml index aef95db03c..e418591b26 100644 --- a/libdd-common/Cargo.toml +++ b/libdd-common/Cargo.toml @@ -17,6 +17,7 @@ bench = false [dependencies] anyhow = "1.0" +async-trait = "0.1" futures = "0.3" futures-core = { version = "0.3.0", default-features = false } futures-util = { version = "0.3.0", default-features = false } diff --git a/libdd-common/src/worker.rs b/libdd-common/src/worker.rs index c79c9317f2..a88c81a192 100644 --- a/libdd-common/src/worker.rs +++ b/libdd-common/src/worker.rs @@ -1,12 +1,65 @@ // Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ // SPDX-License-Identifier: Apache-2.0 +use async_trait::async_trait; + /// Trait representing a generic worker. /// -/// The worker runs an async looping function running periodic tasks. -/// -/// This trait can be used to provide wrapper around a worker. -pub trait Worker { - /// Main worker loop - fn run(&mut self) -> impl std::future::Future + Send; +/// # Lifecycle +/// The worker's `Self::run` method should be executed everytime the `Self::trigger` method returns. +/// On startup `Self::initial_trigger` should be called before `Self::run`. +#[async_trait] +pub trait Worker: std::fmt::Debug { + /// Main worker function + /// + /// Code in this function should always use timeout on long-running await calls to avoid + /// blocking forks if an await call takes too long to complete. + async fn run(&mut self); + + /// Function called between each `run` to wait for the next run + async fn trigger(&mut self); + + /// Alternative trigger called on start to provide custom behavior + /// Defaults to `trigger` behavior. + async fn initial_trigger(&mut self) { + self.trigger().await + } + + /// Reset the worker in the child after a fork + fn reset(&mut self) {} + + /// Hook called after the worker has been paused (e.g. before a fork). + /// Default is a no-op. + async fn on_pause(&mut self) {} + + /// Hook called when the app is shutting down. Can be used to flush remaining data. + async fn shutdown(&mut self) {} +} + +// Blanket implementation for boxed trait objects +#[async_trait] +impl Worker for Box { + async fn run(&mut self) { + (**self).run().await + } + + async fn trigger(&mut self) { + (**self).trigger().await + } + + async fn initial_trigger(&mut self) { + (**self).initial_trigger().await + } + + fn reset(&mut self) { + (**self).reset() + } + + async fn on_pause(&mut self) { + (**self).on_pause().await + } + + async fn shutdown(&mut self) { + (**self).shutdown().await + } } diff --git a/libdd-data-pipeline-ffi/src/lib.rs b/libdd-data-pipeline-ffi/src/lib.rs index 9e4e4bc278..c1c71bad49 100644 --- a/libdd-data-pipeline-ffi/src/lib.rs +++ b/libdd-data-pipeline-ffi/src/lib.rs @@ -8,6 +8,7 @@ mod error; mod response; +mod shared_runtime; mod trace_exporter; #[cfg(all(feature = "catch_panic", panic = "unwind"))] diff --git a/libdd-data-pipeline-ffi/src/shared_runtime.rs b/libdd-data-pipeline-ffi/src/shared_runtime.rs new file mode 100644 index 0000000000..917f60dffa --- /dev/null +++ b/libdd-data-pipeline-ffi/src/shared_runtime.rs @@ -0,0 +1,273 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +use libdd_data_pipeline::shared_runtime::{SharedRuntime, SharedRuntimeError}; +use std::ffi::{c_char, CString}; +use std::ptr::NonNull; +use std::sync::Arc; + +/// Error codes for SharedRuntime FFI operations. +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum SharedRuntimeErrorCode { + /// Invalid argument provided (e.g. null handle). + InvalidArgument, + /// The runtime is not available or in an invalid state. + RuntimeUnavailable, + /// Failed to acquire a lock on internal state. + LockFailed, + /// A worker operation failed. + WorkerError, + /// Failed to create the tokio runtime. + RuntimeCreation, + /// Shutdown timed out. + ShutdownTimedOut, +} + +/// Error returned by SharedRuntime FFI functions. +#[repr(C)] +pub struct SharedRuntimeFFIError { + pub code: SharedRuntimeErrorCode, + pub msg: *mut c_char, +} + +impl SharedRuntimeFFIError { + fn new(code: SharedRuntimeErrorCode, msg: &str) -> Self { + Self { + code, + msg: CString::new(msg).unwrap_or_default().into_raw(), + } + } +} + +impl From for SharedRuntimeFFIError { + fn from(err: SharedRuntimeError) -> Self { + let code = match &err { + SharedRuntimeError::RuntimeUnavailable => SharedRuntimeErrorCode::RuntimeUnavailable, + SharedRuntimeError::LockFailed(_) => SharedRuntimeErrorCode::LockFailed, + SharedRuntimeError::WorkerError(_) => SharedRuntimeErrorCode::WorkerError, + SharedRuntimeError::RuntimeCreation(_) => SharedRuntimeErrorCode::RuntimeCreation, + SharedRuntimeError::ShutdownTimedOut(_) => SharedRuntimeErrorCode::ShutdownTimedOut, + }; + SharedRuntimeFFIError::new(code, &err.to_string()) + } +} + +impl Drop for SharedRuntimeFFIError { + fn drop(&mut self) { + if !self.msg.is_null() { + // SAFETY: `msg` is always produced by `CString::into_raw` in `new`. + unsafe { + drop(CString::from_raw(self.msg)); + self.msg = std::ptr::null_mut(); + } + } + } +} + +/// Frees a `SharedRuntimeFFIError`. After this call the pointer is invalid. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_error_free(error: Option>) { + drop(error); +} + +/// Create a new `SharedRuntime`. +/// +/// On success writes a raw handle into `*out_handle` and returns `None`. +/// On failure leaves `*out_handle` unchanged and returns an error. +/// +/// The caller owns the handle and must eventually pass it to +/// [`ddog_shared_runtime_free`] (or another consumer that takes ownership). +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_new( + out_handle: NonNull<*const SharedRuntime>, +) -> Option> { + match SharedRuntime::new() { + Ok(runtime) => { + out_handle.as_ptr().write(Arc::into_raw(Arc::new(runtime))); + None + } + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } +} + +/// Free a handle, decrementing the `Arc` strong count. +/// +/// The underlying runtime may not be dropped if other components are still using it. +/// Use [`ddog_shared_runtime_shutdown`] to cleanly stop workers. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_free(handle: *const SharedRuntime) { + if !handle.is_null() { + // SAFETY: handle was produced by Arc::into_raw; this call takes ownership. + drop(Arc::from_raw(handle)); + } +} + +/// Must be called in the parent process before `fork()`. +/// +/// Pauses all workers so that no background threads are running during the +/// fork, preventing deadlocks in the child process. +/// +/// Returns an error if `handle` is null. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_before_fork( + handle: *const SharedRuntime, +) -> Option> { + if handle.is_null() { + return Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))); + } + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + (*handle).before_fork(); + None +} + +/// Must be called in the parent process after `fork()`. +/// +/// Restarts all workers that were paused by [`ddog_shared_runtime_before_fork`]. +/// +/// Returns `None` on success, or an error if workers could not be restarted. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_after_fork_parent( + handle: *const SharedRuntime, +) -> Option> { + if handle.is_null() { + return Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))); + } + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + match (*handle).after_fork_parent() { + Ok(()) => None, + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } +} + +/// Must be called in the child process after `fork()`. +/// +/// Creates a fresh tokio runtime and restarts all workers. The original +/// runtime cannot be safely reused after a fork. +/// +/// Returns `None` on success, or an error if the runtime could not be +/// reinitialized. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_after_fork_child( + handle: *const SharedRuntime, +) -> Option> { + if handle.is_null() { + return Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))); + } + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + match (*handle).after_fork_child() { + Ok(()) => None, + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } +} + +/// Shut down the `SharedRuntime`, stopping all workers. +/// +/// `timeout_ms` is the maximum time to wait for workers to stop, in +/// milliseconds. Pass `0` for no timeout. +/// +/// Returns `None` on success, or `SharedRuntimeErrorCode::ShutdownTimedOut` +/// if the timeout was reached. +#[no_mangle] +pub unsafe extern "C" fn ddog_shared_runtime_shutdown( + handle: *const SharedRuntime, + timeout_ms: u64, +) -> Option> { + if handle.is_null() { + return Some(Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "handle is null", + ))); + } + + let timeout = if timeout_ms > 0 { + Some(std::time::Duration::from_millis(timeout_ms)) + } else { + None + }; + + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + match (*handle).shutdown(timeout) { + Ok(()) => None, + Err(err) => Some(Box::new(SharedRuntimeFFIError::from(err))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem::MaybeUninit; + + #[test] + fn test_new_and_free() { + unsafe { + let mut handle: MaybeUninit<*const SharedRuntime> = MaybeUninit::uninit(); + let err = ddog_shared_runtime_new(NonNull::new_unchecked(handle.as_mut_ptr())); + assert!(err.is_none()); + ddog_shared_runtime_free(handle.assume_init()); + } + } + + #[test] + fn test_before_after_fork_null() { + unsafe { + let err = ddog_shared_runtime_before_fork(std::ptr::null()); + assert_eq!(err.unwrap().code, SharedRuntimeErrorCode::InvalidArgument); + + let err = ddog_shared_runtime_after_fork_parent(std::ptr::null()); + assert_eq!(err.unwrap().code, SharedRuntimeErrorCode::InvalidArgument); + + let err = ddog_shared_runtime_after_fork_child(std::ptr::null()); + assert_eq!(err.unwrap().code, SharedRuntimeErrorCode::InvalidArgument); + } + } + + #[test] + fn test_fork_lifecycle() { + unsafe { + let mut handle: MaybeUninit<*const SharedRuntime> = MaybeUninit::uninit(); + ddog_shared_runtime_new(NonNull::new_unchecked(handle.as_mut_ptr())); + let handle = handle.assume_init(); + + let err = ddog_shared_runtime_before_fork(handle); + assert!(err.is_none(), "{:?}", err.map(|e| e.code)); + + let err = ddog_shared_runtime_after_fork_parent(handle); + assert!(err.is_none(), "{:?}", err.map(|e| e.code)); + + ddog_shared_runtime_free(handle); + } + } + + #[test] + fn test_shutdown() { + unsafe { + let mut handle: MaybeUninit<*const SharedRuntime> = MaybeUninit::uninit(); + ddog_shared_runtime_new(NonNull::new_unchecked(handle.as_mut_ptr())); + let handle = handle.assume_init(); + + let err = ddog_shared_runtime_shutdown(handle, 0); + assert!(err.is_none()); + + ddog_shared_runtime_free(handle); + } + } + + #[test] + fn test_error_free() { + let error = Box::new(SharedRuntimeFFIError::new( + SharedRuntimeErrorCode::InvalidArgument, + "test error", + )); + unsafe { ddog_shared_runtime_error_free(Some(error)) }; + } +} diff --git a/libdd-data-pipeline-ffi/src/trace_exporter.rs b/libdd-data-pipeline-ffi/src/trace_exporter.rs index 27226a00dd..ce26d11051 100644 --- a/libdd-data-pipeline-ffi/src/trace_exporter.rs +++ b/libdd-data-pipeline-ffi/src/trace_exporter.rs @@ -8,11 +8,11 @@ use libdd_common_ffi::{ CharSlice, {slice::AsBytes, slice::ByteSlice}, }; - +use libdd_data_pipeline::shared_runtime::SharedRuntime; use libdd_data_pipeline::trace_exporter::{ TelemetryConfig, TraceExporter, TraceExporterInputFormat, TraceExporterOutputFormat, }; -use std::{ptr::NonNull, time::Duration}; +use std::{ptr::NonNull, sync::Arc, time::Duration}; use tracing::{debug, error}; #[inline] @@ -67,6 +67,7 @@ pub struct TraceExporterConfig { health_metrics_enabled: bool, test_session_token: Option, connection_timeout: Option, + shared_runtime: Option>, } #[no_mangle] @@ -393,6 +394,36 @@ pub unsafe extern "C" fn ddog_trace_exporter_config_set_connection_timeout( ) } +/// Sets a shared runtime for the TraceExporter to use for background workers. +/// +/// `handle` must have been initialized with [`ddog_shared_runtime_new`]. +/// +/// When set, the exporter will use the provided runtime instead of creating its own. +/// This allows multiple exporters (or other components) to share a single runtime. +/// The config holds a clone of the `Arc` (increments the strong count), so the +/// original handle remains valid and must still be freed with +/// [`ddog_shared_runtime_free`]. +#[no_mangle] +pub unsafe extern "C" fn ddog_trace_exporter_config_set_shared_runtime( + config: Option<&mut TraceExporterConfig>, + handle: *const SharedRuntime, +) -> Option> { + catch_panic!( + match config { + Some(config) if !handle.is_null() => { + // SAFETY: handle was produced by Arc::into_raw and the Arc is still alive. + // Increment the strong count before reconstructing so the config's Arc + // is independent from the caller's handle. + Arc::increment_strong_count(handle); + config.shared_runtime = Some(Arc::from_raw(handle)); + None + } + _ => gen_error!(ErrorCode::InvalidArgument), + }, + gen_error!(ErrorCode::Panic) + ) +} + /// Create a new TraceExporter instance. /// /// # Arguments @@ -445,6 +476,10 @@ pub unsafe extern "C" fn ddog_trace_exporter_new( builder.enable_health_metrics(); } + if let Some(runtime) = config.shared_runtime.clone() { + builder.set_shared_runtime(runtime); + } + match builder.build() { Ok(exporter) => { out_handle.as_ptr().write(Box::new(exporter)); diff --git a/libdd-data-pipeline/Cargo.toml b/libdd-data-pipeline/Cargo.toml index 7688eface5..c513c4c5a5 100644 --- a/libdd-data-pipeline/Cargo.toml +++ b/libdd-data-pipeline/Cargo.toml @@ -14,6 +14,7 @@ autobenches = false [dependencies] anyhow = { version = "1.0" } arc-swap = "1.7.1" +async-trait = "0.1" http = "1.1" http-body-util = "0.1" tracing = { version = "0.1", default-features = false } diff --git a/libdd-data-pipeline/examples/send-traces-with-stats.rs b/libdd-data-pipeline/examples/send-traces-with-stats.rs index e5a2180754..7542826a7e 100644 --- a/libdd-data-pipeline/examples/send-traces-with-stats.rs +++ b/libdd-data-pipeline/examples/send-traces-with-stats.rs @@ -2,8 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 use clap::Parser; -use libdd_data_pipeline::trace_exporter::{ - TelemetryConfig, TraceExporter, TraceExporterInputFormat, TraceExporterOutputFormat, +use libdd_data_pipeline::{ + shared_runtime::SharedRuntime, + trace_exporter::{ + TelemetryConfig, TraceExporter, TraceExporterInputFormat, TraceExporterOutputFormat, + }, }; use libdd_log::logger::{ logger_configure_std, logger_set_log_level, LogEventLevel, StdConfig, StdTarget, @@ -11,6 +14,7 @@ use libdd_log::logger::{ use libdd_trace_protobuf::pb; use std::{ collections::HashMap, + sync::Arc, time::{Duration, UNIX_EPOCH}, }; @@ -53,6 +57,8 @@ fn main() { .expect("Failed to configure logger"); logger_set_log_level(LogEventLevel::Debug).expect("Failed to set log level"); + let shared_runtime = Arc::new(SharedRuntime::new().expect("Failed to create runtime")); + let args = Args::parse(); let telemetry_cfg = TelemetryConfig::default(); let mut builder = TraceExporter::builder(); @@ -67,6 +73,7 @@ fn main() { .set_language_version(env!("CARGO_PKG_RUST_VERSION")) .set_input_format(TraceExporterInputFormat::V04) .set_output_format(TraceExporterOutputFormat::V04) + .set_shared_runtime(shared_runtime.clone()) .enable_telemetry(telemetry_cfg) .enable_stats(Duration::from_secs(10)); let exporter = builder.build().expect("Failed to build TraceExporter"); @@ -86,7 +93,7 @@ fn main() { let data = rmp_serde::to_vec_named(&traces).expect("Failed to serialize traces"); exporter.send(data.as_ref()).expect("Failed to send traces"); - exporter + shared_runtime .shutdown(None) - .expect("Failed to shutdown exporter"); + .expect("Failed to shutdown runtime"); } diff --git a/libdd-data-pipeline/src/agent_info/fetcher.rs b/libdd-data-pipeline/src/agent_info/fetcher.rs index 9bd7200288..773496afa7 100644 --- a/libdd-data-pipeline/src/agent_info/fetcher.rs +++ b/libdd-data-pipeline/src/agent_info/fetcher.rs @@ -5,6 +5,7 @@ use super::{schema::AgentInfo, AGENT_INFO_CACHE}; use anyhow::{anyhow, Result}; +use async_trait::async_trait; use http::header::HeaderName; use http_body_util::BodyExt; use libdd_common::{http_common, worker::Worker, Endpoint}; @@ -97,9 +98,12 @@ async fn fetch_and_hash_response(info_endpoint: &Endpoint) -> Result<(String, by /// Fetch the info endpoint and update an ArcSwap keeping it up-to-date. /// -/// Once the run method has been started, the fetcher will -/// update the global info state based on the given refresh interval. You can access the current -/// state with [`crate::agent_info::get_agent_info`] +/// This type implements [`libdd_common::worker::Worker`] and is intended to be driven by a worker +/// runner such as [`crate::shared_runtime::SharedRuntime`]. +/// In that lifecycle, `trigger()` waits for the next refresh event and `run()` performs a single +/// fetch. +/// +/// You can access the current state with [`crate::agent_info::get_agent_info`]. /// /// # Response observer /// When the fetcher is created it also returns a [`ResponseObserver`] which can be used to check @@ -120,10 +124,9 @@ async fn fetch_and_hash_response(info_endpoint: &Endpoint) -> Result<(String, by /// endpoint, /// std::time::Duration::from_secs(5 * 60), /// ); -/// // Start the runner -/// tokio::spawn(async move { -/// fetcher.run().await; -/// }); +/// // Start the fetcher on a shared runtime +/// let runtime = libdd_data_pipeline::shared_runtime::SharedRuntime::new()?; +/// runtime.spawn_worker(fetcher)?; /// /// // Get the Arc to access the info /// let agent_info_arc = agent_info::get_agent_info(); @@ -143,6 +146,7 @@ pub struct AgentInfoFetcher { info_endpoint: Endpoint, refresh_interval: Duration, trigger_rx: Option>, + trigger_tx: mpsc::Sender<()>, } impl AgentInfoFetcher { @@ -160,6 +164,7 @@ impl AgentInfoFetcher { info_endpoint, refresh_interval, trigger_rx: Some(trigger_rx), + trigger_tx: trigger_tx.clone(), }; let response_observer = ResponseObserver::new(trigger_tx); @@ -176,47 +181,54 @@ impl AgentInfoFetcher { } } +#[async_trait] impl Worker for AgentInfoFetcher { - /// Start fetching the info endpoint with the given interval. - /// - /// # Warning - /// This method does not return and should be called within a dedicated task. - async fn run(&mut self) { - // Skip the first fetch if some info is present to avoid calling the /info endpoint - // at fork for heavy-forking environment. + async fn initial_trigger(&mut self) { + // Skip initial wait if cache is not populated if AGENT_INFO_CACHE.load().is_none() { - self.fetch_and_update().await; + return; } + self.trigger().await + } - // Main loop waiting for a trigger event or the end of the refresh interval to trigger the - // fetch. - loop { - match &mut self.trigger_rx { - Some(trigger_rx) => { - tokio::select! { - // Wait for manual trigger (new state from headers) - trigger = trigger_rx.recv() => { - if trigger.is_some() { - self.fetch_and_update().await; - } else { - // The channel has been closed - self.trigger_rx = None; - } - } - // Regular periodic fetch timer - _ = sleep(self.refresh_interval) => { - self.fetch_and_update().await; + async fn trigger(&mut self) { + // Wait for either a manual trigger or the refresh interval + match &mut self.trigger_rx { + Some(trigger_rx) => { + tokio::select! { + // Wait for manual trigger (new state from headers) + trigger = trigger_rx.recv() => { + if trigger.is_none() { + // The channel has been closed + self.trigger_rx = None; } - }; - } - None => { - // If the trigger channel is closed we only use timed fetch. - sleep(self.refresh_interval).await; - self.fetch_and_update().await; + } + // Regular periodic fetch timer + _ = sleep(self.refresh_interval) => {} } } + None => { + // If the trigger channel is closed we only use timed fetch. + sleep(self.refresh_interval).await; + } } } + + async fn on_pause(&mut self) { + // Release the IoStack waker stored in trigger_rx by waking the channel, + // then drain the message to avoid a spurious fetch on restart. + let _ = self.trigger_tx.try_send(()); + self.drain(); + } + + fn reset(&mut self) { + // Drain all messages from the channel to remove messages sent to release the reference on + self.drain(); + } + + async fn run(&mut self) { + self.fetch_and_update().await; + } } impl AgentInfoFetcher { @@ -291,6 +303,7 @@ impl ResponseObserver { mod single_threaded_tests { use super::*; use crate::agent_info; + use crate::shared_runtime::SharedRuntime; use httpmock::prelude::*; const TEST_INFO: &str = r#"{ @@ -432,29 +445,31 @@ mod single_threaded_tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_agent_info_fetcher_run() { + #[test] + fn test_agent_info_fetcher_run() { AGENT_INFO_CACHE.store(None); let server = MockServer::start(); - let mock_v1 = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"1"}"#); - }) - .await; + let mut mock_v1 = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"1"}"#); + }); let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let (mut fetcher, _response_observer) = + let (fetcher, _response_observer) = AgentInfoFetcher::new(endpoint.clone(), Duration::from_millis(100)); assert!(agent_info::get_agent_info().is_none()); - tokio::spawn(async move { - fetcher.run().await; - }); + let shared_runtime = SharedRuntime::new().unwrap(); + shared_runtime.spawn_worker(fetcher).unwrap(); // Wait until the info is fetched + let start = std::time::Instant::now(); while agent_info::get_agent_info().is_none() { - tokio::time::sleep(Duration::from_millis(100)).await; + assert!( + start.elapsed() <= Duration::from_secs(10), + "Timeout waiting for first /info fetch" + ); + std::thread::sleep(Duration::from_millis(100)); } let version_1 = agent_info::get_agent_info() @@ -465,22 +480,25 @@ mod single_threaded_tests { .clone() .unwrap(); assert_eq!(version_1, "1"); - mock_v1.assert_async().await; + mock_v1.assert(); // Update the info endpoint - mock_v1.delete_async().await; - let mock_v2 = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"2"}"#); - }) - .await; + mock_v1.delete(); + let mock_v2 = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"2"}"#); + }); // Wait for second fetch - while mock_v2.calls_async().await == 0 { - tokio::time::sleep(Duration::from_millis(100)).await; + let start = std::time::Instant::now(); + while mock_v2.calls() == 0 { + assert!( + start.elapsed() <= Duration::from_secs(10), + "Timeout waiting for second /info fetch" + ); + std::thread::sleep(Duration::from_millis(100)); } // This check is not 100% deterministic, but between the time the mock returns the response @@ -499,22 +517,20 @@ mod single_threaded_tests { assert_eq!(version_2, "2"); break; } - tokio::time::sleep(Duration::from_millis(100)).await; + std::thread::sleep(Duration::from_millis(100)); } } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_agent_info_trigger_different_state() { + #[test] + fn test_agent_info_trigger_different_state() { let server = MockServer::start(); - let mock = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"triggered"}"#); - }) - .await; + let mock = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"triggered"}"#); + }); // Populate the cache with initial state AGENT_INFO_CACHE.store(Some(Arc::new(AgentInfo { @@ -523,13 +539,12 @@ mod single_threaded_tests { }))); let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let (mut fetcher, response_observer) = + let (fetcher, response_observer) = // Interval is too long to fetch during the test AgentInfoFetcher::new(endpoint, Duration::from_secs(3600)); - tokio::spawn(async move { - fetcher.run().await; - }); + let shared_runtime = SharedRuntime::new().unwrap(); + shared_runtime.spawn_worker(fetcher).unwrap(); // Create a mock HTTP response with the new agent state let response = http_common::empty_response( @@ -547,13 +562,13 @@ mod single_threaded_tests { const SLEEP_DURATION_MS: u64 = 10; let mut attempts = 0; - while mock.calls_async().await == 0 && attempts < MAX_ATTEMPTS { + while mock.calls() == 0 && attempts < MAX_ATTEMPTS { attempts += 1; - tokio::time::sleep(Duration::from_millis(SLEEP_DURATION_MS)).await; + std::thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); } // Should trigger a fetch since the state is different - mock.assert_calls_async(1).await; + mock.assert_calls(1); // Wait for the cache to be updated with proper timeout let mut attempts = 0; @@ -567,7 +582,7 @@ mod single_threaded_tests { } } attempts += 1; - tokio::time::sleep(Duration::from_millis(SLEEP_DURATION_MS)).await; + std::thread::sleep(Duration::from_millis(SLEEP_DURATION_MS)); } // Verify the cache was updated @@ -587,17 +602,15 @@ mod single_threaded_tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_agent_info_trigger_same_state() { + #[test] + fn test_agent_info_trigger_same_state() { let server = MockServer::start(); - let mock = server - .mock_async(|when, then| { - when.path("/info"); - then.status(200) - .header("content-type", "application/json") - .body(r#"{"version":"same"}"#); - }) - .await; + let mock = server.mock(|when, then| { + when.path("/info"); + then.status(200) + .header("content-type", "application/json") + .body(r#"{"version":"same"}"#); + }); let same_json = r#"{"version":"same"}"#; let same_hash = calculate_hash(same_json); @@ -609,12 +622,11 @@ mod single_threaded_tests { }))); let endpoint = Endpoint::from_url(server.url("/info").parse().unwrap()); - let (mut fetcher, response_observer) = + let (fetcher, response_observer) = AgentInfoFetcher::new(endpoint, Duration::from_secs(3600)); // Very long interval - tokio::spawn(async move { - fetcher.run().await; - }); + let shared_runtime = SharedRuntime::new().unwrap(); + shared_runtime.spawn_worker(fetcher).unwrap(); // Create a mock HTTP response with the same agent state let response = http_common::empty_response( @@ -628,9 +640,9 @@ mod single_threaded_tests { response_observer.check_response(&response); // Wait to ensure no fetch occurs - tokio::time::sleep(Duration::from_millis(500)).await; + std::thread::sleep(Duration::from_millis(500)); // Should not trigger a fetch since the state is the same - mock.assert_calls_async(0).await; + mock.assert_calls(0); } } diff --git a/libdd-data-pipeline/src/lib.rs b/libdd-data-pipeline/src/lib.rs index 57572cd97a..c059939a55 100644 --- a/libdd-data-pipeline/src/lib.rs +++ b/libdd-data-pipeline/src/lib.rs @@ -13,6 +13,7 @@ pub mod agent_info; mod health_metrics; mod pausable_worker; +pub mod shared_runtime; #[allow(missing_docs)] pub mod stats_exporter; pub(crate) mod telemetry; diff --git a/libdd-data-pipeline/src/pausable_worker.rs b/libdd-data-pipeline/src/pausable_worker.rs index 223d0af246..f2a1ad5090 100644 --- a/libdd-data-pipeline/src/pausable_worker.rs +++ b/libdd-data-pipeline/src/pausable_worker.rs @@ -5,11 +5,7 @@ use libdd_common::worker::Worker; use std::fmt::Display; -use tokio::{ - runtime::Runtime, - select, - task::{JoinError, JoinHandle}, -}; +use tokio::{runtime::Runtime, select, task::JoinHandle}; use tokio_util::sync::CancellationToken; /// A pausable worker which can be paused and restarted on forks. @@ -68,40 +64,93 @@ impl PausableWorker { /// Start the worker on the given runtime. /// /// The worker's main loop will be run on the runtime. - /// - /// # Errors - /// Fails if the worker is in an invalid state. pub fn start(&mut self, rt: &Runtime) -> Result<(), PausableWorkerError> { - if let Self::Running { .. } = self { - Ok(()) - } else if let Self::Paused { mut worker } = std::mem::replace(self, Self::InvalidState) { - // Worker is temporarily in an invalid state, but since this block is failsafe it will - // be replaced by a valid state. - let stop_token = CancellationToken::new(); - let cloned_token = stop_token.clone(); - let handle = rt.spawn(async move { - select! { - _ = worker.run() => {worker} - _ = cloned_token.cancelled() => {worker} - } - }); + match self { + PausableWorker::Running { .. } => Ok(()), + PausableWorker::Paused { .. } => { + let PausableWorker::Paused { mut worker } = + std::mem::replace(self, PausableWorker::InvalidState) + else { + // Unreachable + return Ok(()); + }; + + // Worker is temporarily in an invalid state, but since this block is failsafe it + // will be replaced by a valid state. + let stop_token = CancellationToken::new(); + let cloned_token = stop_token.clone(); + let handle = rt.spawn(async move { + // First iteration using initial_trigger + select! { + _ = worker.initial_trigger() => { + worker.run().await; + } + _ = cloned_token.cancelled() => { + return worker; + } + } + + // Regular iterations + loop { + select! { + _ = worker.trigger() => { + worker.run().await; + } + _ = cloned_token.cancelled() => { + break; + } + } + } + worker + }); + + *self = PausableWorker::Running { handle, stop_token }; + Ok(()) + } + PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState), + } + } - *self = PausableWorker::Running { handle, stop_token }; - Ok(()) - } else { - Err(PausableWorkerError::InvalidState) + /// Request the worker to pause without waiting for task termination. + /// + /// This is useful when pausing multiple workers in parallel. + pub fn request_pause(&self) -> Result<(), PausableWorkerError> { + match self { + PausableWorker::Running { stop_token, .. } => { + stop_token.cancel(); + Ok(()) + } + PausableWorker::Paused { .. } => Ok(()), + PausableWorker::InvalidState => Err(PausableWorkerError::InvalidState), } } - /// Pause the worker saving it's state to be restarted. + /// Pause the worker and wait for it to complete, storing its state for restart. + /// + /// This method will cancel the worker's cancellation token if it hasn't been cancelled yet, + /// then wait for the worker to finish and store its state. Calling [`Self::request_pause`] + /// before this method is optional - it's only needed when shutting down multiple workers + /// simultaneously to allow them to pause concurrently before waiting for all of them. /// /// # Errors /// Fails if the worker handle has been aborted preventing the worker from being retrieved. - pub async fn pause(&mut self) -> Result<(), PausableWorkerError> { + pub async fn join(&mut self) -> Result<(), PausableWorkerError> { match self { - PausableWorker::Running { handle, stop_token } => { - stop_token.cancel(); - if let Ok(worker) = handle.await { + PausableWorker::Running { .. } => { + let PausableWorker::Running { handle, stop_token } = + std::mem::replace(self, PausableWorker::InvalidState) + else { + // Unreachable + return Ok(()); + }; + + // Cancel the token if it hasn't been cancelled yet to avoid deadlock + if !stop_token.is_cancelled() { + stop_token.cancel(); + } + + if let Ok(mut worker) = handle.await { + worker.on_pause().await; *self = PausableWorker::Paused { worker }; Ok(()) } else { @@ -115,17 +164,24 @@ impl PausableWorker { } } - /// Wait for the run method of the worker to exit. - pub async fn join(self) -> Result<(), JoinError> { - if let PausableWorker::Running { handle, .. } = self { - handle.await?; + /// Reset the worker state (e.g. in a fork child). + pub fn reset(&mut self) { + if let PausableWorker::Paused { worker } = self { + worker.reset(); + } + } + + /// Shutdown the worker. + pub async fn shutdown(&mut self) { + if let PausableWorker::Paused { worker } = self { + worker.shutdown().await; } - Ok(()) } } #[cfg(test)] mod tests { + use async_trait::async_trait; use tokio::{runtime::Builder, time::sleep}; use super::*; @@ -135,18 +191,21 @@ mod tests { }; /// Test worker incrementing the state and sending it with the sender. + #[derive(Debug)] struct TestWorker { state: u32, sender: Sender, } + #[async_trait] impl Worker for TestWorker { async fn run(&mut self) { - loop { - let _ = self.sender.send(self.state); - self.state += 1; - sleep(Duration::from_millis(100)).await; - } + let _ = self.sender.send(self.state); + self.state += 1; + } + + async fn trigger(&mut self) { + sleep(Duration::from_millis(100)).await; } } @@ -160,7 +219,7 @@ mod tests { pausable_worker.start(&runtime).unwrap(); assert_eq!(receiver.recv().unwrap(), 0); - runtime.block_on(async { pausable_worker.pause().await.unwrap() }); + runtime.block_on(async { pausable_worker.join().await.unwrap() }); // Empty the message queue and get the last message let mut next_message = 1; for message in receiver.try_iter() { diff --git a/libdd-data-pipeline/src/shared_runtime.rs b/libdd-data-pipeline/src/shared_runtime.rs new file mode 100644 index 0000000000..4db4755be7 --- /dev/null +++ b/libdd-data-pipeline/src/shared_runtime.rs @@ -0,0 +1,454 @@ +// Copyright 2025-Present Datadog, Inc. https://www.datadoghq.com/ +// SPDX-License-Identifier: Apache-2.0 + +//! SharedRuntime for managing PausableWorkers across fork boundaries. +//! +//! This module provides a SharedRuntime that manages a tokio runtime and allows +//! spawning PausableWorkers on it. It also provides hooks for safely handling +//! fork operations by pausing workers before fork and restarting them appropriately +//! in parent and child processes. + +use crate::pausable_worker::{PausableWorker, PausableWorkerError}; +use libdd_common::{worker::Worker, MutexExt}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::{fmt, io}; +use tokio::runtime::{Builder, Runtime}; +use tokio::task::JoinSet; + +type BoxedWorker = Box; + +#[derive(Debug)] +struct WorkerEntry { + id: u64, + worker: PausableWorker, +} + +/// Handle to a worker registered on a [`SharedRuntime`]. +/// +/// This handle can be used to stop the worker. +#[derive(Clone, Debug)] +pub struct WorkerHandle { + worker_id: u64, + workers: Arc>>, +} + +#[derive(Debug)] +pub enum WorkerHandleError { + AlreadyStopped, + WorkerError(PausableWorkerError), +} + +impl fmt::Display for WorkerHandleError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::AlreadyStopped => { + write!(f, "Worker has already been stopped") + } + Self::WorkerError(err) => write!(f, "Worker error: {}", err), + } + } +} + +impl From for WorkerHandleError { + fn from(err: PausableWorkerError) -> Self { + Self::WorkerError(err) + } +} + +impl WorkerHandle { + /// Stop the worker and execute the shutdown logic. + /// + /// # Errors + /// Returns an error if the worker has already been stopped. + pub async fn stop(self) -> Result<(), WorkerHandleError> { + let mut worker = { + let mut workers_lock = self.workers.lock_or_panic(); + let Some(position) = workers_lock + .iter() + .position(|entry| entry.id == self.worker_id) + else { + return Err(WorkerHandleError::AlreadyStopped); + }; + let WorkerEntry { worker, .. } = workers_lock.swap_remove(position); + worker + }; + worker.join().await?; + worker.shutdown().await; + Ok(()) + } +} + +/// Errors that can occur when using SharedRuntime. +#[derive(Debug)] +pub enum SharedRuntimeError { + /// The runtime is not available or in an invalid state. + RuntimeUnavailable, + /// Failed to acquire a lock on internal state. + LockFailed(String), + /// A worker operation failed. + WorkerError(PausableWorkerError), + /// Failed to create the tokio runtime. + RuntimeCreation(io::Error), + /// Shutdown timed out. + ShutdownTimedOut(std::time::Duration), +} + +impl fmt::Display for SharedRuntimeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::RuntimeUnavailable => { + write!(f, "Runtime is not available or in an invalid state") + } + Self::LockFailed(msg) => write!(f, "Failed to acquire lock: {}", msg), + Self::WorkerError(err) => write!(f, "Worker error: {}", err), + Self::RuntimeCreation(err) => { + write!(f, "Failed to create runtime: {}", err) + } + Self::ShutdownTimedOut(duration) => { + write!(f, "Shutdown timed out after {:?}", duration) + } + } + } +} + +impl std::error::Error for SharedRuntimeError {} + +impl From for SharedRuntimeError { + fn from(err: PausableWorkerError) -> Self { + SharedRuntimeError::WorkerError(err) + } +} + +impl From for SharedRuntimeError { + fn from(err: io::Error) -> Self { + SharedRuntimeError::RuntimeCreation(err) + } +} + +/// A shared runtime that manages PausableWorkers and provides fork safety hooks. +/// +/// The SharedRuntime owns a tokio runtime and tracks PausableWorkers spawned on it. +/// It provides methods to safely pause workers before forking and restart them +/// after fork in both parent and child processes. +#[derive(Debug)] +pub struct SharedRuntime { + runtime: Arc>>>, + workers: Arc>>, + next_worker_id: AtomicU64, +} + +impl SharedRuntime { + /// Create a new SharedRuntime with a default multi-threaded tokio runtime. + /// + /// # Errors + /// Returns an error if the tokio runtime cannot be created. + pub fn new() -> Result { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build()?; + + Ok(Self { + runtime: Arc::new(Mutex::new(Some(Arc::new(runtime)))), + workers: Arc::new(Mutex::new(Vec::new())), + next_worker_id: AtomicU64::new(1), + }) + } + + /// Spawn a PausableWorker on this runtime. + /// + /// The worker will be tracked by this SharedRuntime and will be paused/resumed + /// during fork operations. + /// + /// # Errors + /// Returns an error if the runtime is not available or the worker cannot be started. + pub fn spawn_worker( + &self, + worker: T, + ) -> Result { + let boxed_worker: BoxedWorker = Box::new(worker); + let mut pausable_worker = PausableWorker::new(boxed_worker); + let worker_id = self.next_worker_id.fetch_add(1, Ordering::Relaxed); + + let runtime_lock = self.runtime.lock_or_panic(); + + // If the runtime is not available, it's added to the worker list and will be started when + // the runtime is recreated. + if let Some(runtime) = runtime_lock.as_ref() { + pausable_worker.start(runtime)?; + } + + let mut workers_lock = self.workers.lock_or_panic(); + workers_lock.push(WorkerEntry { + id: worker_id, + worker: pausable_worker, + }); + + Ok(WorkerHandle { + worker_id, + workers: self.workers.clone(), + }) + } + + /// Hook to be called before forking. + /// + /// This method pauses all workers and prepares the runtime for forking. + /// It ensures that no background tasks are running when the fork occurs, + /// preventing potential deadlocks in the child process. + /// + /// Worker errors are logged but do not cause the function to fail. + pub fn before_fork(&self) { + use tracing::error; + + if let Some(runtime) = self.runtime.lock_or_panic().take() { + let mut workers_lock = self.workers.lock_or_panic(); + runtime.block_on(async move { + for worker_entry in workers_lock.iter_mut() { + let _ = worker_entry.worker.request_pause(); + } + + for worker_entry in workers_lock.iter_mut() { + if let Err(e) = worker_entry.worker.join().await { + error!("Worker failed to pause before fork: {:?}", e); + } + } + }); + } + } + + fn restart_runtime(&self) -> Result<(), SharedRuntimeError> { + let mut runtime_lock = self.runtime.lock_or_panic(); + if runtime_lock.is_none() { + *runtime_lock = Some(Arc::new( + Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build()?, + )); + } + Ok(()) + } + + /// Hook to be called in the parent process after forking. + /// + /// This method restarts workers and resumes normal operation in the parent process. + /// The runtime may need to be recreated if it was shut down in before_fork. + /// + /// # Errors + /// Returns an error if workers cannot be restarted or the runtime cannot be recreated. + pub fn after_fork_parent(&self) -> Result<(), SharedRuntimeError> { + self.restart_runtime()?; + + let runtime_lock = self.runtime.lock_or_panic(); + let runtime = runtime_lock + .as_ref() + .ok_or(SharedRuntimeError::RuntimeUnavailable)?; + + let mut workers_lock = self.workers.lock_or_panic(); + + // Restart all workers + for worker_entry in workers_lock.iter_mut() { + worker_entry.worker.start(runtime)?; + } + + Ok(()) + } + + /// Hook to be called in the child process after forking. + /// + /// This method reinitializes the runtime and workers in the child process. + /// A new runtime must be created since tokio runtimes cannot be safely forked. + /// Workers can optionally be restarted to resume operations in the child. + /// + /// # Errors + /// Returns an error if the runtime cannot be reinitialized or workers cannot be started. + pub fn after_fork_child(&self) -> Result<(), SharedRuntimeError> { + self.restart_runtime()?; + + let runtime_lock = self.runtime.lock_or_panic(); + let runtime = runtime_lock + .as_ref() + .ok_or(SharedRuntimeError::RuntimeUnavailable)?; + + let mut workers_lock = self.workers.lock_or_panic(); + + // Restart all workers in child process + for worker_entry in workers_lock.iter_mut() { + worker_entry.worker.reset(); + worker_entry.worker.start(runtime)?; + } + + Ok(()) + } + + /// Get a reference to the underlying runtime or create a single-threaded one. + /// + /// This allows external code to spawn additional tasks on the runtime if needed. + /// + /// # Warning + /// Since this method can return a single-threaded runtime it should only be use to + /// execute async code with `block_on` if you need to spawn async code on it without blocking, + /// you should us a `Worker` instead. + /// + /// # Errors + /// Returns an error if it fails to create a runtime. + pub fn runtime(&self) -> Result, io::Error> { + match self.runtime.lock_or_panic().as_ref() { + None => Ok(Arc::new( + Builder::new_current_thread().enable_all().build()?, + )), + Some(runtime) => Ok(runtime.clone()), + } + } + + /// Shutdown the runtime and all workers synchronously with optional timeout. + /// + /// Worker errors are logged but do not cause the function to fail. + /// + /// # Errors + /// Returns an error only if shutdown times out. + pub fn shutdown(&self, timeout: Option) -> Result<(), SharedRuntimeError> { + match self.runtime.lock_or_panic().take() { + Some(runtime) => { + let result = if let Some(timeout) = timeout { + match runtime.block_on(async { + tokio::time::timeout(timeout, self.shutdown_async()).await + }) { + Ok(()) => Ok(()), + Err(_) => Err(SharedRuntimeError::ShutdownTimedOut(timeout)), + } + } else { + runtime.block_on(self.shutdown_async()); + Ok(()) + }; + result + } + None => Ok(()), // The runtime is not running so there's nothing to shutdown + } + } + + /// Shutdown all workers asynchronously. + /// + /// This should be called during application shutdown to cleanly stop all + /// background workers and the runtime. + /// + /// Worker errors are logged but do not cause the function to fail. + /// + /// This function should not take ownership of the SharedRuntime as it will cause the runtime + /// to be dropped in a non-blocking context causing a panic. + pub async fn shutdown_async(&self) { + use tracing::error; + + let workers = { + let mut workers_lock = self.workers.lock_or_panic(); + std::mem::take(&mut *workers_lock) + }; + + let mut join_set = JoinSet::new(); + for mut worker_entry in workers { + join_set.spawn(async move { + let result = worker_entry.worker.join().await; + if let Err(e) = result { + error!("Worker failed to shutdown: {:?}", e); + return; + } + worker_entry.worker.shutdown().await; + }); + } + + join_set.join_all().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use std::sync::mpsc::{channel, Sender}; + use std::time::Duration; + use tokio::time::sleep; + + #[derive(Debug)] + struct TestWorker { + state: u32, + sender: Sender, + } + + #[async_trait] + impl Worker for TestWorker { + async fn run(&mut self) { + let _ = self.sender.send(self.state); + self.state += 1; + } + + async fn trigger(&mut self) { + sleep(Duration::from_millis(100)).await; + } + } + + #[test] + fn test_shared_runtime_creation() { + let shared_runtime = SharedRuntime::new(); + assert!(shared_runtime.is_ok()); + } + + #[test] + fn test_spawn_worker() { + let shared_runtime = SharedRuntime::new().unwrap(); + let (sender, _receiver) = channel::(); + let worker = TestWorker { state: 0, sender }; + + let result = shared_runtime.spawn_worker(worker); + assert!(result.is_ok()); + assert_eq!(shared_runtime.workers.lock_or_panic().len(), 1); + } + + #[test] + fn test_worker_handle_stop_removes_worker() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let shared_runtime = SharedRuntime::new().unwrap(); + let (sender, _receiver) = channel::(); + let worker = TestWorker { state: 0, sender }; + + let handle = shared_runtime.spawn_worker(worker).unwrap(); + assert_eq!(shared_runtime.workers.lock_or_panic().len(), 1); + + rt.block_on(async { + assert!(handle.stop().await.is_ok()); + }); + + assert_eq!(shared_runtime.workers.lock_or_panic().len(), 0); + } + + #[test] + fn test_before_and_after_fork_parent() { + // Run in a separate thread to ensure we're not in any async context + let handle = std::thread::spawn(|| { + let rt = tokio::runtime::Runtime::new().unwrap(); + let shared_runtime = SharedRuntime::new().unwrap(); + + // Test before_fork + shared_runtime.before_fork(); + + // Test after_fork_parent (synchronous) + assert!(shared_runtime.after_fork_parent().is_ok()); + + // Clean shutdown + rt.block_on(async { + shared_runtime.shutdown_async().await; + }); + }); + + handle.join().expect("Thread panicked"); + } + + #[test] + fn test_after_fork_child() { + // Test after_fork_child in a non-async context + let shared_runtime = SharedRuntime::new().unwrap(); + + // This should succeed as we're not in an async context + assert!(shared_runtime.after_fork_child().is_ok()); + } +} diff --git a/libdd-data-pipeline/src/stats_exporter.rs b/libdd-data-pipeline/src/stats_exporter.rs index 6b64c09ecc..4dc67e2f01 100644 --- a/libdd-data-pipeline/src/stats_exporter.rs +++ b/libdd-data-pipeline/src/stats_exporter.rs @@ -12,12 +12,11 @@ use std::{ }; use crate::trace_exporter::TracerMetadata; +use async_trait::async_trait; use libdd_common::{worker::Worker, Endpoint, HttpClient}; use libdd_trace_protobuf::pb; use libdd_trace_stats::span_concentrator::SpanConcentrator; use libdd_trace_utils::send_with_retry::{send_with_retry, RetryStrategy}; -use tokio::select; -use tokio_util::sync::CancellationToken; use tracing::error; const STATS_ENDPOINT_PATH: &str = "/v0.6/stats"; @@ -30,7 +29,6 @@ pub struct StatsExporter { endpoint: Endpoint, meta: TracerMetadata, sequence_id: AtomicU64, - cancellation_token: CancellationToken, client: HttpClient, } @@ -48,7 +46,6 @@ impl StatsExporter { concentrator: Arc>, meta: TracerMetadata, endpoint: Endpoint, - cancellation_token: CancellationToken, client: HttpClient, ) -> Self { Self { @@ -57,7 +54,6 @@ impl StatsExporter { endpoint, meta, sequence_id: AtomicU64::new(0), - cancellation_token, client, } } @@ -132,24 +128,20 @@ impl StatsExporter { } } +#[async_trait] impl Worker for StatsExporter { - /// Run loop of the stats exporter - /// - /// Once started, the stats exporter will flush and send stats on every `self.flush_interval`. - /// If the `self.cancellation_token` is cancelled, the exporter will force flush all stats and - /// return. + async fn trigger(&mut self) { + tokio::time::sleep(self.flush_interval).await; + } + + /// Flush and send stats on every trigger. async fn run(&mut self) { - loop { - select! { - _ = self.cancellation_token.cancelled() => { - let _ = self.send(true).await; - break; - }, - _ = tokio::time::sleep(self.flush_interval) => { - let _ = self.send(false).await; - }, - }; - } + let _ = self.send(false).await; + } + + async fn shutdown(&mut self) { + // Force flush all stats on shutdown + let _ = self.send(true).await; } } @@ -189,6 +181,7 @@ pub fn stats_url_from_agent_url(agent_url: &str) -> anyhow::Result { #[cfg(test)] mod tests { use super::*; + use crate::shared_runtime::SharedRuntime; use httpmock::prelude::*; use httpmock::MockServer; use libdd_common::http_common::new_default_client; @@ -267,7 +260,6 @@ mod tests { Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - CancellationToken::new(), new_default_client(), ); @@ -295,7 +287,6 @@ mod tests { Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - CancellationToken::new(), new_default_client(), ); @@ -328,7 +319,6 @@ mod tests { Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - CancellationToken::new(), new_default_client(), ); @@ -347,40 +337,39 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn test_cancellation_token() { - let server = MockServer::start_async().await; - - let mut mock = server - .mock_async(|when, then| { - when.method(POST) - .header("Content-type", "application/msgpack") - .path("/v0.6/stats") - .body_includes("libdatadog-test"); - then.status(200).body(""); - }) - .await; + #[test] + fn test_worker_shutdown() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + let server = MockServer::start(); + + let mut mock = server.mock(|when, then| { + when.method(POST) + .header("Content-type", "application/msgpack") + .path("/v0.6/stats") + .body_includes("libdatadog-test"); + then.status(200).body(""); + }); let buckets_duration = Duration::from_secs(10); - let cancellation_token = CancellationToken::new(); - let mut stats_exporter = StatsExporter::new( + let stats_exporter = StatsExporter::new( buckets_duration, Arc::new(Mutex::new(get_test_concentrator())), get_test_metadata(), Endpoint::from_url(stats_url_from_agent_url(&server.url("/")).unwrap()), - cancellation_token.clone(), new_default_client(), ); - tokio::spawn(async move { - stats_exporter.run().await; - }); - // Cancel token to trigger force flush - cancellation_token.cancel(); + let _handle = shared_runtime + .spawn_worker(stats_exporter) + .expect("Failed to spawn worker"); + + shared_runtime.shutdown(None).unwrap(); assert!( - poll_for_mock_hit(&mut mock, 10, 100, 1, false).await, + rt.block_on(poll_for_mock_hit(&mut mock, 10, 100, 1, false)), "Expected max retry attempts" ); } diff --git a/libdd-data-pipeline/src/telemetry/mod.rs b/libdd-data-pipeline/src/telemetry/mod.rs index 9715aa50ae..0ec7d1817f 100644 --- a/libdd-data-pipeline/src/telemetry/mod.rs +++ b/libdd-data-pipeline/src/telemetry/mod.rs @@ -297,14 +297,6 @@ impl TelemetryClient { .send_msg(TelemetryActions::Lifecycle(LifecycleAction::Start)) .await; } - - /// Shutdowns the telemetry client. - pub async fn shutdown(self) { - _ = self - .worker - .send_msg(TelemetryActions::Lifecycle(LifecycleAction::Stop)) - .await; - } } #[cfg(test)] @@ -312,28 +304,32 @@ mod tests { use http::{Response, StatusCode}; use httpmock::Method::POST; use httpmock::MockServer; - use libdd_common::{http_common, worker::Worker}; + use libdd_common::http_common; + use libdd_trace_utils::test_utils::poll_for_mock_hit; use regex::Regex; use tokio::time::sleep; use super::*; + use crate::shared_runtime::{SharedRuntime, WorkerHandle}; - async fn get_test_client(url: &str) -> TelemetryClient { - let (client, mut worker) = TelemetryClientBuilder::default() + fn get_test_client(url: &str, runtime: &SharedRuntime) -> (TelemetryClient, WorkerHandle) { + let (client, worker) = TelemetryClientBuilder::default() .set_service_name("test_service") .set_service_version("test_version") .set_env("test_env") .set_language("test_language") .set_language_version("test_language_version") .set_tracer_version("test_tracer_version") + .set_runtime_id("foo") .set_url(url) .set_heartbeat(100) .set_debug_enabled(true) .build(Handle::current()); - tokio::spawn(async move { worker.run().await }); - client + let handle = runtime + .spawn_worker(worker) + .expect("Failed to spawn worker"); + (client, handle) } - #[test] fn builder_test() { let builder = TelemetryClientBuilder::default() @@ -365,306 +361,353 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test(flavor = "multi_thread")] - async fn spawn_test() { - let _ = TelemetryClientBuilder::default() - .set_service_name("test_service") - .set_service_version("test_version") - .set_env("test_env") - .set_language("test_language") - .set_language_version("test_language_version") - .set_tracer_version("test_tracer_version") - .build(Handle::current()); - } - - #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn api_bytes_test() { + #[test] + fn api_bytes_test() { let payload = Regex::new(r#""metric":"trace_api.bytes","tags":\["src_library:libdatadog"\],"sketch_b64":".+","common":true,"interval":\d+,"type":"distribution""#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - bytes_sent: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + bytes_sent: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn requests_test() { + #[test] + fn requests_test() { let payload = Regex::new(r#""metric":"trace_api.requests","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog"\],"common":true,"type":"count""#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - requests_count: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + requests_count: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn responses_per_code_test() { + #[test] + fn responses_per_code_test() { let payload = Regex::new(r#""metric":"trace_api.responses","points":\[\[\d+,1\.0\]\],"tags":\["status_code:200","src_library:libdatadog"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - responses_count_per_code: HashMap::from([(200, 1)]), - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + responses_count_per_code: HashMap::from([(200, 1)]), + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn errors_timeout_test() { + #[test] + fn errors_timeout_test() { let payload = Regex::new(r#""metric":"trace_api.errors","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","type:timeout"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - errors_timeout: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + errors_timeout: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn errors_network_test() { + #[test] + fn errors_network_test() { let payload = Regex::new(r#""metric":"trace_api.errors","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","type:network"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - errors_network: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + errors_network: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn errors_status_code_test() { + #[test] + fn errors_status_code_test() { let payload = Regex::new(r#""metric":"trace_api.errors","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","type:status_code"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - errors_status_code: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + errors_status_code: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_sent_test() { + #[test] + fn chunks_sent_test() { let payload = Regex::new(r#""metric":"trace_chunks_sent","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - chunks_sent: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + chunks_sent: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_dropped_send_failure_test() { + #[test] + fn chunks_dropped_send_failure_test() { let payload = Regex::new(r#""metric":"trace_chunks_dropped","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","reason:send_failure"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - chunks_dropped_send_failure: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + chunks_dropped_send_failure: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_dropped_p0_test() { + #[test] + fn chunks_dropped_p0_test() { let payload = Regex::new(r#""metric":"trace_chunks_dropped","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","reason:p0_drop"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - chunks_dropped_p0: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + chunks_dropped_p0: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn chunks_dropped_serialization_error_test() { + #[test] + fn chunks_dropped_serialization_error_test() { let payload = Regex::new(r#""metric":"trace_chunks_dropped","points":\[\[\d+,1\.0\]\],"tags":\["src_library:libdatadog","reason:serialization_error"\],"common":true,"type":"count"#).unwrap(); - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_matches(payload); - then.status(200).body(""); - }) - .await; - - let data = SendPayloadTelemetry { - chunks_dropped_serialization_error: 1, - ..Default::default() - }; + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_matches(payload); + then.status(200).body(""); + }) + .await; + + let data = SendPayloadTelemetry { + chunks_dropped_serialization_error: 1, + ..Default::default() + }; - let client = get_test_client(&server.url("/")).await; + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + let _ = client.send(&data); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; - client.start().await; - let _ = client.send(&data); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { - sleep(Duration::from_millis(10)).await; - } - telemetry_srv.assert_calls_async(1).await; + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + }); } #[test] @@ -763,8 +806,8 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn telemetry_from_build_error_test() { + #[test] + fn telemetry_from_build_error_test() { let result = Err(SendWithRetryError::Build(5)); let telemetry = SendPayloadTelemetry::from_retry_result(&result, 1, 2, 0); assert_eq!( @@ -807,88 +850,74 @@ mod tests { } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn runtime_id_test() { - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST).body_includes(r#""runtime_id":"foo""#); - then.status(200).body(""); - }) - .await; - - let (client, mut worker) = TelemetryClientBuilder::default() - .set_service_name("test_service") - .set_service_version("test_version") - .set_env("test_env") - .set_language("test_language") - .set_language_version("test_language_version") - .set_tracer_version("test_tracer_version") - .set_url(&server.url("/")) - .set_heartbeat(100) - .set_runtime_id("foo") - .build(Handle::current()); - tokio::spawn(async move { worker.run().await }); - - client.start().await; - client - .send(&SendPayloadTelemetry { - requests_count: 1, - ..Default::default() - }) - .unwrap(); - client.shutdown().await; - while telemetry_srv.calls_async().await == 0 { + #[test] + fn runtime_id_test() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST).body_includes(r#""runtime_id":"foo""#); + then.status(200).body(""); + }) + .await; + + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + client + .send(&SendPayloadTelemetry { + requests_count: 1, + ..Default::default() + }) + .unwrap(); + // Wait for send to be processed sleep(Duration::from_millis(10)).await; - } - // One payload generate-metrics - telemetry_srv.assert_calls_async(1).await; + + handle.stop().await.expect("Failed to stop worker"); + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + // One payload generate-metrics + }); } #[cfg_attr(miri, ignore)] - #[tokio::test] - async fn application_metadata_test() { - let server = MockServer::start_async().await; - - let telemetry_srv = server - .mock_async(|when, then| { - when.method(POST) - .body_includes(r#""application":{"service_name":"test_service","service_version":"test_version","env":"test_env","language_name":"test_language","language_version":"test_language_version","tracer_version":"test_tracer_version"}"#); - then.status(200).body(""); - }) - .await; - - let (client, mut worker) = TelemetryClientBuilder::default() - .set_service_name("test_service") - .set_service_version("test_version") - .set_env("test_env") - .set_language("test_language") - .set_language_version("test_language_version") - .set_tracer_version("test_tracer_version") - .set_url(&server.url("/")) - .set_heartbeat(100) - .set_runtime_id("foo") - .build(Handle::current()); - tokio::spawn(async move { worker.run().await }); - - client.start().await; - client - .send(&SendPayloadTelemetry { - requests_count: 1, - ..Default::default() - }) - .unwrap(); - client.shutdown().await; - // Wait for the server to receive at least one call, but don't hang forever. - let start = std::time::Instant::now(); - while telemetry_srv.calls_async().await == 0 { - if start.elapsed() > Duration::from_secs(180) { - panic!("telemetry server did not receive calls within timeout"); - } - sleep(Duration::from_millis(10)).await; - } - // One payload generate-metrics - telemetry_srv.assert_calls_async(1).await; + #[test] + fn application_metadata_test() { + let shared_runtime = SharedRuntime::new().expect("Failed to create runtime"); + let rt = shared_runtime.runtime().expect("Failed to get runtime"); + + rt.block_on(async { + let server = MockServer::start_async().await; + let mut telemetry_srv = server + .mock_async(|when, then| { + when.method(POST) + .body_includes(r#""application":{"service_name":"test_service","service_version":"test_version","env":"test_env","language_name":"test_language","language_version":"test_language_version","tracer_version":"test_tracer_version"}"#); + then.status(200).body(""); + }) + .await; + + let (client, handle) = get_test_client(&server.url("/"), &shared_runtime); + client.start().await; + client + .send(&SendPayloadTelemetry { + requests_count: 1, + ..Default::default() + }) + .unwrap(); + // Wait for send to be processed + sleep(Duration::from_millis(1)).await; + + handle.stop().await.expect("Failed to stop worker"); + // Wait for the server to receive at least one call, but don't hang forever. + assert!( + poll_for_mock_hit(&mut telemetry_srv, 1000, 10, 1, false).await, + "telemetry server did not receive calls within timeout" + ); + // One payload generate-metrics + }); } } diff --git a/libdd-data-pipeline/src/trace_exporter/builder.rs b/libdd-data-pipeline/src/trace_exporter/builder.rs index f9833fa668..e68851c332 100644 --- a/libdd-data-pipeline/src/trace_exporter/builder.rs +++ b/libdd-data-pipeline/src/trace_exporter/builder.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use crate::agent_info::AgentInfoFetcher; -use crate::pausable_worker::PausableWorker; +use crate::shared_runtime::SharedRuntime; use crate::telemetry::TelemetryClientBuilder; use crate::trace_exporter::agent_response::AgentResponsePayloadVersion; use crate::trace_exporter::error::BuilderErrorKind; @@ -15,7 +15,7 @@ use arc_swap::ArcSwap; use libdd_common::http_common::new_default_client; use libdd_common::{parse_uri, tag, Endpoint}; use libdd_dogstatsd_client::new; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::time::Duration; const DEFAULT_AGENT_URL: &str = "http://127.0.0.1:8126"; @@ -46,6 +46,7 @@ pub struct TraceExporterBuilder { compute_stats_by_span_kind: bool, peer_tags: Vec, telemetry: Option, + shared_runtime: Option>, health_metrics_enabled: bool, test_session_token: Option, agent_rates_payload_version_enabled: bool, @@ -199,6 +200,12 @@ impl TraceExporterBuilder { self } + /// Set a shared runtime used by the exporter for background workers. + pub fn set_shared_runtime(&mut self, shared_runtime: Arc) -> &mut Self { + self.shared_runtime = Some(shared_runtime); + self + } + /// Enables health metrics emission. pub fn enable_health_metrics(&mut self) -> &mut Self { self.health_metrics_enabled = true; @@ -227,12 +234,13 @@ impl TraceExporterBuilder { )); } - let runtime = Arc::new( - tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .enable_all() - .build()?, - ); + let shared_runtime = + self.shared_runtime + .unwrap_or(Arc::new(SharedRuntime::new().map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( + e.to_string(), + )) + })?)); let dogstatsd = self.dogstatsd_url.and_then(|u| { new(Endpoint::from_slice(&u)).ok() // If we couldn't set the endpoint return @@ -251,8 +259,7 @@ impl TraceExporterBuilder { let info_endpoint = Endpoint::from_url(add_path(&agent_url, INFO_ENDPOINT)); let (info_fetcher, info_response_observer) = AgentInfoFetcher::new(info_endpoint.clone(), Duration::from_secs(5 * 60)); - let mut info_fetcher_worker = PausableWorker::new(info_fetcher); - info_fetcher_worker.start(&runtime).map_err(|e| { + let info_fetcher_handle = shared_runtime.spawn_worker(info_fetcher).map_err(|e| { TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration(e.to_string())) })?; @@ -276,20 +283,32 @@ impl TraceExporterBuilder { if let Some(id) = telemetry_config.runtime_id { builder = builder.set_runtime_id(&id); } - builder.build(runtime.handle().clone()) + let runtime = shared_runtime.runtime().map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration(e.to_string())) + })?; + // This handle is never used since we run it as a SharedRuntime worker. So it is fine + // if the tokio runtime is dropped by SharedRuntime. + Ok(builder.build(runtime.handle().clone())) }); - let (telemetry_client, telemetry_worker) = match telemetry { - Some((client, worker)) => { - let mut telemetry_worker = PausableWorker::new(worker); - telemetry_worker.start(&runtime).map_err(|e| { + let (telemetry_client, telemetry_handle) = match telemetry { + Some(Ok((client, worker))) => { + let handle = shared_runtime.spawn_worker(worker).map_err(|e| { TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( e.to_string(), )) })?; - runtime.block_on(client.start()); - (Some(client), Some(telemetry_worker)) + shared_runtime + .runtime() + .map_err(|e| { + TraceExporterError::Builder(BuilderErrorKind::InvalidConfiguration( + e.to_string(), + )) + })? + .block_on(client.start()); + (Some(client), Some(handle)) } + Some(Err(e)) => return Err(e), None => (None, None), }; @@ -320,7 +339,7 @@ impl TraceExporterBuilder { input_format: self.input_format, output_format: self.output_format, client_computed_top_level: self.client_computed_top_level, - runtime: Arc::new(Mutex::new(Some(runtime))), + shared_runtime, dogstatsd, common_stats_tags: vec![libdatadog_version], client_side_stats: ArcSwap::new(stats.into()), @@ -328,16 +347,14 @@ impl TraceExporterBuilder { info_response_observer, telemetry: telemetry_client, health_metrics_enabled: self.health_metrics_enabled, - workers: Arc::new(Mutex::new(TraceExporterWorkers { - info: info_fetcher_worker, - stats: None, - telemetry: telemetry_worker, - })), - agent_payload_response_version: self .agent_rates_payload_version_enabled .then(AgentResponsePayloadVersion::new), http_client: new_default_client(), + workers: TraceExporterWorkers { + info_fetcher: info_fetcher_handle, + telemetry: telemetry_handle, + }, }) } @@ -420,6 +437,22 @@ mod tests { assert_eq!(exporter.metadata.language_interpreter, ""); assert!(!exporter.metadata.client_computed_stats); assert!(exporter.telemetry.is_none()); + assert!( + exporter.shared_runtime.runtime().is_ok(), + "default shared runtime should be initialized" + ); + } + + #[cfg_attr(miri, ignore)] + #[test] + fn test_set_shared_runtime() { + let mut builder = TraceExporterBuilder::default(); + let shared_runtime = Arc::new(SharedRuntime::new().unwrap()); + builder.set_shared_runtime(shared_runtime.clone()); + + let exporter = builder.build().unwrap(); + + assert!(Arc::ptr_eq(&exporter.shared_runtime, &shared_runtime)); } #[test] diff --git a/libdd-data-pipeline/src/trace_exporter/mod.rs b/libdd-data-pipeline/src/trace_exporter/mod.rs index 02320a684d..550862bb28 100644 --- a/libdd-data-pipeline/src/trace_exporter/mod.rs +++ b/libdd-data-pipeline/src/trace_exporter/mod.rs @@ -14,14 +14,15 @@ use self::agent_response::AgentResponse; use self::metrics::MetricsEmitter; use self::stats::StatsComputationStatus; use self::trace_serializer::TraceSerializer; -use crate::agent_info::{AgentInfoFetcher, ResponseObserver}; -use crate::pausable_worker::PausableWorker; -use crate::stats_exporter::StatsExporter; +use crate::agent_info::ResponseObserver; +use crate::shared_runtime::{SharedRuntime, WorkerHandle}; use crate::telemetry::{SendPayloadTelemetry, TelemetryClient}; use crate::trace_exporter::agent_response::{ AgentResponsePayloadVersion, DATADOG_RATES_PAYLOAD_VERSION_HEADER, }; -use crate::trace_exporter::error::{InternalErrorKind, RequestError, TraceExporterError}; +use crate::trace_exporter::error::{ + InternalErrorKind, RequestError, ShutdownError, TraceExporterError, +}; use crate::{ agent_info::{self, schema::AgentInfo}, health_metrics, @@ -32,10 +33,8 @@ use http::uri::PathAndQuery; use http::Uri; use http_body_util::BodyExt; use libdd_common::tag::Tag; -use libdd_common::{http_common, Endpoint}; -use libdd_common::{HttpClient, MutexExt}; +use libdd_common::{http_common, Endpoint, HttpClient}; use libdd_dogstatsd_client::Client; -use libdd_telemetry::worker::TelemetryWorker; use libdd_trace_utils::msgpack_decoder; use libdd_trace_utils::send_with_retry::{ send_with_retry, RetryStrategy, SendWithRetryError, SendWithRetryResult, @@ -43,10 +42,12 @@ use libdd_trace_utils::send_with_retry::{ use libdd_trace_utils::span::{v04::Span, TraceData}; use libdd_trace_utils::trace_utils::TracerHeaderTags; use std::io; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +#[cfg(feature = "test-utils")] use std::time::Duration; use std::{borrow::Borrow, collections::HashMap, str::FromStr}; use tokio::runtime::Runtime; +use tokio::task::JoinSet; use tracing::{debug, error, warn}; const INFO_ENDPOINT: &str = "/info"; @@ -153,11 +154,11 @@ impl<'a> From<&'a TracerMetadata> for HashMap<&'static str, String> { } } +/// Handles for the background workers owned by a [`TraceExporter`]. #[derive(Debug)] pub(crate) struct TraceExporterWorkers { - pub info: PausableWorker, - pub stats: Option>, - pub telemetry: Option>, + info_fetcher: WorkerHandle, + telemetry: Option, } /// The TraceExporter ingest traces from the tracers serialized as messagepack and forward them to @@ -190,8 +191,7 @@ pub struct TraceExporter { metadata: TracerMetadata, input_format: TraceExporterInputFormat, output_format: TraceExporterOutputFormat, - // TODO - do something with the response callback - https://datadoghq.atlassian.net/browse/APMSP-1019 - runtime: Arc>>>, + shared_runtime: Arc, /// None if dogstatsd is disabled dogstatsd: Option, common_stats_tags: Vec, @@ -201,9 +201,9 @@ pub struct TraceExporter { info_response_observer: ResponseObserver, telemetry: Option, health_metrics_enabled: bool, - workers: Arc>, agent_payload_response_version: Option, http_client: HttpClient, + workers: TraceExporterWorkers, } impl TraceExporter { @@ -212,113 +212,70 @@ impl TraceExporter { TraceExporterBuilder::default() } - /// Return the existing runtime or create a new one and start all workers - fn runtime(&self) -> Result, TraceExporterError> { - let mut runtime_guard = self.runtime.lock_or_panic(); - match runtime_guard.as_ref() { - Some(runtime) => { - // Runtime already running - Ok(runtime.clone()) - } - None => { - // Create a new current thread runtime with all features enabled - let runtime = Arc::new( - tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .enable_all() - .build()?, - ); - *runtime_guard = Some(runtime.clone()); - self.start_all_workers(&runtime)?; - Ok(runtime) + /// Stop the background workers owned by this exporter. + /// + /// Only the workers spawned for this exporter are stopped. Workers from other components + /// sharing the same [`SharedRuntime`] are unaffected. + /// + /// # Errors + /// Returns [`SharedRuntimeError::ShutdownTimedOut`] if a timeout was given and elapsed before + /// all workers finished. + pub fn shutdown(self, timeout: Option) -> Result<(), TraceExporterError> { + let runtime = self.runtime()?; + if let Some(timeout) = timeout { + match runtime + .block_on(async { tokio::time::timeout(timeout, self.shutdown_workers()).await }) + { + Ok(()) => Ok(()), + Err(_) => Err(TraceExporterError::Shutdown(ShutdownError::TimedOut( + timeout, + ))), } + } else { + runtime.block_on(self.shutdown_workers()); + Ok(()) } } - /// Manually start all workers - pub fn run_worker(&self) -> Result<(), TraceExporterError> { - self.runtime()?; - Ok(()) - } - - /// Start all workers with the given runtime - fn start_all_workers(&self, runtime: &Arc) -> Result<(), TraceExporterError> { - let mut workers = self.workers.lock_or_panic(); + async fn shutdown_workers(self) { + let mut join_set = JoinSet::new(); - self.start_info_worker(&mut workers, runtime)?; - self.start_stats_worker(&mut workers, runtime)?; - self.start_telemetry_worker(&mut workers, runtime)?; + // Extract the stats handle before moving other fields. + if let StatsComputationStatus::Enabled { worker_handle, .. } = + &**self.client_side_stats.load() + { + let handle = worker_handle.clone(); + join_set.spawn(async move { handle.stop().await }); + } - Ok(()) - } + let info_fetcher = self.workers.info_fetcher; + let telemetry = self.workers.telemetry; - /// Start the info worker - fn start_info_worker( - &self, - workers: &mut TraceExporterWorkers, - runtime: &Arc, - ) -> Result<(), TraceExporterError> { - workers.info.start(runtime).map_err(|e| { - TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) - }) - } + join_set.spawn(async move { info_fetcher.stop().await }); - /// Start the stats worker if present - fn start_stats_worker( - &self, - workers: &mut TraceExporterWorkers, - runtime: &Arc, - ) -> Result<(), TraceExporterError> { - if let Some(stats_worker) = &mut workers.stats { - stats_worker.start(runtime).map_err(|e| { - TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) - })?; + if let Some(telemetry) = telemetry { + join_set.spawn(async move { telemetry.stop().await }); } - Ok(()) - } - /// Start the telemetry worker if present - fn start_telemetry_worker( - &self, - workers: &mut TraceExporterWorkers, - runtime: &Arc, - ) -> Result<(), TraceExporterError> { - if let Some(telemetry_worker) = &mut workers.telemetry { - telemetry_worker.start(runtime).map_err(|e| { - TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) - })?; - if let Some(client) = &self.telemetry { - runtime.block_on(client.start()); + while let Some(result) = join_set.join_next().await { + if let Ok(Err(e)) = result { + error!("Worker failed to shutdown: {:?}", e); } } - Ok(()) } - pub fn stop_worker(&self) { - let runtime = self.runtime.lock_or_panic().take(); - if let Some(ref rt) = runtime { - // Stop workers to save their state - let mut workers = self.workers.lock_or_panic(); - rt.block_on(async { - let _ = workers.info.pause().await; - if let Some(stats_worker) = &mut workers.stats { - let _ = stats_worker.pause().await; - }; - if let Some(telemetry_worker) = &mut workers.telemetry { - let _ = telemetry_worker.pause().await; - }; - }); - } - // When the info fetcher is paused, the trigger channel keeps a reference to the runtime's - // IoStack as a waker. This prevents the IoStack from being dropped when shutting - // down runtime. By manually sending a message to the trigger channel we trigger the - // waker releasing the reference to the IoStack. Finally we drain the channel to - // avoid triggering a fetch when the info fetcher is restarted. - if let PausableWorker::Paused { worker } = &mut self.workers.lock_or_panic().info { - self.info_response_observer.manual_trigger(); - worker.drain(); - } - drop(runtime); + /// Return a runtime from the shared runtime manager. + fn runtime(&self) -> Result, TraceExporterError> { + self.shared_runtime + .runtime() + .map_err(TraceExporterError::Io) + } + + /// Manually start all workers + pub fn run_worker(&self) -> Result<(), TraceExporterError> { + self.shared_runtime.after_fork_parent().map_err(|e| { + TraceExporterError::Internal(InternalErrorKind::InvalidWorkerState(e.to_string())) + }) } /// Send msgpack serialized traces to the agent @@ -347,55 +304,6 @@ impl TraceExporter { Ok(res) } - /// Safely shutdown the TraceExporter and all related tasks - pub fn shutdown(mut self, timeout: Option) -> Result<(), TraceExporterError> { - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - - if let Some(timeout) = timeout { - match runtime - .block_on(async { tokio::time::timeout(timeout, self.shutdown_async()).await }) - { - Ok(()) => Ok(()), - Err(_e) => Err(TraceExporterError::Shutdown( - error::ShutdownError::TimedOut(timeout), - )), - } - } else { - runtime.block_on(self.shutdown_async()); - Ok(()) - } - } - - /// Future used inside `Self::shutdown`. - /// - /// This function should not take ownership of the trace exporter as it will cause the runtime - /// stored in the trace exporter to be dropped in a non-blocking context causing a panic. - async fn shutdown_async(&mut self) { - let stats_status = self.client_side_stats.load(); - if let StatsComputationStatus::Enabled { - cancellation_token, .. - } = stats_status.as_ref() - { - cancellation_token.cancel(); - - let stats_worker = self.workers.lock_or_panic().stats.take(); - - if let Some(stats_worker) = stats_worker { - let _ = stats_worker.join().await; - } - } - if let Some(telemetry) = self.telemetry.take() { - telemetry.shutdown().await; - let telemetry_worker = self.workers.lock_or_panic().telemetry.take(); - - if let Some(telemetry_worker) = telemetry_worker { - let _ = telemetry_worker.join().await; - } - } - } - /// Check if agent info state has changed fn has_agent_info_state_changed(&self, agent_info: &Arc) -> bool { Some(agent_info.state_hash.as_str()) @@ -415,13 +323,12 @@ impl TraceExporter { let ctx = stats::StatsContext { metadata: &self.metadata, endpoint_url: &self.endpoint.url, - runtime: &self.runtime, + shared_runtime: &self.shared_runtime, }; stats::handle_stats_disabled_by_agent( &ctx, &agent_info, &self.client_side_stats, - &self.workers, self.http_client.clone(), ); } @@ -431,14 +338,13 @@ impl TraceExporter { let ctx = stats::StatsContext { metadata: &self.metadata, endpoint_url: &self.endpoint.url, - runtime: &self.runtime, + shared_runtime: &self.shared_runtime, }; stats::handle_stats_enabled( &ctx, &agent_info, stats_concentrator, &self.client_side_stats, - &self.workers, ); } } @@ -848,7 +754,7 @@ impl TraceExporter { #[cfg(test)] /// Test only function to check if the stats computation is active and the worker is running pub fn is_stats_worker_active(&self) -> bool { - stats::is_stats_worker_active(&self.client_side_stats, &self.workers) + stats::is_stats_worker_active(&self.client_side_stats) } } @@ -1521,15 +1427,9 @@ mod tests { traces_endpoint.assert_calls(1); while metrics_endpoint.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter.shared_runtime.runtime().unwrap().block_on(async { + sleep(Duration::from_millis(100)).await; + }) } metrics_endpoint.assert_calls(1); } @@ -1579,15 +1479,9 @@ mod tests { traces_endpoint.assert_calls(1); while metrics_endpoint.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter.shared_runtime.runtime().unwrap().block_on(async { + sleep(Duration::from_millis(100)).await; + }) } metrics_endpoint.assert_calls(1); } @@ -1648,15 +1542,9 @@ mod tests { traces_endpoint.assert_calls(1); while metrics_endpoint.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter.shared_runtime.runtime().unwrap().block_on(async { + sleep(Duration::from_millis(100)).await; + }) } metrics_endpoint.assert_calls(1); } @@ -1831,21 +1719,13 @@ mod tests { // Wait for the info fetcher to get the config while mock_info.calls() == 0 { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter.shared_runtime.runtime().unwrap().block_on(async { + sleep(Duration::from_millis(100)).await; + }) } let _ = exporter.send(data.as_ref()).unwrap(); - exporter.shutdown(None).unwrap(); - mock_traces.assert(); } @@ -1864,15 +1744,6 @@ mod tests { assert_eq!(exporter.endpoint.timeout_ms, 42); } - - #[test] - #[cfg_attr(miri, ignore)] - fn stop_and_start_runtime() { - let builder = TraceExporterBuilder::default(); - let exporter = builder.build().unwrap(); - exporter.stop_worker(); - exporter.run_worker().unwrap(); - } } #[cfg(test)] @@ -1915,6 +1786,8 @@ mod single_threaded_tests { .body(r#"{"version":"1","client_drop_p0s":true,"endpoints":["/v0.4/traces","/v0.6/stats"]}"#); }); + let runtime = Arc::new(SharedRuntime::new().unwrap()); + let mut builder = TraceExporterBuilder::default(); builder .set_url(&server.url("/")) @@ -1926,6 +1799,7 @@ mod single_threaded_tests { .set_language_interpreter("v8") .set_input_format(TraceExporterInputFormat::V04) .set_output_format(TraceExporterOutputFormat::V04) + .set_shared_runtime(runtime.clone()) .enable_stats(Duration::from_secs(10)); let exporter = builder.build().unwrap(); @@ -1938,15 +1812,9 @@ mod single_threaded_tests { // Wait for the info fetcher to get the config while agent_info::get_agent_info().is_none() { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter.shared_runtime.runtime().unwrap().block_on(async { + sleep(Duration::from_millis(100)).await; + }) } let result = exporter.send(data.as_ref()); @@ -1963,7 +1831,7 @@ mod single_threaded_tests { std::thread::sleep(Duration::from_millis(10)); } - exporter.shutdown(None).unwrap(); + runtime.shutdown(None).unwrap(); // Wait for the mock server to process the stats for _ in 0..1000 { @@ -2015,6 +1883,8 @@ mod single_threaded_tests { .body(r#"{"version":"1","client_drop_p0s":true,"endpoints":["/v0.4/traces","/v0.6/stats"]}"#); }); + let runtime = Arc::new(SharedRuntime::new().unwrap()); + let mut builder = TraceExporterBuilder::default(); builder .set_url(&server.url("/")) @@ -2026,6 +1896,7 @@ mod single_threaded_tests { .set_language_interpreter("v8") .set_input_format(TraceExporterInputFormat::V04) .set_output_format(TraceExporterOutputFormat::V04) + .set_shared_runtime(runtime.clone()) .enable_stats(Duration::from_secs(10)); let exporter = builder.build().unwrap(); @@ -2043,15 +1914,9 @@ mod single_threaded_tests { // Wait for agent_info to be present so that sending a trace will trigger the stats worker // to start while agent_info::get_agent_info().is_none() { - exporter - .runtime - .lock() - .unwrap() - .as_ref() - .unwrap() - .block_on(async { - sleep(Duration::from_millis(100)).await; - }) + exporter.shared_runtime.runtime().unwrap().block_on(async { + sleep(Duration::from_millis(100)).await; + }) } exporter.send(data.as_ref()).unwrap(); @@ -2066,7 +1931,7 @@ mod single_threaded_tests { std::thread::sleep(Duration::from_millis(10)); } - exporter + runtime .shutdown(Some(Duration::from_millis(5))) .unwrap_err(); // The shutdown should timeout diff --git a/libdd-data-pipeline/src/trace_exporter/stats.rs b/libdd-data-pipeline/src/trace_exporter/stats.rs index 943ebc5dd1..b33a014c1d 100644 --- a/libdd-data-pipeline/src/trace_exporter/stats.rs +++ b/libdd-data-pipeline/src/trace_exporter/stats.rs @@ -8,14 +8,13 @@ //! and processing traces for stats collection. use crate::agent_info::schema::AgentInfo; +use crate::shared_runtime::{SharedRuntime, WorkerHandle}; use crate::stats_exporter; use arc_swap::ArcSwap; use libdd_common::{Endpoint, HttpClient, MutexExt}; use libdd_trace_stats::span_concentrator::SpanConcentrator; use std::sync::{Arc, Mutex}; use std::time::Duration; -use tokio::runtime::Runtime; -use tokio_util::sync::CancellationToken; use tracing::{debug, error}; use super::add_path; @@ -28,7 +27,7 @@ pub(crate) const STATS_ENDPOINT: &str = "/v0.6/stats"; pub(crate) struct StatsContext<'a> { pub metadata: &'a super::TracerMetadata, pub endpoint_url: &'a http::Uri, - pub runtime: &'a Arc>>>, + pub shared_runtime: &'a SharedRuntime, } #[derive(Debug)] @@ -42,7 +41,7 @@ pub(crate) enum StatsComputationStatus { /// Client-side stats is enabled Enabled { stats_concentrator: Arc>, - cancellation_token: CancellationToken, + worker_handle: WorkerHandle, }, } @@ -61,7 +60,6 @@ fn get_span_kinds_for_stats(agent_info: &Arc) -> Vec { pub(crate) fn start_stats_computation( ctx: &StatsContext, client_side_stats: &ArcSwap, - workers: &Arc>, span_kinds: Vec, peer_tags: Vec, client: HttpClient, @@ -73,13 +71,10 @@ pub(crate) fn start_stats_computation( span_kinds, peer_tags, ))); - let cancellation_token = CancellationToken::new(); create_and_start_stats_worker( ctx, bucket_size, &stats_concentrator, - &cancellation_token, - workers, client_side_stats, client, )?; @@ -92,8 +87,6 @@ fn create_and_start_stats_worker( ctx: &StatsContext, bucket_size: Duration, stats_concentrator: &Arc>, - cancellation_token: &CancellationToken, - workers: &Arc>, client_side_stats: &ArcSwap, client: HttpClient, ) -> anyhow::Result<()> { @@ -102,28 +95,17 @@ fn create_and_start_stats_worker( stats_concentrator.clone(), ctx.metadata.clone(), Endpoint::from_url(add_path(ctx.endpoint_url, STATS_ENDPOINT)), - cancellation_token.clone(), client, ); - let mut stats_worker = crate::pausable_worker::PausableWorker::new(stats_exporter); + let worker_handle = ctx + .shared_runtime + .spawn_worker(stats_exporter) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; - // Get runtime guard - let runtime_guard = ctx.runtime.lock_or_panic(); - if let Some(rt) = runtime_guard.as_ref() { - stats_worker.start(rt).map_err(|e| { - super::error::TraceExporterError::Internal( - super::error::InternalErrorKind::InvalidWorkerState(e.to_string()), - ) - })?; - } else { - return Err(anyhow::anyhow!("Runtime not available")); - } - - // Update the stats computation state with the new worker and components - workers.lock_or_panic().stats = Some(stats_worker); + // Update the stats computation state with the new worker components. client_side_stats.store(Arc::new(StatsComputationStatus::Enabled { stats_concentrator: stats_concentrator.clone(), - cancellation_token: cancellation_token.clone(), + worker_handle, })); Ok(()) @@ -135,25 +117,21 @@ fn create_and_start_stats_worker( pub(crate) fn stop_stats_computation( ctx: &StatsContext, client_side_stats: &ArcSwap, - workers: &Arc>, ) { if let StatsComputationStatus::Enabled { stats_concentrator, - cancellation_token, + worker_handle, } = &**client_side_stats.load() { - // If there's no runtime there's no exporter to stop - let runtime_guard = ctx.runtime.lock_or_panic(); - if let Some(rt) = runtime_guard.as_ref() { - rt.block_on(async { - cancellation_token.cancel(); - }); - workers.lock_or_panic().stats = None; - let bucket_size = stats_concentrator.lock_or_panic().get_bucket_size(); - - client_side_stats.store(Arc::new(StatsComputationStatus::DisabledByAgent { - bucket_size, - })); + let bucket_size = stats_concentrator.lock_or_panic().get_bucket_size(); + client_side_stats.store(Arc::new(StatsComputationStatus::DisabledByAgent { + bucket_size, + })); + match ctx.shared_runtime.runtime() { + Ok(runtime) => { + let _ = runtime.block_on(async { worker_handle.clone().stop().await }); + } + Err(e) => error!("Failed to stop stats worker: {e}"), } } } @@ -163,7 +141,6 @@ pub(crate) fn handle_stats_disabled_by_agent( ctx: &StatsContext, agent_info: &Arc, client_side_stats: &ArcSwap, - workers: &Arc>, client: HttpClient, ) { if agent_info.info.client_drop_p0s.is_some_and(|v| v) { @@ -171,7 +148,6 @@ pub(crate) fn handle_stats_disabled_by_agent( let status = start_stats_computation( ctx, client_side_stats, - workers, get_span_kinds_for_stats(agent_info), agent_info.info.peer_tags.clone().unwrap_or_default(), client, @@ -191,14 +167,13 @@ pub(crate) fn handle_stats_enabled( agent_info: &Arc, stats_concentrator: &Mutex, client_side_stats: &ArcSwap, - workers: &Arc>, ) { if agent_info.info.client_drop_p0s.is_some_and(|v| v) { let mut concentrator = stats_concentrator.lock_or_panic(); concentrator.set_span_kinds(get_span_kinds_for_stats(agent_info)); concentrator.set_peer_tags(agent_info.info.peer_tags.clone().unwrap_or_default()); } else { - stop_stats_computation(ctx, client_side_stats, workers); + stop_stats_computation(ctx, client_side_stats); debug!("Client-side stats computation has been disabled by the agent") } } @@ -258,25 +233,9 @@ pub(crate) fn process_traces_for_stats( #[cfg(test)] /// Test only function to check if the stats computation is active and the worker is running -pub(crate) fn is_stats_worker_active( - client_side_stats: &ArcSwap, - workers: &Arc>, -) -> bool { - if !matches!( +pub(crate) fn is_stats_worker_active(client_side_stats: &ArcSwap) -> bool { + matches!( **client_side_stats.load(), StatsComputationStatus::Enabled { .. } - ) { - return false; - } - - if let Ok(workers) = workers.try_lock() { - if let Some(stats_worker) = &workers.stats { - return matches!( - stats_worker, - crate::pausable_worker::PausableWorker::Running { .. } - ); - } - } - - false + ) } diff --git a/libdd-telemetry/Cargo.toml b/libdd-telemetry/Cargo.toml index fc910a1a68..b41cb0d01e 100644 --- a/libdd-telemetry/Cargo.toml +++ b/libdd-telemetry/Cargo.toml @@ -18,6 +18,7 @@ https = ["libdd-common/https"] [dependencies] anyhow = { version = "1.0" } +async-trait = "0.1" base64 = "0.22" futures = { version = "0.3", default-features = false } http-body-util = "0.1" diff --git a/libdd-telemetry/src/worker/mod.rs b/libdd-telemetry/src/worker/mod.rs index 3bfa1bcccb..b453d4b2ce 100644 --- a/libdd-telemetry/src/worker/mod.rs +++ b/libdd-telemetry/src/worker/mod.rs @@ -11,6 +11,7 @@ use crate::{ metrics::{ContextKey, MetricBuckets, MetricContexts}, }; +use async_trait::async_trait; use libdd_common::{http_common, tag::Tag, worker::Worker}; use std::iter::Sum; @@ -140,6 +141,7 @@ pub struct TelemetryWorker { metrics_flush_interval: Duration, deadlines: scheduler::Scheduler, data: TelemetryWorkerData, + next_action: Option, } impl Debug for TelemetryWorker { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -157,58 +159,64 @@ impl Debug for TelemetryWorker { } } +#[async_trait] impl Worker for TelemetryWorker { - // Runs a state machine that waits for actions, either from the worker's - // mailbox, or scheduled actions from the worker's deadline object. - async fn run(&mut self) { - debug!( - worker.flavor = ?self.flavor, - worker.runtime_id = %self.runtime_id, - "Starting telemetry worker" - ); - - loop { - if self.cancellation_token.is_cancelled() { - debug!( - worker.runtime_id = %self.runtime_id, - "Telemetry worker cancelled, shutting down" - ); - return; - } + async fn trigger(&mut self) { + // Wait for the next action and store it + let action = self.recv_next_action().await; + self.next_action = Some(action); + } - let action = self.recv_next_action().await; + // Processes a single action from the state machine + async fn run(&mut self) { + // Take the action that was stored by trigger() + if let Some(action) = self.next_action.take() { debug!( worker.runtime_id = %self.runtime_id, action = ?action, "Received telemetry action" ); - let action_result = match self.flavor { + let _action_result = match self.flavor { TelemetryWorkerFlavor::Full => self.dispatch_action(action).await, TelemetryWorkerFlavor::MetricsLogs => { self.dispatch_metrics_logs_action(action).await } }; - - match action_result { - ControlFlow::Continue(()) => {} - ControlFlow::Break(()) => { - debug!( - worker.runtime_id = %self.runtime_id, - worker.restartable = self.config.restartable, - "Telemetry worker received break signal" - ); - if !self.config.restartable { - break; - } - } - }; } + } - debug!( - worker.runtime_id = %self.runtime_id, - "Telemetry worker stopped" - ); + /// Reset the worker state in the child process after a fork. + /// + /// Discards inherited pending telemetry state without sending anything, and drains + /// the mailbox so that actions queued before the fork are not processed by the child. + /// Dedupe history is preserved across forks so the child does not re-emit already + /// seen dependencies, integrations, or configurations unless they are observed again + /// as new data. + fn reset(&mut self) { + // Drain all actions queued in the mailbox before the fork. + while self.mailbox.try_recv().is_ok() {} + + // Discard any action that was staged by the last trigger() call. + self.next_action = None; + + // Clear all unbuffered telemetry data; the child must not send pre-fork data. + self.data.logs = store::QueueHashMap::default(); + self.data.metric_buckets = MetricBuckets::default(); + self.data.dependencies.clear(); + self.data.integrations.clear(); + self.data.configurations.clear(); + self.data.endpoints.clear(); + } + + async fn shutdown(&mut self) { + let stop_action = TelemetryActions::Lifecycle(LifecycleAction::Stop); + let _action_result = match self.flavor { + TelemetryWorkerFlavor::Full => self.dispatch_action(stop_action).await, + TelemetryWorkerFlavor::MetricsLogs => { + self.dispatch_metrics_logs_action(stop_action).await + } + }; } } @@ -828,6 +836,59 @@ impl TelemetryWorker { metric_buckets: self.data.metric_buckets.stats(), } } + + // Runs a state machine that waits for actions, either from the worker's + // mailbox, or scheduled actions from the worker's deadline object. + async fn run_loop(mut self) { + debug!( + worker.flavor = ?self.flavor, + worker.runtime_id = %self.runtime_id, + "Starting telemetry worker" + ); + + loop { + if self.cancellation_token.is_cancelled() { + debug!( + worker.runtime_id = %self.runtime_id, + "Telemetry worker cancelled, shutting down" + ); + return; + } + + let action = self.recv_next_action().await; + debug!( + worker.runtime_id = %self.runtime_id, + action = ?action, + "Received telemetry action" + ); + + let action_result = match self.flavor { + TelemetryWorkerFlavor::Full => self.dispatch_action(action).await, + TelemetryWorkerFlavor::MetricsLogs => { + self.dispatch_metrics_logs_action(action).await + } + }; + + match action_result { + ControlFlow::Continue(()) => {} + ControlFlow::Break(()) => { + debug!( + worker.runtime_id = %self.runtime_id, + worker.restartable = self.config.restartable, + "Telemetry worker received break signal" + ); + if !self.config.restartable { + break; + } + } + }; + } + + debug!( + worker.runtime_id = %self.runtime_id, + "Telemetry worker stopped" + ); + } } #[derive(Debug)] @@ -1145,6 +1206,7 @@ impl TelemetryWorkerBuilder { ), ]), cancellation_token: token.clone(), + next_action: None, }; ( @@ -1153,6 +1215,7 @@ impl TelemetryWorkerBuilder { shutdown, cancellation_token: token, runtime: tokio_runtime, + contexts, }, worker, @@ -1164,9 +1227,9 @@ impl TelemetryWorkerBuilder { pub fn spawn(self) -> (TelemetryWorkerHandle, JoinHandle<()>) { let tokio_runtime = tokio::runtime::Handle::current(); - let (worker_handle, mut worker) = self.build_worker(tokio_runtime.clone()); + let (worker_handle, worker) = self.build_worker(tokio_runtime.clone()); - let join_handle = tokio_runtime.spawn(async move { worker.run().await }); + let join_handle = tokio_runtime.spawn(async move { worker.run_loop().await }); (worker_handle, join_handle) } @@ -1176,10 +1239,10 @@ impl TelemetryWorkerBuilder { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; - let (handle, mut worker) = self.build_worker(runtime.handle().clone()); + let (handle, worker) = self.build_worker(runtime.handle().clone()); let notify_shutdown = handle.shutdown.clone(); std::thread::spawn(move || { - runtime.block_on(worker.run()); + runtime.block_on(worker.run_loop()); runtime.shutdown_background(); notify_shutdown.shutdown_finished(); }); @@ -1202,4 +1265,231 @@ mod tests { #[allow(clippy::redundant_closure)] let _ = |h: TelemetryWorkerHandle| is_sync(h); } + + mod reset { + use super::super::*; + use crate::data::{ + metrics::{MetricNamespace, MetricType}, + Configuration, ConfigurationOrigin, Dependency, Endpoint, Integration, Log, LogLevel, + }; + use libdd_common::worker::Worker; + + fn build_test_worker() -> (TelemetryWorkerHandle, TelemetryWorker) { + let builder = TelemetryWorkerBuilder::new( + "hostname".to_string(), + "test-service".to_string(), + "rust".to_string(), + "1.0.0".to_string(), + "1.0.0".to_string(), + ); + // build_worker requires a tokio Handle; tests using this must be #[tokio::test] + builder.build_worker(tokio::runtime::Handle::current()) + } + + fn make_log(id: u64, message: &str) -> (LogIdentifier, Log) { + ( + LogIdentifier { identifier: id }, + Log { + message: message.to_string(), + level: LogLevel::Warn, + stack_trace: None, + count: 1, + tags: String::new(), + is_sensitive: false, + is_crash: false, + }, + ) + } + + /// After reset(), pending buffered telemetry is cleared while dedupe history is preserved. + #[tokio::test] + async fn test_reset_clears_buffered_data() { + let (handle, mut worker) = build_test_worker(); + + // Populate every data field that reset() should clear. + worker.data.dependencies.insert(Dependency { + name: "dep".to_string(), + version: None, + }); + worker.data.integrations.insert(Integration { + name: "integration".to_string(), + version: None, + enabled: true, + compatible: None, + auto_enabled: None, + }); + worker.data.configurations.insert(Configuration { + name: "cfg".to_string(), + value: "true".to_string(), + origin: ConfigurationOrigin::Code, + config_id: None, + seq_id: None, + }); + worker.data.endpoints.insert(Endpoint { + operation_name: "GET /health".to_string(), + resource_name: "/health".to_string(), + ..Default::default() + }); + let (id, log) = make_log(42, "msg"); + worker.data.logs.get_mut_or_insert(id, log); + + // Register a metric context and add a data point. + let key = handle.register_metric_context( + "test.metric".to_string(), + vec![], + MetricType::Count, + false, + MetricNamespace::Tracers, + ); + worker.data.metric_buckets.add_point(key, 1.0, vec![]); + + worker.reset(); + + let stats = worker.stats(); + assert_eq!( + stats.dependencies_stored, 1, + "dependency dedupe history should be preserved" + ); + assert_eq!( + stats.dependencies_unflushed, 0, + "dependency pending queue should be cleared" + ); + assert_eq!( + stats.integrations_stored, 1, + "integration dedupe history should be preserved" + ); + assert_eq!( + stats.integrations_unflushed, 0, + "integration pending queue should be cleared" + ); + assert_eq!( + stats.configurations_stored, 1, + "configuration dedupe history should be preserved" + ); + assert_eq!( + stats.configurations_unflushed, 0, + "configuration pending queue should be cleared" + ); + assert_eq!(stats.logs, 0, "logs should be cleared"); + assert_eq!( + stats.metric_buckets.buckets, 0, + "metric buckets should be cleared" + ); + assert_eq!( + stats.metric_buckets.series, 0, + "metric series should be cleared" + ); + assert!( + worker.data.endpoints.is_empty(), + "endpoints should be cleared" + ); + assert!(worker.next_action.is_none(), "next_action should be None"); + } + + /// After reset(), actions queued in the mailbox before the fork are discarded. + #[tokio::test] + async fn test_reset_drains_mailbox() { + let (handle, mut worker) = build_test_worker(); + + // Enqueue several actions that should be discarded. + handle + .try_send_msg(TelemetryActions::AddDependency(Dependency { + name: "dep".to_string(), + version: None, + })) + .unwrap(); + let (id, log) = make_log(1, "pre-fork log"); + handle + .try_send_msg(TelemetryActions::AddLog((id, log))) + .unwrap(); + + // Stage one action as if trigger() had already stored it. + worker.next_action = Some(TelemetryActions::Lifecycle(LifecycleAction::Start)); + + worker.reset(); + + // The mailbox must be empty and next_action cleared. + assert!( + worker.mailbox.try_recv().is_err(), + "mailbox should be empty" + ); + assert!(worker.next_action.is_none(), "next_action should be None"); + // None of the queued actions should have been applied to pending state. + let stats = worker.stats(); + assert_eq!( + stats.dependencies_stored, 0, + "queued AddDependency must not be applied" + ); + assert_eq!( + stats.dependencies_unflushed, 0, + "queued AddDependency must not be pending" + ); + assert_eq!(stats.logs, 0, "queued AddLog must be discarded"); + } + + /// After reset(), the worker accepts new telemetry and processes it normally. + #[tokio::test] + async fn test_worker_accepts_new_data_after_reset() { + let (handle, mut worker) = build_test_worker(); + worker.flavor = TelemetryWorkerFlavor::MetricsLogs; + + // Populate state before reset – this data must not survive. + let (id, log) = make_log(99, "pre-fork"); + worker.data.logs.get_mut_or_insert(id, log); + + worker.reset(); + + // Send a new log from the child side. + let (id2, log2) = make_log(1, "post-fork"); + handle + .try_send_msg(TelemetryActions::AddLog((id2, log2))) + .unwrap(); + + // Simulate one trigger() + run() cycle. + worker.trigger().await; + worker.run().await; + + let stats = worker.stats(); + // Only the new post-fork log should be buffered. + assert_eq!(stats.logs, 1, "only post-fork log should be present"); + } + + /// After reset(), lifecycle state needed to keep periodic flushing alive is preserved. + #[tokio::test] + async fn test_reset_preserves_started_and_deadlines() { + let (_handle, mut worker) = build_test_worker(); + + worker.data.started = true; + worker + .deadlines + .schedule_event(LifecycleAction::FlushMetricAggr) + .unwrap(); + worker + .deadlines + .schedule_event(LifecycleAction::FlushData) + .unwrap(); + + let deadlines_before = worker.deadlines.deadlines.clone(); + + worker.reset(); + + assert!(worker.data.started, "started flag should be preserved"); + assert_eq!( + worker.deadlines.deadlines.len(), + deadlines_before.len(), + "scheduled deadlines should be preserved" + ); + for ((_, actual), (_, expected)) in worker + .deadlines + .deadlines + .iter() + .zip(deadlines_before.iter()) + { + assert_eq!( + actual, expected, + "deadline kinds should be preserved across reset" + ); + } + } + } } diff --git a/libdd-telemetry/src/worker/store.rs b/libdd-telemetry/src/worker/store.rs index 3986941de1..1c2d400900 100644 --- a/libdd-telemetry/src/worker/store.rs +++ b/libdd-telemetry/src/worker/store.rs @@ -208,6 +208,11 @@ where pub fn len_stored(&self) -> usize { self.items.len() } + + /// Discard only pending unflushed items while preserving stored dedupe history. + pub fn clear(&mut self) { + self.unflushed.clear(); + } } impl Extend for Store