From 3f9a0d787b02328950840d644c95c380af0d290a Mon Sep 17 00:00:00 2001 From: itowlson Date: Fri, 14 Nov 2025 13:21:53 +1300 Subject: [PATCH] Async PostgreSQL API Signed-off-by: itowlson --- Cargo.lock | 10 + crates/factor-outbound-pg/Cargo.toml | 1 + .../factor-outbound-pg/src/allowed_hosts.rs | 71 ++++++ crates/factor-outbound-pg/src/client.rs | 157 ++++++++++-- crates/factor-outbound-pg/src/host.rs | 230 +++++++++++++----- crates/factor-outbound-pg/src/lib.rs | 31 ++- crates/factor-outbound-pg/src/types.rs | 18 +- .../factor-outbound-pg/src/types/convert.rs | 2 +- .../factor-outbound-pg/src/types/interval.rs | 2 +- .../factor-outbound-pg/tests/factor_test.rs | 19 +- crates/wasi-async/Cargo.toml | 10 + crates/wasi-async/src/lib.rs | 1 + crates/wasi-async/src/stream.rs | 47 ++++ crates/world/src/conversions.rs | 4 +- crates/world/src/lib.rs | 4 +- wit/deps/spin-postgres@4.2.0/postgres.wit | 184 ++++++++++++++ wit/world.wit | 2 +- 17 files changed, 690 insertions(+), 103 deletions(-) create mode 100644 crates/factor-outbound-pg/src/allowed_hosts.rs create mode 100644 crates/wasi-async/Cargo.toml create mode 100644 crates/wasi-async/src/lib.rs create mode 100644 crates/wasi-async/src/stream.rs create mode 100644 wit/deps/spin-postgres@4.2.0/postgres.wit diff --git a/Cargo.lock b/Cargo.lock index 30561172d..0591dc1cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8769,6 +8769,7 @@ dependencies = [ "spin-factors-test", "spin-locked-app", "spin-resource-table", + "spin-wasi-async", "spin-world", "tokio", "tokio-postgres", @@ -9455,6 +9456,15 @@ dependencies = [ "vaultrs", ] +[[package]] +name = "spin-wasi-async" +version = "3.7.0-pre0" +dependencies = [ + "anyhow", + "spin-core", + "tokio", +] + [[package]] name = "spin-world" version = "3.7.0-pre0" diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index 5a5654ff4..4ad764814 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -24,6 +24,7 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" } spin-factors = { path = "../factors" } spin-locked-app = { path = "../locked-app" } spin-resource-table = { path = "../table" } +spin-wasi-async = { path = "../wasi-async" } spin-world = { path = "../world" } tokio = { workspace = true, features = ["rt-multi-thread"] } tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] } diff --git a/crates/factor-outbound-pg/src/allowed_hosts.rs b/crates/factor-outbound-pg/src/allowed_hosts.rs new file mode 100644 index 000000000..f3cfc4c35 --- /dev/null +++ b/crates/factor-outbound-pg/src/allowed_hosts.rs @@ -0,0 +1,71 @@ +use std::sync::Arc; + +use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; + +/// Encapsulates checking of a PostgreSQL address/connection string against +/// an allow-list. +/// +/// This is broken out as a distinct object to allow it to be synchronously retrieved +/// within a P3 Accessor block and then asynchronously queried outside the block. +#[derive(Clone)] +pub(crate) struct AllowedHostChecker { + allowed_hosts: Arc, +} + +impl AllowedHostChecker { + pub fn new(allowed_hosts: OutboundAllowedHosts) -> Self { + Self { + allowed_hosts: Arc::new(allowed_hosts), + } + } + + #[allow(clippy::result_large_err)] + pub async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> { + fn conn_failed(message: impl Into) -> v4::Error { + v4::Error::ConnectionFailed(message.into()) + } + fn err_other(err: anyhow::Error) -> v4::Error { + v4::Error::Other(err.to_string()) + } + + let config = address + .parse::() + .map_err(|e| conn_failed(e.to_string()))?; + + for (i, host) in config.get_hosts().iter().enumerate() { + match host { + tokio_postgres::config::Host::Tcp(address) => { + let ports = config.get_ports(); + // The port we use is either: + // * The port at the same index as the host + // * The first port if there is only one port + let port = ports.get(i).or_else(|| { + if ports.len() == 1 { + ports.first() + } else { + None + } + }); + let port_str = port.map(|p| format!(":{p}")).unwrap_or_default(); + let url = format!("{address}{port_str}"); + if !self + .allowed_hosts + .check_url(&url, "postgres") + .await + .map_err(err_other)? + { + return Err(conn_failed(format!( + "address postgres://{url} is not permitted" + ))); + } + } + #[cfg(unix)] + tokio_postgres::config::Host::Unix(_) => { + return Err(conn_failed("Unix sockets are not supported on WebAssembly")); + } + } + } + Ok(()) + } +} diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index 657081f4e..e8b701197 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -1,17 +1,18 @@ +use std::sync::Arc; + use anyhow::{Context, Result}; use futures::stream::TryStreamExt as _; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; use spin_world::async_trait; -use spin_world::spin::postgres4_1_0::postgres::{ +use spin_world::spin::postgres4_2_0::postgres::{ self as v4, Column, DbValue, ParameterValue, RowSet, }; -use std::pin::pin; use tokio_postgres::config::SslMode; use tokio_postgres::types::ToSql; use tokio_postgres::{NoTls, Row}; -use crate::types::{convert_data_type, convert_entry, to_sql_parameter}; +use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters}; /// Max connections in a given address' connection pool const CONNECTION_POOL_SIZE: usize = 64; @@ -61,9 +62,18 @@ impl Default for PooledTokioClientFactory { } } +#[derive(Clone)] +pub struct PooledTokioClient(Arc); + +impl AsRef for PooledTokioClient { + fn as_ref(&self) -> &deadpool_postgres::Object { + self.0.as_ref() + } +} + #[async_trait] impl ClientFactory for PooledTokioClientFactory { - type Client = deadpool_postgres::Object; + type Client = PooledTokioClient; async fn get_client( &self, @@ -81,7 +91,7 @@ impl ClientFactory for PooledTokioClientFactory { .map_err(ArcError) .context("establishing PostgreSQL connection pool")?; - Ok(pool.get().await?) + Ok(PooledTokioClient(Arc::new(pool.get().await?))) } } @@ -123,7 +133,7 @@ fn create_connection_pool( } #[async_trait] -pub trait Client: Send + Sync + 'static { +pub trait Client: Clone + Send + Sync + 'static { async fn execute( &self, statement: String, @@ -136,6 +146,19 @@ pub trait Client: Send + Sync + 'static { params: Vec, max_result_bytes: usize, ) -> Result; + + async fn query_async( + &self, + statement: String, + params: Vec, + max_result_bytes: usize, + ) -> Result; +} + +pub struct QueryAsyncResult { + pub columns: Vec, + pub rows: tokio::sync::mpsc::Receiver, + pub error: tokio::sync::oneshot::Receiver>, } /// Extract weak-typed error data for WIT purposes @@ -180,8 +203,13 @@ fn query_failed(e: tokio_postgres::error::Error) -> v4::Error { v4::Error::QueryFailed(query_error) } +fn query_failed_anyhow(e: anyhow::Error) -> v4::Error { + let text = format!("{e:?}"); + v4::Error::QueryFailed(v4::QueryError::Text(text)) +} + #[async_trait] -impl Client for deadpool_postgres::Object { +impl Client for PooledTokioClient { async fn execute( &self, statement: String, @@ -210,17 +238,7 @@ impl Client for deadpool_postgres::Object { params: Vec, max_result_bytes: usize, ) -> Result { - let params = params - .iter() - .map(to_sql_parameter) - .collect::>>() - .map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?; - - let mut results = pin!(self - .as_ref() - .query_raw(&statement, params) - .await - .map_err(query_failed)?); + let (cols_fut, mut results) = self.query_stream(statement, params).await?; let mut columns = None; let mut byte_count = std::mem::size_of::(); @@ -228,16 +246,13 @@ impl Client for deadpool_postgres::Object { async { while let Some(row) = results.try_next().await? { - if columns.is_none() { - columns = Some(infer_columns(&row)); - } - let row = convert_row(&row)?; byte_count += row.iter().map(|v| v.memory_size()).sum::(); if byte_count > max_result_bytes { anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes") } rows.push(row); } + columns = Some(cols_fut.await); Ok(()) } .await @@ -248,6 +263,104 @@ impl Client for deadpool_postgres::Object { rows, }) } + + async fn query_async( + &self, + statement: String, + params: Vec, + max_result_bytes: usize, + ) -> Result { + use futures::StreamExt; + + let (cols_fut, mut rows) = self.query_stream(statement, params).await?; + + let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4); + let (err_tx, err_rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + loop { + let Some(row) = rows.next().await else { + _ = err_tx.send(Ok(())); + return; + }; + match row { + Ok(row) => { + let byte_count = row.iter().map(|v| v.memory_size()).sum::(); + if byte_count > max_result_bytes { + _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text( + format!("query result exceeds limit of {max_result_bytes} bytes"), + )))); + return; + } + + if let Err(e) = rows_tx.send(row).await { + _ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text( + format!("async error: {e}"), + )))); + return; + } + } + Err(e) => { + _ = err_tx.send(Err(e)); + return; + } + } + } + }); + + let cols = cols_fut.await; + + Ok(QueryAsyncResult { + columns: cols, + rows: rows_rx, + error: err_rx, + }) + } +} + +impl PooledTokioClient { + async fn query_stream( + &self, + statement: String, + params: Vec, + ) -> Result< + ( + impl std::future::Future>, + impl futures::Stream, v4::Error>>, + ), + v4::Error, + > { + use futures::{FutureExt, StreamExt}; + + let params = to_sql_parameters(params)?; + + let results = Box::pin( + self.as_ref() + .query_raw(&statement, params) + .await + .map_err(query_failed)?, + ); + + let (cols_tx, cols_rx) = tokio::sync::oneshot::channel(); + let mut cols_tx_opt = Some(cols_tx); + + let row_stm = results.enumerate().map(move |(index, row_res)| { + let row_res = row_res.map_err(query_failed); + row_res.and_then(|r| { + if index == 0 { + if let Some(cols_tx) = cols_tx_opt.take() { + let cols = infer_columns(&r); + _ = cols_tx.send(cols); + } + } + convert_row(&r).map_err(query_failed_anyhow) + }) + }); + + let cols_rx = cols_rx.map(|result| result.unwrap_or_default()); + + Ok((cols_rx, Box::pin(row_stm))) + } } fn infer_columns(row: &Row) -> Vec { diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 4c22dd984..827225ed6 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,7 +1,8 @@ use anyhow::Result; -use spin_core::wasmtime::component::Resource; +use spin_core::wasmtime; +use spin_core::wasmtime::component::{Accessor, FutureReader, Resource, StreamReader}; use spin_world::spin::postgres3_0_0::postgres::{self as v3}; -use spin_world::spin::postgres4_1_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use spin_world::v1::postgres as v1; use spin_world::v1::rdbms_types as v1_types; use spin_world::v2::postgres::{self as v2}; @@ -11,7 +12,8 @@ use tracing::field::Empty; use tracing::instrument; use tracing::Level; -use crate::client::{Client, ClientFactory, HashableCertificate}; +use crate::allowed_hosts::AllowedHostChecker; +use crate::client::{Client, ClientFactory, HashableCertificate, QueryAsyncResult}; use crate::InstanceState; impl InstanceState { @@ -40,53 +42,15 @@ impl InstanceState { .ok_or_else(|| v4::Error::ConnectionFailed("no connection found".into())) } + fn allowed_host_checker(&self) -> AllowedHostChecker { + self.allowed_host_checker.clone() + } + #[allow(clippy::result_large_err)] async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> { - fn conn_failed(message: impl Into) -> v4::Error { - v4::Error::ConnectionFailed(message.into()) - } - fn err_other(err: anyhow::Error) -> v4::Error { - v4::Error::Other(err.to_string()) - } - - let config = address - .parse::() - .map_err(|e| conn_failed(e.to_string()))?; - - for (i, host) in config.get_hosts().iter().enumerate() { - match host { - tokio_postgres::config::Host::Tcp(address) => { - let ports = config.get_ports(); - // The port we use is either: - // * The port at the same index as the host - // * The first port if there is only one port - let port = ports.get(i).or_else(|| { - if ports.len() == 1 { - ports.first() - } else { - None - } - }); - let port_str = port.map(|p| format!(":{p}")).unwrap_or_default(); - let url = format!("{address}{port_str}"); - if !self - .allowed_hosts - .check_url(&url, "postgres") - .await - .map_err(err_other)? - { - return Err(conn_failed(format!( - "address postgres://{url} is not permitted" - ))); - } - } - #[cfg(unix)] - tokio_postgres::config::Host::Unix(_) => { - return Err(conn_failed("Unix sockets are not supported on WebAssembly")); - } - } - } - Ok(()) + self.allowed_host_checker + .ensure_address_allowed(address) + .await } } @@ -183,15 +147,8 @@ impl v4::HostConnectionBuilder for InstanceState { &mut self, self_: Resource, ) -> Result, v4::Error> { - let builder = self - .builders - .get_mut(self_.rep()) - .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?; - // borrow checker gets pedantic here, so we need to outsmart it - let address = builder.address.clone(); - let root_ca = builder.root_ca.clone(); - let conn = self.open_connection(&address, root_ca).await; - conn + let (address, root_ca) = self.get_builder_info(self_.rep())?; + self.open_connection(&address, root_ca).await } async fn drop(&mut self, builder: Resource) -> Result<()> { @@ -242,6 +199,164 @@ impl v4::HostConnection for InstanceState { } } +impl spin_world::spin::postgres4_2_0::postgres::HostConnectionWithStore + for crate::PgFactorData +{ + #[instrument(name = "spin_outbound_pg.open_async", skip(accessor, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] + async fn open_async( + accessor: &Accessor, + address: String, + ) -> Result, v4::Error> { + spin_factor_outbound_networking::record_address_fields(&address); + + Self::ensure_address_allowed_async(accessor, &address).await?; + Self::open_connection_async(accessor, &address, None).await + } + + #[instrument(name = "spin_outbound_pg.execute", skip(accessor, connection, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn execute_async( + accessor: &Accessor, + connection: Resource, + statement: String, + params: Vec, + ) -> Result { + let client = accessor.with(|mut access| { + let host = access.get(); + host.connections.get(connection.rep()).unwrap().clone() + }); + + client.execute(statement, params).await + } + + #[allow(clippy::type_complexity)] // blame bindgen, clippy, blame bindgen + #[instrument(name = "spin_outbound_pg.query_async", skip(accessor, params), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] + async fn query_async( + accessor: &Accessor, + connection: Resource, + statement: String, + params: Vec, + ) -> Result< + ( + Vec, + StreamReader, + FutureReader>, + ), + v4::Error, + > { + use wasmtime::AsContextMut; + + let client = accessor.with(|mut access| { + let host = access.get(); + host.connections.get(connection.rep()).unwrap().clone() + }); + + let QueryAsyncResult { + columns, + rows, + error, + } = client + .query_async(statement, params, MAX_HOST_BUFFERED_BYTES) + .await?; + + let row_producer = spin_wasi_async::stream::producer(rows); + + let (sr, efr) = accessor.with(|mut access| { + let sr = StreamReader::new(access.as_context_mut(), row_producer); + let efr = FutureReader::new(access.as_context_mut(), error); + (sr, efr) + }); + + Ok((columns, sr, efr)) + } +} + +impl InstanceState { + #[allow(clippy::result_large_err)] + fn get_builder_info( + &mut self, + builder_rep: u32, + ) -> Result<(String, Option), v4::Error> { + let builder = self + .builders + .get_mut(builder_rep) + .ok_or_else(|| v4::Error::ConnectionFailed("no builder found".into()))?; + + let address = builder.address.clone(); + let root_ca = builder.root_ca.clone(); + + Ok((address, root_ca)) + } +} + +impl crate::PgFactorData { + #[allow(clippy::result_large_err)] + fn get_builder_info( + accessor: &Accessor, + builder: Resource, + ) -> Result<(String, Option), v4::Error> { + let builder_rep = builder.rep(); + accessor.with(|mut access| { + let host = access.get(); + host.get_builder_info(builder_rep) + }) + } + + async fn ensure_address_allowed_async( + accessor: &Accessor, + address: &str, + ) -> Result<(), v4::Error> { + // A merry dance to avoid doing the async allow check under the accessor + let allowed_host_checker = accessor.with(|mut access| { + let host = access.get(); + host.allowed_host_checker() + }); + + allowed_host_checker.ensure_address_allowed(address).await + } + + async fn open_connection_async( + accessor: &Accessor, + address: &str, + root_ca: Option, + ) -> Result, v4::Error> { + let cf = accessor.with(|mut access| { + let host = access.get(); + host.client_factory.clone() + }); + + let client = cf + .get_client(address, root_ca) + .await + .map_err(|e| v4::Error::ConnectionFailed(format!("{e:?}")))?; + + let rsrc = accessor.with(|mut access| { + let host = access.get(); + host.connections + .push(client) + .map_err(|_| v4::Error::ConnectionFailed("too many connections".into())) + .map(Resource::new_own) + }); + + rsrc + } +} + +impl spin_world::spin::postgres4_2_0::postgres::HostConnectionBuilderWithStore + for crate::PgFactorData +{ + async fn build_async( + accessor: &Accessor, + builder: Resource, + ) -> Result, v4::Error> { + let (address, root_ca) = Self::get_builder_info(accessor, builder)?; + + spin_factor_outbound_networking::record_address_fields(&address); + + Self::ensure_address_allowed_async(accessor, &address).await?; + Self::open_connection_async(accessor, &address, root_ca).await + } +} + impl v2_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) @@ -283,7 +398,6 @@ impl v2::HostConnection for InstanceState { spin_factor_outbound_networking::record_address_fields(&address); self.ensure_address_allowed(&address).await?; - Ok(self.open_connection(&address, None).await?) } diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 8a8891e25..aae1c2443 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,17 +1,16 @@ +mod allowed_hosts; pub mod client; mod host; mod types; use std::{collections::HashMap, sync::Arc}; +use allowed_hosts::AllowedHostChecker; use client::ClientFactory; use spin_factor_otel::OtelFactorState; -use spin_factor_outbound_networking::{ - config::allowed_hosts::OutboundAllowedHosts, OutboundNetworkingFactor, -}; +use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factors::{ - anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, - SelfInstanceBuilder, + anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; pub struct OutboundPgFactor { @@ -24,13 +23,13 @@ impl Factor for OutboundPgFactor { type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { - ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, FactorData>)?; - ctx.link_bindings(spin_world::v2::postgres::add_to_linker::<_, FactorData>)?; + ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, PgFactorData>)?; + ctx.link_bindings(spin_world::v2::postgres::add_to_linker::<_, PgFactorData>)?; ctx.link_bindings( - spin_world::spin::postgres3_0_0::postgres::add_to_linker::<_, FactorData>, + spin_world::spin::postgres3_0_0::postgres::add_to_linker::<_, PgFactorData>, )?; ctx.link_bindings( - spin_world::spin::postgres4_1_0::postgres::add_to_linker::<_, FactorData>, + spin_world::spin::postgres4_2_0::postgres::add_to_linker::<_, PgFactorData>, )?; Ok(()) } @@ -57,7 +56,7 @@ impl Factor for OutboundPgFactor { let cf = ctx.app_state().get(ctx.app_component().id()).unwrap(); Ok(InstanceState { - allowed_hosts, + allowed_host_checker: AllowedHostChecker::new(allowed_hosts), client_factory: cf.clone(), connections: Default::default(), otel, @@ -81,7 +80,7 @@ impl OutboundPgFactor { } pub struct InstanceState { - allowed_hosts: OutboundAllowedHosts, + allowed_host_checker: AllowedHostChecker, client_factory: Arc, connections: spin_resource_table::Table, otel: OtelFactorState, @@ -89,3 +88,13 @@ pub struct InstanceState { } impl SelfInstanceBuilder for InstanceState {} + +pub struct PgFactorData(OutboundPgFactor); + +impl spin_core::wasmtime::component::HasData for PgFactorData { + type Data<'a> = &'a mut InstanceState; +} + +impl spin_core::wasmtime::component::HasData for InstanceState { + type Data<'a> = &'a mut InstanceState; +} diff --git a/crates/factor-outbound-pg/src/types.rs b/crates/factor-outbound-pg/src/types.rs index 85608058f..d8449c2d0 100644 --- a/crates/factor-outbound-pg/src/types.rs +++ b/crates/factor-outbound-pg/src/types.rs @@ -1,4 +1,5 @@ -use spin_world::spin::postgres4_1_0::postgres::{DbDataType, DbValue, ParameterValue}; +use anyhow::Result; +use spin_world::spin::postgres4_2_0::postgres::{self as v4, DbDataType, DbValue, ParameterValue}; use tokio_postgres::types::{FromSql, Type}; use tokio_postgres::{types::ToSql, Row}; @@ -162,3 +163,18 @@ pub fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result Ok(Box::new(PgNull)), } } + +// The logic for "vector of ParameterValue to vector of &dyn ToSql" is +// used in multiple places, but needs to be broken into two functions +// because the return value of the first (the Vec) needs to be kept +// around to provide an owner for the refs. +#[allow(clippy::result_large_err)] +pub fn to_sql_parameters( + params: Vec, +) -> Result>, v4::Error> { + params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v4::Error::BadParameter(format!("{e:?}"))) +} diff --git a/crates/factor-outbound-pg/src/types/convert.rs b/crates/factor-outbound-pg/src/types/convert.rs index 5d325f8fe..303cc53b1 100644 --- a/crates/factor-outbound-pg/src/types/convert.rs +++ b/crates/factor-outbound-pg/src/types/convert.rs @@ -2,7 +2,7 @@ //! the tokio_postgres driver. use anyhow::{anyhow, Context}; -use spin_world::spin::postgres4_1_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use super::decimal::RangeableDecimal; diff --git a/crates/factor-outbound-pg/src/types/interval.rs b/crates/factor-outbound-pg/src/types/interval.rs index cd6632d6e..a87bdbde0 100644 --- a/crates/factor-outbound-pg/src/types/interval.rs +++ b/crates/factor-outbound-pg/src/types/interval.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use spin_world::spin::postgres4_1_0::postgres::{self as v4}; +use spin_world::spin::postgres4_2_0::postgres::{self as v4}; use tokio_postgres::types::{FromSql, ToSql, Type}; #[derive(Debug)] diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 0c4b6500e..eeafb5bfa 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -3,15 +3,16 @@ use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_outbound_pg::client::Client; use spin_factor_outbound_pg::client::ClientFactory; use spin_factor_outbound_pg::client::HashableCertificate; +use spin_factor_outbound_pg::client::QueryAsyncResult; use spin_factor_outbound_pg::OutboundPgFactor; use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; use spin_factors_test::{toml, TestEnvironment}; use spin_world::async_trait; -use spin_world::spin::postgres4_1_0::postgres::Error as PgError; -use spin_world::spin::postgres4_1_0::postgres::HostConnection; -use spin_world::spin::postgres4_1_0::postgres::{self as v2}; -use spin_world::spin::postgres4_1_0::postgres::{ParameterValue, RowSet}; +use spin_world::spin::postgres4_2_0::postgres::Error as PgError; +use spin_world::spin::postgres4_2_0::postgres::HostConnection; +use spin_world::spin::postgres4_2_0::postgres::{self as v2}; +use spin_world::spin::postgres4_2_0::postgres::{ParameterValue, RowSet}; #[derive(RuntimeFactors)] struct TestFactors { @@ -108,6 +109,7 @@ async fn exercise_query() -> anyhow::Result<()> { // TODO: We can expand this mock to track calls and simulate return values #[derive(Default)] pub struct MockClientFactory {} +#[derive(Clone)] pub struct MockClient {} #[async_trait] @@ -143,4 +145,13 @@ impl Client for MockClient { rows: vec![], }) } + + async fn query_async( + &self, + _statement: String, + _params: Vec, + _max_result_bytes: usize, + ) -> Result { + panic!("not implemented"); + } } diff --git a/crates/wasi-async/Cargo.toml b/crates/wasi-async/Cargo.toml new file mode 100644 index 000000000..d352b7c57 --- /dev/null +++ b/crates/wasi-async/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "spin-wasi-async" +version.workspace = true +authors.workspace = true +edition.workspace = true + +[dependencies] +anyhow = { workspace = true } +spin-core = { path = "../core" } +tokio = { workspace = true } diff --git a/crates/wasi-async/src/lib.rs b/crates/wasi-async/src/lib.rs new file mode 100644 index 000000000..baf29e06a --- /dev/null +++ b/crates/wasi-async/src/lib.rs @@ -0,0 +1 @@ +pub mod stream; diff --git a/crates/wasi-async/src/stream.rs b/crates/wasi-async/src/stream.rs new file mode 100644 index 000000000..06e4d283e --- /dev/null +++ b/crates/wasi-async/src/stream.rs @@ -0,0 +1,47 @@ +use spin_core::wasmtime; + +pub fn producer(rx: tokio::sync::mpsc::Receiver) -> StreamProducer { + StreamProducer { rx } +} + +pub struct StreamProducer { + rx: tokio::sync::mpsc::Receiver, +} + +impl wasmtime::component::StreamProducer for StreamProducer { + type Item = T; + + type Buffer = Option; + + fn poll_produce<'a>( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + store: wasmtime::StoreContextMut<'a, D>, + mut destination: wasmtime::component::Destination<'a, Self::Item, Self::Buffer>, + finish: bool, + ) -> std::task::Poll> { + use std::task::Poll; + use wasmtime::component::StreamResult; + + let remaining = destination.remaining(store); + if remaining.is_some_and(|r| r == 0) { + return Poll::Ready(Ok(StreamResult::Completed)); + } + + let recv = self.get_mut().rx.poll_recv(cx); + match recv { + Poll::Ready(None) => Poll::Ready(Ok(StreamResult::Dropped)), + Poll::Pending => { + if finish { + Poll::Ready(Ok(StreamResult::Cancelled)) + } else { + Poll::Pending + } + } + Poll::Ready(Some(row)) => { + destination.set_buffer(Some(row)); + Poll::Ready(Ok(StreamResult::Completed)) + } + } + } +} diff --git a/crates/world/src/conversions.rs b/crates/world/src/conversions.rs index ec977bdfc..53bc78772 100644 --- a/crates/world/src/conversions.rs +++ b/crates/world/src/conversions.rs @@ -3,7 +3,7 @@ use super::*; mod rdbms_types { use super::*; use spin::postgres3_0_0::postgres as pg3; - use spin::postgres4_1_0::postgres as pg4; + use spin::postgres4_2_0::postgres as pg4; impl From for v1::rdbms_types::Column { fn from(value: v2::rdbms_types::Column) -> Self { @@ -422,7 +422,7 @@ mod rdbms_types { mod postgres { use super::*; use spin::postgres3_0_0::postgres as pg3; - use spin::postgres4_1_0::postgres as pg4; + use spin::postgres4_2_0::postgres as pg4; impl From for v1::postgres::RowSet { fn from(value: pg4::RowSet) -> v1::postgres::RowSet { diff --git a/crates/world/src/lib.rs b/crates/world/src/lib.rs index 12ef9507b..f89dccafd 100644 --- a/crates/world/src/lib.rs +++ b/crates/world/src/lib.rs @@ -35,7 +35,7 @@ wasmtime::component::bindgen!({ "fermyon:spin/sqlite.error" => v1::sqlite::Error, "fermyon:spin/variables@2.0.0.error" => v2::variables::Error, "spin:postgres/postgres@3.0.0.error" => spin::postgres3_0_0::postgres::Error, - "spin:postgres/postgres@4.1.0.error" => spin::postgres4_1_0::postgres::Error, + "spin:postgres/postgres@4.2.0.error" => spin::postgres4_2_0::postgres::Error, "spin:sqlite/sqlite.error" => spin::sqlite::sqlite::Error, "wasi:config/store@0.2.0-draft-2024-09-27.error" => wasi::config::store::Error, "wasi:keyvalue/store.error" => wasi::keyvalue::store::Error, @@ -67,7 +67,7 @@ impl spin::sqlite::sqlite::Value { } } -impl spin::postgres4_1_0::postgres::DbValue { +impl spin::postgres4_2_0::postgres::DbValue { pub fn memory_size(&self) -> usize { match self { Self::DbNull diff --git a/wit/deps/spin-postgres@4.2.0/postgres.wit b/wit/deps/spin-postgres@4.2.0/postgres.wit new file mode 100644 index 000000000..093e8a125 --- /dev/null +++ b/wit/deps/spin-postgres@4.2.0/postgres.wit @@ -0,0 +1,184 @@ +package spin:postgres@4.2.0; + +interface postgres { + /// Errors related to interacting with a database. + variant error { + connection-failed(string), + bad-parameter(string), + query-failed(query-error), + value-conversion-failed(string), + other(string) + } + + variant query-error { + /// An error occurred but we do not have structured info for it + text(string), + /// Postgres returned a structured database error + db-error(db-error), + } + + record db-error { + /// Stringised version of the error. This is primarily to facilitate migration of older code. + as-text: string, + severity: string, + code: string, + message: string, + detail: option, + /// Any error information provided by Postgres and not captured above. + extras: list>, + } + + /// Data types for a database column + variant db-data-type { + boolean, + int8, + int16, + int32, + int64, + floating32, + floating64, + str, + binary, + date, + time, + datetime, + timestamp, + uuid, + jsonb, + decimal, + range-int32, + range-int64, + range-decimal, + array-int32, + array-int64, + array-decimal, + array-str, + interval, + other(string), + } + + /// Database values + variant db-value { + boolean(bool), + int8(s8), + int16(s16), + int32(s32), + int64(s64), + floating32(f32), + floating64(f64), + str(string), + binary(list), + date(tuple), // (year, month, day) + time(tuple), // (hour, minute, second, nanosecond) + /// Date-time types are always treated as UTC (without timezone info). + /// The instant is represented as a (year, month, day, hour, minute, second, nanosecond) tuple. + datetime(tuple), + /// Unix timestamp (seconds since epoch) + timestamp(s64), + uuid(string), + jsonb(list), + decimal(string), // I admit defeat. Base 10 + range-int32(tuple>, option>>), + range-int64(tuple>, option>>), + range-decimal(tuple>, option>>), + array-int32(list>), + array-int64(list>), + array-decimal(list>), + array-str(list>), + interval(interval), + db-null, + unsupported(list), + } + + /// Values used in parameterized queries + variant parameter-value { + boolean(bool), + int8(s8), + int16(s16), + int32(s32), + int64(s64), + floating32(f32), + floating64(f64), + str(string), + binary(list), + date(tuple), // (year, month, day) + time(tuple), // (hour, minute, second, nanosecond) + /// Date-time types are always treated as UTC (without timezone info). + /// The instant is represented as a (year, month, day, hour, minute, second, nanosecond) tuple. + datetime(tuple), + /// Unix timestamp (seconds since epoch) + timestamp(s64), + uuid(string), + jsonb(list), + decimal(string), // base 10 + range-int32(tuple>, option>>), + range-int64(tuple>, option>>), + range-decimal(tuple>, option>>), + array-int32(list>), + array-int64(list>), + array-decimal(list>), + array-str(list>), + interval(interval), + db-null, + } + + record interval { + micros: s64, + days: s32, + months: s32, + } + + /// A database column + record column { + name: string, + data-type: db-data-type, + } + + /// A database row + type row = list; + + /// A set of database rows + record row-set { + columns: list, + rows: list, + } + + /// For range types, indicates if each bound is inclusive or exclusive + enum range-bound-kind { + inclusive, + exclusive, + } + + @since(version = 4.1.0) + resource connection-builder { + constructor(address: string); + set-ca-root: func(certificate: string) -> result<_, error>; + build: func() -> result; + @since(version = 4.2.0) + build-async: async func() -> result; + } + + /// A connection to a postgres database. + resource connection { + /// Open a connection to the Postgres instance at `address`. + open: static func(address: string) -> result; + + /// Open a connection to the Postgres instance at `address`. + @since(version = 4.2.0) + open-async: static async func(address: string) -> result; + + /// Query the database. + query: func(statement: string, params: list) -> result; + + /// Query the database. + @since(version = 4.2.0) + query-async: async func(statement: string, params: list) -> result, stream, future>>, error>; + + /// Execute command to the database. + execute: func(statement: string, params: list) -> result; + + /// Execute command to the database. + @since(version = 4.2.0) + execute-async: async func(statement: string, params: list) -> result; + } +} diff --git a/wit/world.wit b/wit/world.wit index 0d28a32a2..820e48f25 100644 --- a/wit/world.wit +++ b/wit/world.wit @@ -20,7 +20,7 @@ world platform { include fermyon:spin/platform@2.0.0; include wasi:keyvalue/imports@0.2.0-draft2; import spin:postgres/postgres@3.0.0; - import spin:postgres/postgres@4.1.0; + import spin:postgres/postgres@4.2.0; import spin:sqlite/sqlite@3.0.0; import wasi:config/store@0.2.0-draft-2024-09-27; }