Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions libs/@local/graph/api/src/rest/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ use utoipa::{OpenApi, ToSchema};

use super::status::BoxedResponse;
use crate::rest::{
AuthenticatedUserHeader, OpenApiQuery, QueryLogger, RestApiStore,
ApiConfig, AuthenticatedUserHeader, OpenApiQuery, QueryLogger, RestApiStore,
json::Json,
resolve_limit,
status::{report_to_response, status_to_response},
utoipa_typedef::{ListOrValue, MaybeListOfDataType, subgraph::Subgraph},
};
Expand Down Expand Up @@ -341,6 +342,7 @@ async fn query_data_types<S>(
AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader,
store_pool: Extension<Arc<S>>,
temporal_client: Extension<Option<Arc<TemporalClient>>>,
Extension(api_config): Extension<ApiConfig>,
mut query_logger: Option<Extension<QueryLogger>>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<QueryDataTypesResponse>, BoxedResponse>
Expand All @@ -351,20 +353,25 @@ where
query_logger.capture(actor_id, OpenApiQuery::GetDataTypes(&request));
}

// Manually deserialize the query from a JSON value to allow borrowed deserialization
// and better error reporting.
let mut params = QueryDataTypesParams::deserialize(&request)
.map_err(Report::from)
.map_err(report_to_response)?;

params.limit = Some(
resolve_limit(params.limit, api_config.query_ontology_limit)
.attach(hash_status::StatusCode::InvalidArgument)
.map_err(report_to_response)?,
);

let store = store_pool
.acquire(temporal_client.0)
.await
.map_err(report_to_response)?;

let response = store
.query_data_types(
actor_id,
// Manually deserialize the query from a JSON value to allow borrowed deserialization
// and better error reporting.
QueryDataTypesParams::deserialize(&request)
.map_err(Report::from)
.map_err(report_to_response)?,
)
.query_data_types(actor_id, params)
.await
.map_err(report_to_response)
.map(Json);
Expand Down Expand Up @@ -405,6 +412,7 @@ async fn query_data_type_subgraph<S>(
AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader,
store_pool: Extension<Arc<S>>,
temporal_client: Extension<Option<Arc<TemporalClient>>>,
Extension(api_config): Extension<ApiConfig>,
mut query_logger: Option<Extension<QueryLogger>>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<QueryDataTypeSubgraphResponse>, BoxedResponse>
Expand All @@ -415,21 +423,27 @@ where
query_logger.capture(actor_id, OpenApiQuery::GetDataTypeSubgraph(&request));
}

let store = store_pool
.acquire(temporal_client.0)
.await
.map_err(report_to_response)?;

// Manually deserialize the query from a JSON value to allow borrowed deserialization
// and better error reporting.
let params = QueryDataTypeSubgraphParams::deserialize(&request)
let mut params = QueryDataTypeSubgraphParams::deserialize(&request)
.map_err(Report::from)
.map_err(report_to_response)?;
params
.validate()
.map_err(Report::new)
.map_err(report_to_response)?;

params.request_mut().limit = Some(
resolve_limit(params.request().limit, api_config.query_ontology_limit)
.attach(hash_status::StatusCode::InvalidArgument)
.map_err(report_to_response)?,
);

let store = store_pool
.acquire(temporal_client.0)
.await
.map_err(report_to_response)?;

let response = store
.query_data_type_subgraph(actor_id, params)
.await
Expand Down
24 changes: 6 additions & 18 deletions libs/@local/graph/api/src/rest/entity_query_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use serde_json::value::RawValue as RawJsonValue;
use type_system::knowledge::Entity;
use utoipa::ToSchema;

use super::{ApiConfig, status::BoxedResponse};
use super::{ApiConfig, LimitExceededError, resolve_limit, status::BoxedResponse};

