Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ spin-factor-outbound-networking = { path = "../factor-outbound-networking" }
spin-factors = { path = "../factors" }
spin-locked-app = { path = "../locked-app" }
spin-resource-table = { path = "../table" }
spin-wasi-async = { path = "../wasi-async" }
spin-world = { path = "../world" }
tokio = { workspace = true, features = ["rt-multi-thread"] }
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-serde_json-1", "with-uuid-1"] }
Expand Down
71 changes: 71 additions & 0 deletions crates/factor-outbound-pg/src/allowed_hosts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::sync::Arc;

use spin_factor_outbound_networking::config::allowed_hosts::OutboundAllowedHosts;
use spin_world::spin::postgres4_2_0::postgres::{self as v4};

/// Encapsulates checking of a PostgreSQL address/connection string against
/// an allow-list.
///
/// This is broken out as a distinct object to allow it to be synchronously retrieved
/// within a P3 Accessor block and then asynchronously queried outside the block.
#[derive(Clone)]
pub(crate) struct AllowedHostChecker {
allowed_hosts: Arc<OutboundAllowedHosts>,
}

impl AllowedHostChecker {
pub fn new(allowed_hosts: OutboundAllowedHosts) -> Self {
Self {
allowed_hosts: Arc::new(allowed_hosts),
}
}

#[allow(clippy::result_large_err)]
pub async fn ensure_address_allowed(&self, address: &str) -> Result<(), v4::Error> {
fn conn_failed(message: impl Into<String>) -> v4::Error {
v4::Error::ConnectionFailed(message.into())
}
fn err_other(err: anyhow::Error) -> v4::Error {
v4::Error::Other(err.to_string())
}

let config = address
.parse::<tokio_postgres::Config>()
.map_err(|e| conn_failed(e.to_string()))?;

for (i, host) in config.get_hosts().iter().enumerate() {
match host {
tokio_postgres::config::Host::Tcp(address) => {
let ports = config.get_ports();
// The port we use is either:
// * The port at the same index as the host
// * The first port if there is only one port
let port = ports.get(i).or_else(|| {
if ports.len() == 1 {
ports.first()
} else {
None
}
});
let port_str = port.map(|p| format!(":{p}")).unwrap_or_default();
let url = format!("{address}{port_str}");
if !self
.allowed_hosts
.check_url(&url, "postgres")
.await
.map_err(err_other)?
{
return Err(conn_failed(format!(
"address postgres://{url} is not permitted"
)));
}
}
#[cfg(unix)]
tokio_postgres::config::Host::Unix(_) => {
return Err(conn_failed("Unix sockets are not supported on WebAssembly"));
}
}
}
Ok(())
}
}
157 changes: 135 additions & 22 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use std::sync::Arc;

use anyhow::{Context, Result};
use futures::stream::TryStreamExt as _;
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::spin::postgres4_1_0::postgres::{
use spin_world::spin::postgres4_2_0::postgres::{
self as v4, Column, DbValue, ParameterValue, RowSet,
};
use std::pin::pin;
use tokio_postgres::config::SslMode;
use tokio_postgres::types::ToSql;
use tokio_postgres::{NoTls, Row};

use crate::types::{convert_data_type, convert_entry, to_sql_parameter};
use crate::types::{convert_data_type, convert_entry, to_sql_parameter, to_sql_parameters};

/// Max connections in a given address' connection pool
const CONNECTION_POOL_SIZE: usize = 64;
Expand Down Expand Up @@ -61,9 +62,18 @@ impl Default for PooledTokioClientFactory {
}
}

#[derive(Clone)]
pub struct PooledTokioClient(Arc<deadpool_postgres::Object>);

impl AsRef<deadpool_postgres::Object> for PooledTokioClient {
fn as_ref(&self) -> &deadpool_postgres::Object {
self.0.as_ref()
}
}

#[async_trait]
impl ClientFactory for PooledTokioClientFactory {
type Client = deadpool_postgres::Object;
type Client = PooledTokioClient;

async fn get_client(
&self,
Expand All @@ -81,7 +91,7 @@ impl ClientFactory for PooledTokioClientFactory {
.map_err(ArcError)
.context("establishing PostgreSQL connection pool")?;

Ok(pool.get().await?)
Ok(PooledTokioClient(Arc::new(pool.get().await?)))
}
}

Expand Down Expand Up @@ -123,7 +133,7 @@ fn create_connection_pool(
}

#[async_trait]
pub trait Client: Send + Sync + 'static {
pub trait Client: Clone + Send + Sync + 'static {
async fn execute(
&self,
statement: String,
Expand All @@ -136,6 +146,19 @@ pub trait Client: Send + Sync + 'static {
params: Vec<ParameterValue>,
max_result_bytes: usize,
) -> Result<RowSet, v4::Error>;

async fn query_async(
&self,
statement: String,
params: Vec<ParameterValue>,
max_result_bytes: usize,
) -> Result<QueryAsyncResult, v4::Error>;
}

