Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion native/core/src/execution/shuffle/spark_unsafe/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ use crate::{
errors::CometError,
execution::shuffle::spark_unsafe::{
map::append_map_elements,
row::{append_field, downcast_builder_ref, SparkUnsafeObject, SparkUnsafeRow},
row::{
append_field, downcast_builder_ref, impl_primitive_accessors, SparkUnsafeObject,
SparkUnsafeRow,
},
},
};
use arrow::array::{
Expand Down Expand Up @@ -101,6 +104,10 @@ impl SparkUnsafeObject for SparkUnsafeArray {
fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 {
(self.element_offset + (index * element_size) as i64) as *const u8
}

// SparkUnsafeArray base address may be unaligned when nested within a row's variable-length
// region, so we must use ptr::read_unaligned() for all typed accesses.
impl_primitive_accessors!(read_unaligned);
}

impl SparkUnsafeArray {
Expand Down
226 changes: 124 additions & 102 deletions native/core/src/execution/shuffle/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,32 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
/// safe to call as long as:
/// - The index is within bounds (caller's responsibility)
/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
///
/// # Alignment
///
/// Primitive accessor methods are implemented separately for each type because they have
/// different alignment guarantees:
/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8,
/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`.
/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's
/// variable-length region, so accessors use `ptr::read_unaligned()`.
pub trait SparkUnsafeObject {
/// Returns the address of the row.
fn get_row_addr(&self) -> i64;

/// Returns the offset of the element at the given index.
fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8;

fn get_boolean(&self, index: usize) -> bool;
fn get_byte(&self, index: usize) -> i8;
fn get_short(&self, index: usize) -> i16;
fn get_int(&self, index: usize) -> i32;
fn get_long(&self, index: usize) -> i64;
fn get_float(&self, index: usize) -> f32;
fn get_double(&self, index: usize) -> f64;
fn get_date(&self, index: usize) -> i32;
fn get_timestamp(&self, index: usize) -> i64;

/// Returns the offset and length of the element at the given index.
#[inline]
fn get_offset_and_len(&self, index: usize) -> (i32, i32) {
Expand All @@ -87,79 +106,6 @@ pub trait SparkUnsafeObject {
(offset, len)
}

/// Returns boolean value at the given index of the object.
#[inline]
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
// The caller ensures index is within bounds.
debug_assert!(
!addr.is_null(),
"get_boolean: null pointer at index {index}"
);
unsafe { *addr != 0 }
}

/// Returns byte value at the given index of the object.
#[inline]
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
i8::from_le_bytes(slice.try_into().unwrap())
}

/// Returns short value at the given index of the object.
#[inline]
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2);
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
i16::from_le_bytes(slice.try_into().unwrap())
}

/// Returns integer value at the given index of the object.
#[inline]
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns long value at the given index of the object.
#[inline]
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns float value at the given index of the object.
#[inline]
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
f32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns double value at the given index of the object.
#[inline]
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
f64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns string value at the given index of the object.
fn get_string(&self, index: usize) -> &str {
let (offset, len) = self.get_offset_and_len(index);
Expand Down Expand Up @@ -190,29 +136,6 @@ pub trait SparkUnsafeObject {
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
}

/// Returns date value at the given index of the object.
#[inline]
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns timestamp value at the given index of the object.
#[inline]
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
debug_assert!(
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns decimal value at the given index of the object.
fn get_decimal(&self, index: usize, precision: u8) -> i128 {
if precision <= MAX_LONG_DIGITS {
Expand Down Expand Up @@ -244,6 +167,94 @@ pub trait SparkUnsafeObject {
}
}

/// Generates primitive accessor implementations for `SparkUnsafeObject`.
///
/// Uses `$read_method` to read typed values from raw pointers:
/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned)
/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray)
macro_rules! impl_primitive_accessors {
($read_method:ident) => {
#[inline]
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
debug_assert!(
!addr.is_null(),
"get_boolean: null pointer at index {index}"
);
// SAFETY: addr points to valid element data within the row/array region.
unsafe { *addr != 0 }
}

#[inline]
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}");
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
unsafe { *(addr as *const i8) }
}

#[inline]
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2) as *const i16;
debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}");
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4) as *const i32;
debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8) as *const i64;
debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}");
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4) as *const f32;
debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8) as *const f64;
debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}");
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4) as *const i32;
debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}");
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
unsafe { addr.$read_method() }
}

#[inline]
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8) as *const i64;
debug_assert!(
!addr.is_null(),
"get_timestamp: null pointer at index {index}"
);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
unsafe { addr.$read_method() }
}
};
}
pub(crate) use impl_primitive_accessors;

pub struct SparkUnsafeRow {
row_addr: i64,
row_size: i32,
Expand All @@ -265,6 +276,11 @@ impl SparkUnsafeObject for SparkUnsafeRow {
);
(self.row_addr + offset) as *const u8
}

// SparkUnsafeRow field offsets are always 8-byte aligned: the base address is 8-byte
// aligned (JVM guarantee), bitset_width is a multiple of 8, and each field slot is
// 8 bytes. This means we can safely use aligned ptr::read() for all typed accesses.
impl_primitive_accessors!(read);
}

impl Default for SparkUnsafeRow {
Expand Down Expand Up @@ -328,11 +344,13 @@ impl SparkUnsafeRow {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
// The bitset starts at row_addr (8-byte aligned) and each word is at offset 8*k,
// so word_offset is always 8-byte aligned — we can use aligned ptr::read().
debug_assert!(self.row_addr != -1, "is_null_at: row not initialized");
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
let word: i64 = word_offset.read_unaligned();
let word: i64 = word_offset.read();
(word & mask) != 0
}
}
Expand All @@ -343,12 +361,13 @@ impl SparkUnsafeRow {
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
// Writing is safe because we have mutable access and the memory is owned by the JVM.
// The bitset is always 8-byte aligned — we can use aligned ptr::read()/write().
debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized");
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
let word: i64 = word_offset.read_unaligned();
word_offset.write_unaligned(word & !mask);
let word: i64 = word_offset.read();
word_offset.write(word & !mask);
}
}
}
Expand Down Expand Up @@ -1668,9 +1687,12 @@ mod test {
let mut row = SparkUnsafeRow::new_with_num_fields(1);
// 8 bytes null bitset + 8 bytes field value = 16 bytes
// Set bit 0 in the null bitset to mark field 0 as null
let mut data = [0u8; 16];
data[0] = 1;
row.point_to_slice(&data);
// Use aligned buffer to match real Spark UnsafeRow layout (8-byte aligned)
#[repr(align(8))]
struct Aligned([u8; 16]);
let mut data = Aligned([0u8; 16]);
data.0[0] = 1;
row.point_to_slice(&data.0);
append_field(&data_type, &mut struct_builder, &row, 0).expect("append field");
struct_builder.append_null();
let struct_array = struct_builder.finish();
Expand Down
Loading