diff --git a/cli/crates/federated-dev/src/dev/gateway_nanny.rs b/cli/crates/federated-dev/src/dev/gateway_nanny.rs index 1f3e3ad0a4..0666e6c2c0 100644 --- a/cli/crates/federated-dev/src/dev/gateway_nanny.rs +++ b/cli/crates/federated-dev/src/dev/gateway_nanny.rs @@ -7,9 +7,8 @@ use super::bus::{EngineSender, GraphWatcher}; use engine_v2::Engine; use futures_concurrency::stream::Merge; use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt as _, StreamExt}; -use runtime::rate_limiting::KeyedRateLimitConfig; use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter; -use tokio::sync::mpsc; +use tokio::sync::watch; use tokio_stream::wrappers::WatchStream; /// The GatewayNanny looks after the `Gateway` - on updates to the graph or config it'll @@ -57,7 +56,12 @@ pub(super) async fn new_gateway(config: Option) -> O .into_iter() .map(|(k, v)| { ( - k.to_string(), + match k { + engine_v2::config::RateLimitKey::Global => runtime::rate_limiting::RateLimitKey::Global, + engine_v2::config::RateLimitKey::Subgraph(name) => { + runtime::rate_limiting::RateLimitKey::Subgraph(name.to_string().into()) + } + }, runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, @@ -66,7 +70,7 @@ pub(super) async fn new_gateway(config: Option) -> O }) .collect::>(); - let (_, rx) = mpsc::channel(100); + let (_, rx) = watch::channel(rate_limiting_configs); let runtime = CliRuntime { fetcher: runtime_local::NativeFetcher::runtime_fetcher(), @@ -75,7 +79,7 @@ pub(super) async fn new_gateway(config: Option) -> O ), kv: runtime_local::InMemoryKvStore::runtime(), meter: grafbase_telemetry::metrics::meter_from_global_provider(), - rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx), + rate_limiter: InMemoryRateLimiter::runtime(rx), }; let schema = config.try_into().ok()?; diff --git a/engine/crates/engine-v2/config/src/lib.rs b/engine/crates/engine-v2/config/src/lib.rs index 424c9d1a0b..a37dba6a0c 100644 --- a/engine/crates/engine-v2/config/src/lib.rs +++ b/engine/crates/engine-v2/config/src/lib.rs @@ -8,8 +8,6 @@ mod v3; mod v4; mod v5; -pub const GLOBAL_RATE_LIMIT_KEY: &str = "global"; - /// The latest version of the configuration. /// /// Users of the crate should always use this verison, and we can keep the details diff --git a/engine/crates/engine-v2/config/src/v5.rs b/engine/crates/engine-v2/config/src/v5.rs index d01b6983e4..94e28b29c2 100644 --- a/engine/crates/engine-v2/config/src/v5.rs +++ b/engine/crates/engine-v2/config/src/v5.rs @@ -2,12 +2,11 @@ mod header; mod rate_limit; use std::{ - collections::{BTreeMap, HashMap}, + collections::BTreeMap, path::{Path, PathBuf}, time::Duration, }; -use crate::GLOBAL_RATE_LIMIT_KEY; use federated_graph::{FederatedGraphV3, SubgraphId}; use self::rate_limit::{RateLimitConfigRef, RateLimitRedisConfigRef, RateLimitRedisTlsConfigRef}; @@ -93,16 +92,17 @@ impl Config { }) } - pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> { - let mut key_based_config = HashMap::new(); + pub fn as_keyed_rate_limit_config(&self) -> Vec<(RateLimitKey<'_>, GraphRateLimit)> { + let mut key_based_config = Vec::new(); if let Some(global_config) = self.rate_limit.as_ref().and_then(|c| c.global) { - key_based_config.insert(GLOBAL_RATE_LIMIT_KEY, global_config); + key_based_config.push((RateLimitKey::Global, global_config)); } for subgraph in self.subgraph_configs.values() { if let Some(subgraph_rate_limit) = subgraph.rate_limit { - key_based_config.insert(&self.strings[subgraph.name.0], subgraph_rate_limit); + let key = RateLimitKey::Subgraph(&self.strings[subgraph.name.0]); + key_based_config.push((key, subgraph_rate_limit)); } } @@ -110,6 +110,12 @@ impl Config { } } +#[derive(Clone, Copy, Debug)] +pub enum RateLimitKey<'a> { + Global, + Subgraph(&'a str), +} + impl std::ops::Index for Config { type Output = String; diff --git a/engine/crates/engine-v2/src/engine.rs b/engine/crates/engine-v2/src/engine.rs index 4f3eab2eda..112b4111d1 100644 --- a/engine/crates/engine-v2/src/engine.rs +++ b/engine/crates/engine-v2/src/engine.rs @@ -2,6 +2,7 @@ use ::runtime::{ auth::AccessToken, hooks::Hooks, hot_cache::{CachedDataKind, HotCache, HotCacheFactory}, + rate_limiting::RateLimitKey, }; use async_runtime::stream::StreamExt as _; use engine::{BatchRequest, Request}; @@ -33,11 +34,9 @@ use crate::{ }; mod cache; -mod rate_limiting; mod runtime; mod trusted_documents; -pub use rate_limiting::RateLimitContext; pub use runtime::Runtime; pub(crate) struct SchemaVersion(Vec); @@ -125,7 +124,7 @@ impl Engine { Err(response) => return HttpGraphqlResponse::build(response, format, Default::default()), }; - if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitContext::Global).await { + if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitKey::Global).await { return HttpGraphqlResponse::build( Response::pre_execution_error(GraphqlError::new(err.to_string(), ErrorCode::RateLimited)), format, @@ -160,7 +159,7 @@ impl Engine { } pub async fn create_session(self: &Arc, headers: http::HeaderMap) -> Result, Cow<'static, str>> { - if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitContext::Global).await { + if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitKey::Global).await { return Err( Response::pre_execution_error(GraphqlError::new(err.to_string(), ErrorCode::RateLimited)) .first_error_message() diff --git a/engine/crates/engine-v2/src/engine/rate_limiting.rs b/engine/crates/engine-v2/src/engine/rate_limiting.rs deleted file mode 100644 index 8c6beabf16..0000000000 --- a/engine/crates/engine-v2/src/engine/rate_limiting.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::net::IpAddr; - -use config::GLOBAL_RATE_LIMIT_KEY; -use serde_json::Value; - -use runtime::rate_limiting::RateLimiterContext; - -pub enum RateLimitContext<'a> { - Global, - Subgraph(&'a str), -} - -impl RateLimiterContext for RateLimitContext<'_> { - fn header(&self, _name: http::HeaderName) -> Option<&http::HeaderValue> { - None - } - - fn graphql_operation_name(&self) -> Option<&str> { - None - } - - fn ip(&self) -> Option { - None - } - - fn jwt_claim(&self, _key: &str) -> Option<&Value> { - None - } - - fn key(&self) -> Option<&str> { - Some(match self { - RateLimitContext::Global => GLOBAL_RATE_LIMIT_KEY, - RateLimitContext::Subgraph(name) => name, - }) - } - - fn is_global(&self) -> bool { - matches!(self, Self::Global) - } -} diff --git a/engine/crates/engine-v2/src/lib.rs b/engine/crates/engine-v2/src/lib.rs index fae11d70d4..cbe39c21e4 100644 --- a/engine/crates/engine-v2/src/lib.rs +++ b/engine/crates/engine-v2/src/lib.rs @@ -10,7 +10,7 @@ mod utils; pub mod websocket; pub use ::engine::{BatchRequest, Request}; -pub use engine::{Engine, RateLimitContext, Runtime, Session}; +pub use engine::{Engine, Runtime, Session}; pub use http_response::{HttpGraphqlResponse, HttpGraphqlResponseBody}; pub use schema::{CacheControl, Schema}; diff --git a/engine/crates/engine-v2/src/sources/graphql/request.rs b/engine/crates/engine-v2/src/sources/graphql/request.rs index 2268ff8519..56db9ae940 100644 --- a/engine/crates/engine-v2/src/sources/graphql/request.rs +++ b/engine/crates/engine-v2/src/sources/graphql/request.rs @@ -4,7 +4,10 @@ use grafbase_telemetry::{ gql_response_status::{GraphqlResponseStatus, SubgraphResponseStatus}, span::{GqlRecorderSpanExt, GRAFBASE_TARGET}, }; -use runtime::fetch::{FetchRequest, FetchResponse}; +use runtime::{ + fetch::{FetchRequest, FetchResponse}, + rate_limiting::RateLimitKey, +}; use schema::sources::graphql::{GraphqlEndpointId, GraphqlEndpointWalker}; use tower::retry::budget::Budget; use tracing::Span; @@ -13,7 +16,7 @@ use web_time::Duration; use crate::{ execution::{ExecutionContext, ExecutionError, ExecutionResult}, response::SubgraphResponse, - RateLimitContext, Runtime, + Runtime, }; pub trait ResponseIngester: Send { @@ -130,7 +133,7 @@ async fn rate_limited_fetch<'ctx, R: Runtime>( ctx.engine .runtime .rate_limiter() - .limit(&RateLimitContext::Subgraph(subgraph.name())) + .limit(&RateLimitKey::Subgraph(subgraph.name().into())) .await?; ctx.engine diff --git a/engine/crates/engine-v2/src/sources/graphql/subscription.rs b/engine/crates/engine-v2/src/sources/graphql/subscription.rs index 14422e6585..e648caab8a 100644 --- a/engine/crates/engine-v2/src/sources/graphql/subscription.rs +++ b/engine/crates/engine-v2/src/sources/graphql/subscription.rs @@ -1,5 +1,5 @@ use futures_util::{stream::BoxStream, StreamExt}; -use runtime::fetch::GraphqlRequest; +use runtime::{fetch::GraphqlRequest, rate_limiting::RateLimitKey}; use serde::de::DeserializeSeed; use super::{ @@ -37,7 +37,7 @@ impl GraphqlPreparedExecutor { ctx.engine .runtime .rate_limiter() - .limit(&crate::engine::RateLimitContext::Subgraph(subgraph.name())) + .limit(&RateLimitKey::Subgraph(subgraph.name().into())) .await?; let stream = ctx diff --git a/engine/crates/integration-tests/src/federation/builder/test_runtime.rs b/engine/crates/integration-tests/src/federation/builder/test_runtime.rs index 9d546cc57b..ad741fb150 100644 --- a/engine/crates/integration-tests/src/federation/builder/test_runtime.rs +++ b/engine/crates/integration-tests/src/federation/builder/test_runtime.rs @@ -4,7 +4,7 @@ use runtime_local::{ rate_limiting::in_memory::key_based::InMemoryRateLimiter, InMemoryHotCacheFactory, InMemoryKvStore, NativeFetcher, }; use runtime_noop::trusted_documents::NoopTrustedDocuments; -use tokio::sync::mpsc; +use tokio::sync::watch; pub struct TestRuntime { pub fetcher: runtime::fetch::Fetcher, @@ -17,14 +17,15 @@ pub struct TestRuntime { impl Default for TestRuntime { fn default() -> Self { - let (_, rx) = mpsc::channel(100); + let (_, rx) = watch::channel(Default::default()); + Self { fetcher: NativeFetcher::runtime_fetcher(), trusted_documents: trusted_documents_client::Client::new(NoopTrustedDocuments), kv: InMemoryKvStore::runtime(), meter: metrics::meter_from_global_provider(), hooks: Default::default(), - rate_limiter: InMemoryRateLimiter::runtime(Default::default(), rx), + rate_limiter: InMemoryRateLimiter::runtime(rx), } } } diff --git a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs index 27006fddcb..9d019c69f9 100644 --- a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs +++ b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs @@ -1,4 +1,3 @@ -use std::net::IpAddr; use std::num::NonZeroU32; use std::sync::Arc; use std::{collections::HashMap, sync::RwLock}; @@ -7,69 +6,48 @@ use futures_util::future::BoxFuture; use futures_util::FutureExt; use governor::Quota; use grafbase_telemetry::span::GRAFBASE_TARGET; -use serde_json::Value; -use http::{HeaderName, HeaderValue}; -use runtime::rate_limiting::{Error, GraphRateLimit, KeyedRateLimitConfig, RateLimiter, RateLimiterContext}; -use tokio::sync::mpsc; +use runtime::rate_limiting::{Error, GraphRateLimit, RateLimitKey, RateLimiter, RateLimiterContext}; +use tokio::sync::watch; -pub struct RateLimitingContext(pub String); - -impl RateLimiterContext for RateLimitingContext { - fn header(&self, _name: HeaderName) -> Option<&HeaderValue> { - None - } - - fn graphql_operation_name(&self) -> Option<&str> { - None - } - - fn ip(&self) -> Option { - None - } - - fn jwt_claim(&self, _key: &str) -> Option<&Value> { - None - } - - fn key(&self) -> Option<&str> { - Some(&self.0) - } -} +type Limits = HashMap, GraphRateLimit>; +type Limiters = HashMap, governor::DefaultKeyedRateLimiter>; pub struct InMemoryRateLimiter { - limiters: Arc>>>, + limiters: Arc>, } impl InMemoryRateLimiter { - pub fn runtime( - config: KeyedRateLimitConfig, - mut updates: mpsc::Receiver>, - ) -> RateLimiter { + pub fn runtime(mut updates: watch::Receiver) -> RateLimiter { let mut limiters = HashMap::new(); // add subgraph rate limiting configuration - for (name, config) in config.rate_limiting_configs { - let Some(limiter) = create_limiter(config) else { + for (name, config) in updates.borrow_and_update().iter() { + let Some(limiter) = create_limiter(*config) else { continue; }; - limiters.insert(name.to_string(), limiter); + limiters.insert(name.clone(), limiter); } let limiters = Arc::new(RwLock::new(limiters)); - let limiters_copy = limiters.clone(); + let limiters_copy = Arc::downgrade(&limiters); tokio::spawn(async move { - while let Some(updates) = updates.recv().await { - let mut limiters = limiters_copy.write().unwrap(); + while let Ok(()) = updates.changed().await { + let Some(limiters) = limiters_copy.upgrade() else { + break; + }; + + let mut limiters = limiters.write().unwrap(); + limiters.clear(); - for (name, config) in updates { - let Some(limiter) = create_limiter(config) else { + for (name, config) in updates.borrow_and_update().iter() { + let Some(limiter) = create_limiter(*config) else { continue; }; - limiters.insert(name.to_string(), limiter); + limiters.insert(name.clone(), limiter); } } }); diff --git a/engine/crates/runtime-local/src/rate_limiting/redis.rs b/engine/crates/runtime-local/src/rate_limiting/redis.rs index 99fdf0e1f2..ad8eeae852 100644 --- a/engine/crates/runtime-local/src/rate_limiting/redis.rs +++ b/engine/crates/runtime-local/src/rate_limiting/redis.rs @@ -1,6 +1,5 @@ mod pool; -use core::fmt; use std::{ collections::HashMap, fs::File, @@ -14,7 +13,7 @@ use deadpool::managed::Pool; use futures_util::future::BoxFuture; use grafbase_telemetry::span::GRAFBASE_TARGET; use redis::ClientTlsConfig; -use runtime::rate_limiting::{Error, GraphRateLimit, RateLimiter, RateLimiterContext}; +use runtime::rate_limiting::{Error, GraphRateLimit, RateLimitKey, RateLimiter, RateLimiterContext}; use tokio::sync::watch; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -31,7 +30,7 @@ pub struct RateLimitRedisTlsConfig<'a> { pub ca: Option<&'a Path>, } -pub type Limits = watch::Receiver>; +pub type Limits = watch::Receiver, GraphRateLimit>>; /// Rate limiter by utilizing Redis as a backend. It uses a averaging fixed window algorithm /// to define is the limit reached or not. @@ -53,24 +52,6 @@ pub struct RedisRateLimiter { limits: Limits, } -enum Key<'a> { - Graph { name: &'a str }, -} - -impl<'a> fmt::Display for Key<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("rate_limit:")?; - - match self { - Key::Graph { name } => { - f.write_str(name)?; - } - } - - Ok(()) - } -} - impl RedisRateLimiter { pub async fn runtime(config: RateLimitRedisConfig<'_>, limits: Limits) -> anyhow::Result { Ok(RateLimiter::new(Self::new(config, limits).await?)) @@ -147,11 +128,14 @@ impl RedisRateLimiter { }) } - fn generate_key(&self, bucket: u64, context: &dyn RateLimiterContext, key: Key<'_>) -> String { - if context.is_global() { - format!("{}:{key}:{bucket}", self.key_prefix) - } else { - format!("{}:subgraph:{key}:{bucket}", self.key_prefix) + fn generate_key(&self, bucket: u64, key: &RateLimitKey<'_>) -> String { + match key { + RateLimitKey::Global => { + format!("{}:rate_limit:global:{bucket}", self.key_prefix) + } + RateLimitKey::Subgraph(ref graph) => { + format!("{}:subgraph:rate_limit:{graph}:{bucket}", self.key_prefix) + } } } @@ -187,9 +171,9 @@ impl RedisRateLimiter { let bucket_percentage = (current_ts % duration_ns) as f64 / duration_ns as f64; // The counter key for the current window. - let current_bucket = self.generate_key(current_bucket, context, Key::Graph { name: key }); + let current_bucket = self.generate_key(current_bucket, key); // The counter key for the previous window. - let previous_bucket = self.generate_key(previous_bucket, context, Key::Graph { name: key }); + let previous_bucket = self.generate_key(previous_bucket, key); // We execute multiple commands in one pipelined query to be _fast_. let mut pipe = redis::pipe(); diff --git a/engine/crates/runtime/src/rate_limiting.rs b/engine/crates/runtime/src/rate_limiting.rs index 710852165c..c91bed2fde 100644 --- a/engine/crates/runtime/src/rate_limiting.rs +++ b/engine/crates/runtime/src/rate_limiting.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; @@ -18,12 +19,9 @@ pub trait RateLimiterContext: Send + Sync { fn graphql_operation_name(&self) -> Option<&str>; fn ip(&self) -> Option; fn jwt_claim(&self, key: &str) -> Option<&serde_json::Value>; - fn key(&self) -> Option<&str> { - None - } - fn is_global(&self) -> bool { - true + fn key(&self) -> Option<&RateLimitKey<'_>> { + None } } @@ -44,6 +42,46 @@ impl RateLimiter { } } +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum RateLimitKey<'a> { + Global, + Subgraph(Cow<'a, str>), +} + +impl<'a> From<&'a str> for RateLimitKey<'a> { + fn from(value: &'a str) -> Self { + Self::Subgraph(Cow::Borrowed(value)) + } +} + +impl<'a> From for RateLimitKey<'a> { + fn from(value: String) -> Self { + Self::Subgraph(Cow::Owned(value)) + } +} + +impl<'a> RateLimiterContext for RateLimitKey<'a> { + fn header(&self, _: http::HeaderName) -> Option<&http::HeaderValue> { + None + } + + fn graphql_operation_name(&self) -> Option<&str> { + None + } + + fn ip(&self) -> Option { + None + } + + fn jwt_claim(&self, _: &str) -> Option<&serde_json::Value> { + None + } + + fn key(&self) -> Option<&RateLimitKey<'a>> { + Some(self) + } +} + impl std::ops::Deref for RateLimiter { type Target = dyn RateLimiterInner; diff --git a/gateway/crates/federated-server/src/config.rs b/gateway/crates/federated-server/src/config.rs index c110c2f810..3f29e597bb 100644 --- a/gateway/crates/federated-server/src/config.rs +++ b/gateway/crates/federated-server/src/config.rs @@ -6,12 +6,7 @@ mod health; pub(crate) mod hot_reload; mod rate_limit; -use std::{ - collections::{BTreeMap, HashMap}, - net::SocketAddr, - path::PathBuf, - time::Duration, -}; +use std::{collections::BTreeMap, net::SocketAddr, path::PathBuf, time::Duration}; pub use self::health::HealthConfig; use ascii::AsciiString; @@ -21,6 +16,7 @@ pub use entity_caching::EntityCachingConfig; use grafbase_telemetry::config::TelemetryConfig; pub use header::{HeaderForward, HeaderInsert, HeaderRemove, HeaderRule, NameOrPattern}; pub use rate_limit::{GraphRateLimit, RateLimitConfig}; +use runtime::rate_limiting::RateLimitKey; use runtime_local::HooksWasiConfig; use serde_dynamic_string::DynamicString; use url::Url; @@ -75,16 +71,16 @@ pub struct Config { impl Config { /// Load the rate limit configuration for global and subgraph level settings. - pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> { - let mut key_based_config = HashMap::new(); + pub fn as_keyed_rate_limit_config(&self) -> Vec<(RateLimitKey<'static>, GraphRateLimit)> { + let mut key_based_config = Vec::new(); if let Some(global_config) = self.gateway.rate_limit.as_ref().and_then(|c| c.global) { - key_based_config.insert("global", global_config); + key_based_config.push((RateLimitKey::Global, global_config)); } for (subgraph_name, subgraph) in self.subgraphs.iter() { if let Some(limit) = subgraph.rate_limit { - key_based_config.insert(subgraph_name, limit); + key_based_config.push((RateLimitKey::Subgraph(subgraph_name.clone().into()), limit)); } } diff --git a/gateway/crates/federated-server/src/config/hot_reload.rs b/gateway/crates/federated-server/src/config/hot_reload.rs index 6c9fa3c5a1..eb454be7ee 100644 --- a/gateway/crates/federated-server/src/config/hot_reload.rs +++ b/gateway/crates/federated-server/src/config/hot_reload.rs @@ -2,49 +2,23 @@ use std::{collections::HashMap, fs, path::PathBuf, sync::OnceLock, time::Duratio use grafbase_telemetry::span::GRAFBASE_TARGET; use notify::{EventHandler, EventKind, PollWatcher, Watcher}; -use runtime::rate_limiting::GraphRateLimit; -use tokio::sync::{mpsc, watch}; +use runtime::rate_limiting::{GraphRateLimit, RateLimitKey}; +use tokio::sync::watch; use crate::Config; -type RateLimitData = HashMap; - -pub(crate) enum RateLimitSender { - Watch(watch::Sender), - Mpsc(mpsc::Sender), -} - -impl RateLimitSender { - fn send(&self, data: RateLimitData) -> crate::Result<()> { - match self { - RateLimitSender::Watch(channel) => Ok(channel.send(data)?), - RateLimitSender::Mpsc(channel) => Ok(channel.blocking_send(data)?), - } - } -} - -impl From> for RateLimitSender { - fn from(value: watch::Sender) -> Self { - Self::Watch(value) - } -} - -impl From> for RateLimitSender { - fn from(value: mpsc::Sender) -> Self { - Self::Mpsc(value) - } -} +type RateLimitData = HashMap, GraphRateLimit>; pub(crate) struct ConfigWatcher { config_path: PathBuf, - rate_limit_sender: RateLimitSender, + rate_limit_sender: watch::Sender, } impl ConfigWatcher { - pub fn new(config_path: PathBuf, rate_limit_sender: impl Into) -> Self { + pub fn new(config_path: PathBuf, rate_limit_sender: watch::Sender) -> Self { Self { config_path, - rate_limit_sender: rate_limit_sender.into(), + rate_limit_sender, } } @@ -90,7 +64,7 @@ impl ConfigWatcher { .into_iter() .map(|(k, v)| { ( - k.to_string(), + k, runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, diff --git a/gateway/crates/federated-server/src/server/gateway.rs b/gateway/crates/federated-server/src/server/gateway.rs index 8946ca92d4..17c1c83f61 100644 --- a/gateway/crates/federated-server/src/server/gateway.rs +++ b/gateway/crates/federated-server/src/server/gateway.rs @@ -5,12 +5,11 @@ use std::{collections::BTreeMap, sync::Arc}; use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter; use runtime_local::rate_limiting::redis::RedisRateLimiter; -use tokio::sync::{mpsc, watch}; +use tokio::sync::watch; use engine_v2::Engine; use graphql_composition::FederatedGraph; use parser_sdl::federation::{header::SubgraphHeaderRule, FederatedGraphConfig}; -use runtime::rate_limiting::KeyedRateLimitConfig; use runtime_local::{ComponentLoader, HooksWasi, HooksWasiConfig, InMemoryKvStore}; use runtime_noop::trusted_documents::NoopTrustedDocuments; @@ -142,7 +141,12 @@ pub(super) async fn generate( .into_iter() .map(|(k, v)| { ( - k.to_string(), + match k { + engine_v2::config::RateLimitKey::Global => runtime::rate_limiting::RateLimitKey::Global, + engine_v2::config::RateLimitKey::Subgraph(name) => { + runtime::rate_limiting::RateLimitKey::Subgraph(name.to_string().into()) + } + }, runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, @@ -151,10 +155,14 @@ pub(super) async fn generate( }) .collect::>(); + let (rate_limit_tx, rate_limit_rx) = watch::channel(rate_limiting_configs); + + if let Some(path) = config_hot_reload.then_some(config_path).flatten() { + hot_reload::ConfigWatcher::new(path, rate_limit_tx).watch()?; + } + let rate_limiter = match config.rate_limit_config() { Some(config) if config.storage.is_redis() => { - let (tx, rx) = watch::channel(rate_limiting_configs); - let tls = config .redis .tls @@ -170,29 +178,11 @@ pub(super) async fn generate( tls, }; - match config_path { - Some(path) if config_hot_reload => { - hot_reload::ConfigWatcher::new(path, tx).watch()?; - } - _ => (), - } - - RedisRateLimiter::runtime(global_config, rx) + RedisRateLimiter::runtime(global_config, rate_limit_rx) .await .map_err(|e| crate::Error::InternalError(e.to_string()))? } - _ => { - let (tx, rx) = mpsc::channel(100); - - match config_path { - Some(path) if config_hot_reload => { - hot_reload::ConfigWatcher::new(path, tx).watch()?; - } - _ => (), - } - - InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx) - } + _ => InMemoryRateLimiter::runtime(rate_limit_rx), }; let runtime = GatewayRuntime { diff --git a/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs b/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs index 1d83526592..cb99ee0bbf 100644 --- a/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs +++ b/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs @@ -126,7 +126,7 @@ fn used_fields_should_be_unique() { "data": null, "errors": [ { - "message": "error sending request for url (http://127.0.0.1:46697/)", + "message": "Request to subgraph 'accounts' failed with: error sending request for url (http://127.0.0.1:46697/)", "path": [ "me" ],