pub struct QueryAsyncResult {
pub columns: Vec<v4::Column>,
pub rows: tokio::sync::mpsc::Receiver<v4::Row>,
pub error: tokio::sync::oneshot::Receiver<Result<(), v4::Error>>,
}

/// Extract weak-typed error data for WIT purposes
Expand Down Expand Up @@ -180,8 +203,13 @@ fn query_failed(e: tokio_postgres::error::Error) -> v4::Error {
v4::Error::QueryFailed(query_error)
}

fn query_failed_anyhow(e: anyhow::Error) -> v4::Error {
let text = format!("{e:?}");
v4::Error::QueryFailed(v4::QueryError::Text(text))
}

#[async_trait]
impl Client for deadpool_postgres::Object {
impl Client for PooledTokioClient {
async fn execute(
&self,
statement: String,
Expand Down Expand Up @@ -210,34 +238,21 @@ impl Client for deadpool_postgres::Object {
params: Vec<ParameterValue>,
max_result_bytes: usize,
) -> Result<RowSet, v4::Error> {
let params = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v4::Error::BadParameter(format!("{e:?}")))?;

let mut results = pin!(self
.as_ref()
.query_raw(&statement, params)
.await
.map_err(query_failed)?);
let (cols_fut, mut results) = self.query_stream(statement, params).await?;

let mut columns = None;
let mut byte_count = std::mem::size_of::<RowSet>();
let mut rows = Vec::new();

async {
while let Some(row) = results.try_next().await? {
if columns.is_none() {
columns = Some(infer_columns(&row));
}
let row = convert_row(&row)?;
byte_count += row.iter().map(|v| v.memory_size()).sum::<usize>();
if byte_count > max_result_bytes {
anyhow::bail!("query result exceeds limit of {max_result_bytes} bytes")
}
rows.push(row);
}
columns = Some(cols_fut.await);
Ok(())
}
.await
Expand All @@ -248,6 +263,104 @@ impl Client for deadpool_postgres::Object {
rows,
})
}

async fn query_async(
&self,
statement: String,
params: Vec<ParameterValue>,
max_result_bytes: usize,
) -> Result<QueryAsyncResult, v4::Error> {
use futures::StreamExt;

let (cols_fut, mut rows) = self.query_stream(statement, params).await?;

let (rows_tx, rows_rx) = tokio::sync::mpsc::channel(4);
let (err_tx, err_rx) = tokio::sync::oneshot::channel();

tokio::spawn(async move {
loop {
let Some(row) = rows.next().await else {
_ = err_tx.send(Ok(()));
return;
};
match row {
Ok(row) => {
let byte_count = row.iter().map(|v| v.memory_size()).sum::<usize>();
if byte_count > max_result_bytes {
_ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
format!("query result exceeds limit of {max_result_bytes} bytes"),
))));
return;
}

if let Err(e) = rows_tx.send(row).await {
_ = err_tx.send(Err(v4::Error::QueryFailed(v4::QueryError::Text(
format!("async error: {e}"),
))));
return;
}
}
Err(e) => {
_ = err_tx.send(Err(e));
return;
}
}
}
});

let cols = cols_fut.await;

Ok(QueryAsyncResult {
columns: cols,
rows: rows_rx,
error: err_rx,
})
}
}

impl PooledTokioClient {
async fn query_stream(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<
(
impl std::future::Future<Output = Vec<v4::Column>>,
impl futures::Stream<Item = Result<Vec<DbValue>, v4::Error>>,
),
v4::Error,
> {
use futures::{FutureExt, StreamExt};

let params = to_sql_parameters(params)?;

let results = Box::pin(
self.as_ref()
.query_raw(&statement, params)
.await
.map_err(query_failed)?,
);

let (cols_tx, cols_rx) = tokio::sync::oneshot::channel();
let mut cols_tx_opt = Some(cols_tx);

let row_stm = results.enumerate().map(move |(index, row_res)| {
let row_res = row_res.map_err(query_failed);
row_res.and_then(|r| {
if index == 0 {
if let Some(cols_tx) = cols_tx_opt.take() {
let cols = infer_columns(&r);
_ = cols_tx.send(cols);
}
}
convert_row(&r).map_err(query_failed_anyhow)
})
});

let cols_rx = cols_rx.map(|result| result.unwrap_or_default());

Ok((cols_rx, Box::pin(row_stm)))
}
}

fn infer_columns(row: &Row) -> Vec<Column> {
Expand Down
Loading