#[tracing::instrument(level = "info", skip_all)]
fn generate_sorting_paths(
Expand Down Expand Up @@ -485,8 +485,6 @@ pub enum EntityQueryOptionsError {
instead."
)]
InvalidFieldForEntityOptions { field: &'static str },
#[display("The requested limit ({requested}) exceeds the maximum allowed limit ({max}).")]
LimitExceeded { requested: usize, max: usize },
}

impl core::error::Error for EntityQueryOptionsError {}
Expand Down Expand Up @@ -591,27 +589,17 @@ impl<'q, 's, 'p> TryFrom<FlatQueryEntitiesRequestData<'q, 's, 'p>> for EntityQue
impl<'p> EntityQueryOptions<'_, 'p> {
/// # Errors
///
/// Returns `LimitExceeded` if the requested limit exceeds the configured maximum in
/// Returns [`LimitExceededError`] if the requested limit exceeds the configured maximum in
/// [`ApiConfig::query_entity_limit`].
pub fn into_params<'f>(
self,
filter: Filter<'f, Entity>,
config: ApiConfig,
) -> Result<QueryEntitiesParams<'f>, Report<EntityQueryOptionsError>>
) -> Result<QueryEntitiesParams<'f>, Report<LimitExceededError>>
where
'p: 'f,
{
let max = config.query_entity_limit;
let limit = match self.limit {
Some(requested) if requested > max => {
return Err(Report::new(EntityQueryOptionsError::LimitExceeded {
requested,
max,
}));
}
Some(limit) => limit,
None => max,
};
let limit = resolve_limit(self.limit, config.query_entity_limit)?;

Ok(QueryEntitiesParams {
filter,
Expand All @@ -636,14 +624,14 @@ impl<'p> EntityQueryOptions<'_, 'p> {

/// # Errors
///
/// Returns `LimitExceeded` if the requested limit exceeds the configured maximum in
/// Returns [`LimitExceededError`] if the requested limit exceeds the configured maximum in
/// [`ApiConfig::query_entity_limit`].
pub fn into_traversal_params<'q>(
self,
filter: Filter<'q, Entity>,
traversal: SubgraphTraversalParams,
config: ApiConfig,
) -> Result<QueryEntitySubgraphParams<'q>, Report<EntityQueryOptionsError>>
) -> Result<QueryEntitySubgraphParams<'q>, Report<LimitExceededError>>
where
'p: 'q,
{
Expand Down
44 changes: 29 additions & 15 deletions libs/@local/graph/api/src/rest/entity_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ use utoipa::{OpenApi, ToSchema};

use super::status::BoxedResponse;
use crate::rest::{
AuthenticatedUserHeader, OpenApiQuery, QueryLogger, RestApiStore,
ApiConfig, AuthenticatedUserHeader, OpenApiQuery, QueryLogger, RestApiStore,
json::Json,
resolve_limit,
status::{report_to_response, status_to_response},
utoipa_typedef::{ListOrValue, MaybeListOfEntityType, subgraph::Subgraph},
};
Expand Down Expand Up @@ -470,6 +471,7 @@ async fn query_entity_types<S>(
AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader,
store_pool: Extension<Arc<S>>,
temporal_client: Extension<Option<Arc<TemporalClient>>>,
Extension(api_config): Extension<ApiConfig>,
mut query_logger: Option<Extension<QueryLogger>>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<QueryEntityTypesResponse>, BoxedResponse>
Expand All @@ -480,20 +482,25 @@ where
query_logger.capture(actor_id, OpenApiQuery::GetEntityTypes(&request));
}

// Manually deserialize the query from a JSON value to allow borrowed deserialization
// and better error reporting.
let mut params = QueryEntityTypesParams::deserialize(&request)
.map_err(Report::from)
.map_err(report_to_response)?;

params.request.limit = Some(
resolve_limit(params.request.limit, api_config.query_ontology_limit)
.attach(hash_status::StatusCode::InvalidArgument)
.map_err(report_to_response)?,
);

let store = store_pool
.acquire(temporal_client.0)
.await
.map_err(report_to_response)?;

let response = store
.query_entity_types(
actor_id,
// Manually deserialize the query from a JSON value to allow borrowed deserialization
// and better error reporting.
QueryEntityTypesParams::deserialize(&request)
.map_err(Report::from)
.map_err(report_to_response)?,
)
.query_entity_types(actor_id, params)
.await
.map_err(report_to_response)
.map(Json);
Expand Down Expand Up @@ -605,6 +612,7 @@ async fn query_entity_type_subgraph<S>(
AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader,
store_pool: Extension<Arc<S>>,
temporal_client: Extension<Option<Arc<TemporalClient>>>,
Extension(api_config): Extension<ApiConfig>,
mut query_logger: Option<Extension<QueryLogger>>,
Json(request): Json<serde_json::Value>,
) -> Result<Json<QueryEntityTypeSubgraphResponse>, BoxedResponse>
Expand All @@ -615,19 +623,25 @@ where
query_logger.capture(actor_id, OpenApiQuery::GetEntityTypeSubgraph(&request));
}

let store = store_pool
.acquire(temporal_client.0)
.await
.map_err(report_to_response)?;

let params = QueryEntityTypeSubgraphParams::deserialize(&request)
let mut params = QueryEntityTypeSubgraphParams::deserialize(&request)
.map_err(Report::from)
.map_err(report_to_response)?;
params
.validate()
.map_err(Report::new)
.map_err(report_to_response)?;

params.request_mut().limit = Some(
resolve_limit(params.request().limit, api_config.query_ontology_limit)
.attach(hash_status::StatusCode::InvalidArgument)
.map_err(report_to_response)?,
);

let store = store_pool
.acquire(temporal_client.0)
.await
.map_err(report_to_response)?;

let response = store
.query_entity_type_subgraph(actor_id, params)
.await
Expand Down
43 changes: 42 additions & 1 deletion libs/@local/graph/api/src/rest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod entity_query_request;
mod json;
mod utoipa_typedef;
use alloc::{borrow::Cow, sync::Arc};
use core::str::FromStr as _;
use core::{error::Error, str::FromStr as _};
use std::{
fs,
io::{self, Write as _},
Expand Down Expand Up @@ -323,6 +323,37 @@ pub enum OpenApiQuery<'a> {
DiffEntity(&'a DiffEntityParams),
}

/// The requested limit exceeds the configured maximum.
#[derive(Debug, Copy, Clone, PartialEq, Eq, derive_more::Display)]
#[display("The requested limit ({requested}) exceeds the maximum allowed limit ({max}).")]
pub struct LimitExceededError {
pub requested: usize,
pub max: usize,
}

impl Error for LimitExceededError {}

/// Resolves an optional request limit against a configured maximum.
///
/// Returns the configured maximum when no limit is requested. Returns the requested limit if it
/// does not exceed the maximum.
///
/// # Errors
///
/// Returns [`LimitExceededError`] if `requested` exceeds `max`.
pub(crate) fn resolve_limit(
requested: Option<usize>,
max: usize,
) -> Result<usize, Report<LimitExceededError>> {
match requested {
Some(requested) if requested > max => {
Err(Report::new(LimitExceededError { requested, max }))
}
Some(limit) => Ok(limit),
None => Ok(max),
}
}

/// Server-side configuration for the REST API, shared across handlers via an [`Extension`].
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "clap", derive(clap::Parser))]
Expand All @@ -336,6 +367,16 @@ pub struct ApiConfig {
clap(long, default_value_t = 1000, env = "HASH_GRAPH_QUERY_ENTITY_LIMIT")
)]
pub query_entity_limit: usize,

/// The default and maximum number of ontology types returned by a single query.
///
/// When a request omits `limit`, this value is used. Requests that specify a `limit` larger
/// than this value are rejected.
#[cfg_attr(
feature = "clap",
clap(long, default_value_t = 1000, env = "HASH_GRAPH_QUERY_ONTOLOGY_LIMIT")
)]
pub query_ontology_limit: usize,
}

pub struct RestRouterDependencies<S>
Expand Down
Loading
Loading