diff --git a/benchmarks/tpc/engines/comet-gracejoin.toml b/benchmarks/tpc/engines/comet-gracejoin.toml new file mode 100644 index 0000000000..ee756abaf1 --- /dev/null +++ b/benchmarks/tpc/engines/comet-gracejoin.toml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[engine] +name = "comet-gracejoin" + +[env] +required = ["COMET_JAR"] + +[spark_submit] +jars = ["$COMET_JAR"] +driver_class_path = ["$COMET_JAR"] + +[spark_conf] +"spark.driver.extraClassPath" = "$COMET_JAR" +"spark.executor.extraClassPath" = "$COMET_JAR" +"spark.plugins" = "org.apache.spark.CometPlugin" +"spark.shuffle.manager" = "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager" +"spark.comet.scan.impl" = "native_datafusion" +"spark.comet.exec.replaceSortMergeJoin" = "true" +"spark.comet.exec.replaceSortMergeJoin.maxBuildSize" = "104857600" +"spark.comet.exec.graceHashJoin.fastPathThreshold" = "34359738368" +"spark.executor.cores" = "8" +"spark.comet.expression.Cast.allowIncompatible" = "true" diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 41b69952a7..25b63335be 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -305,6 +305,29 @@ object CometConf extends ShimCometConf { val COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("localTableScan", defaultValue = false) + val COMET_EXEC_GRACE_HASH_JOIN_NUM_PARTITIONS: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.graceHashJoin.numPartitions") + .category(CATEGORY_EXEC) + .doc("The number of partitions (buckets) to use for Grace Hash Join. A higher number " + + "reduces the size of each partition but increases overhead.") + .intConf + .checkValue(v => v > 0, "The number of partitions must be positive.") + .createWithDefault(16) + + val COMET_EXEC_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.graceHashJoin.fastPathThreshold") + .category(CATEGORY_EXEC) + .doc( + "Total memory budget in bytes for Grace Hash Join fast-path hash tables across " + + "all concurrent tasks. This is divided by spark.executor.cores to get the per-task " + + "threshold. When a build side fits in memory and is smaller than the per-task " + + "threshold, the join executes as a single HashJoinExec without spilling. " + + "Set to 0 to disable the fast path. Larger values risk OOM because HashJoinExec " + + "creates non-spillable hash tables.") + .intConf + .checkValue(v => v >= 0, "The fast path threshold must be non-negative.") + .createWithDefault(10 * 1024 * 1024) // 10 MB + val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled") .category(CATEGORY_EXEC) @@ -381,6 +404,18 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_REPLACE_SMJ_MAX_BUILD_SIZE: ConfigEntry[Long] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.maxBuildSize") + .category(CATEGORY_EXEC) + .doc( + "Maximum estimated size in bytes of the build side for replacing SortMergeJoin " + + "with ShuffledHashJoin. When the build side's logical plan statistics exceed this " + + "threshold, the SortMergeJoin is kept because sort-merge join's streaming merge " + + "on pre-sorted data outperforms hash join's per-task hash table construction " + + "for large build sides. Set to -1 to disable this check and always replace.") + .longConf + .createWithDefault(-1L) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/docs/source/contributor-guide/grace-hash-join-design.md b/docs/source/contributor-guide/grace-hash-join-design.md new file mode 100644 index 0000000000..9e7cf01531 --- /dev/null +++ b/docs/source/contributor-guide/grace-hash-join-design.md @@ -0,0 +1,293 @@ + + +# Grace Hash Join Design Document + +## Overview + +Grace Hash Join (GHJ) is the hash join implementation in Apache DataFusion Comet. When `spark.comet.exec.replaceSortMergeJoin` is enabled, Comet's `RewriteJoin` rule converts `SortMergeJoinExec` to `ShuffledHashJoinExec` (removing the input sorts), and all `ShuffledHashJoinExec` operators are then executed natively as `GraceHashJoinExec`. + +GHJ partitions both build and probe sides into N buckets by hashing join keys, then joins each bucket independently. When memory is tight, partitions spill to disk using Arrow IPC format. A fast path skips partitioning entirely when the build side is small enough. + +Supports all join types: Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark, RightSemi, RightAnti, RightMark. + +## Configuration + +| Config Key | Type | Default | Description | +| --- | --- | --- | --- | +| `spark.comet.exec.replaceSortMergeJoin` | boolean | `false` | Replace SortMergeJoin with ShuffledHashJoin (enables GHJ) | +| `spark.comet.exec.replaceSortMergeJoin.maxBuildSize` | long | `-1` | Max build-side bytes for SMJ replacement. `-1` = no limit | +| `spark.comet.exec.graceHashJoin.numPartitions` | int | `16` | Number of hash partitions (buckets) | +| `spark.comet.exec.graceHashJoin.fastPathThreshold` | int | `10485760` | Total fast-path budget in bytes, divided by executor cores | + +### SMJ Replacement Guard + +The `RewriteJoin` rule checks `maxBuildSize` against Spark's logical plan statistics before replacing a `SortMergeJoinExec`. When both sides are large (e.g., TPC-DS q72's `catalog_sales JOIN inventory`), sort-merge join's streaming merge on pre-sorted data outperforms hash join's per-task hash table construction. Setting `maxBuildSize` (e.g., `104857600` for 100 MB) keeps SMJ for these cases. + +### Fast Path Threshold + +The configured threshold is the total budget across all concurrent tasks on the executor. The planner divides it by `spark.executor.cores` so each task's fast-path hash table stays within its fair share. For example, with a 32 GB threshold and 8 cores, each task gets a 4 GB per-task limit. + +## Architecture + +### Plan Integration + +``` +SortMergeJoinExec + -> RewriteJoin converts to ShuffledHashJoinExec (removes input sorts) + -> CometExecRule wraps as CometHashJoinExec + -> CometHashJoinExec.createExec() creates CometGraceHashJoinExec + -> Serialized to protobuf via JNI + -> PhysicalPlanner (Rust) creates GraceHashJoinExec +``` + +### Key Data Structures + +``` +GraceHashJoinExec ExecutionPlan implementation ++-- left/right Child input plans ++-- on Join key pairs [(left_key, right_key)] ++-- filter Optional post-join filter ++-- join_type Inner/Left/Right/Full/Semi/Anti/Mark ++-- num_partitions Number of hash buckets (default 16) ++-- build_left Whether left input is the build side ++-- fast_path_threshold Per-task threshold for fast path (0 = disabled) ++-- schema Output schema + +HashPartition Per-bucket state during partitioning ++-- build_batches In-memory build-side RecordBatches ++-- probe_batches In-memory probe-side RecordBatches ++-- build_spill_writer Optional SpillWriter for build data ++-- probe_spill_writer Optional SpillWriter for probe data ++-- build_mem_size Tracked memory for build side ++-- probe_mem_size Tracked memory for probe side + +FinishedPartition State after spill writers are closed ++-- build_batches In-memory build batches (if not spilled) ++-- probe_batches In-memory probe batches (if not spilled) ++-- build_spill_file Temp file for spilled build data ++-- probe_spill_file Temp file for spilled probe data +``` + +## Execution Flow + +``` +execute() + | + +- Phase 1: Partition build side + | Hash-partition all build input into N buckets. + | Spill the largest bucket on memory pressure. + | + +- Phase 2: Partition probe side + | Hash-partition probe input into N buckets. + | Spill ALL non-spilled buckets on first memory pressure. + | + +- Decision: fast path or slow path? + | If no spilling occurred and total build size <= per-task threshold: + | -> Fast path: single HashJoinExec, stream probe directly + | Otherwise: + | -> Slow path: merge partitions, join sequentially + | + +- Phase 3 (slow path): Join each partition sequentially + Merge adjacent partitions to ~32 MB build-side groups. + For each group, create a per-partition HashJoinExec. + Spilled probes use streaming SpillReaderExec. + Oversized builds trigger recursive repartitioning. +``` + +### Fast Path + +After partitioning both sides, GHJ checks whether the build side is small enough to join in a single `HashJoinExec`: + +1. No partitions were spilled during Phases 1 or 2 +2. The fast path threshold is non-zero +3. The actual build-side memory (measured via `get_array_memory_size()`) is within the per-task threshold + +When all conditions are met, GHJ concatenates all build-side batches, wraps the probe stream in a `StreamSourceExec`, and creates a single `HashJoinExec` with `CollectLeft` mode. The probe side streams directly through without buffering. This avoids the overhead of partition merging and sequential per-partition joins. + +The fast path threshold is intentionally conservative because `HashJoinExec` creates non-spillable hash tables (`can_spill: false`). The per-task division ensures that concurrent tasks don't collectively exceed memory. + +### Phase 1: Build-Side Partitioning + +For each incoming batch from the build input: + +1. Evaluate join key expressions and compute hash values +2. Assign each row to a partition: `partition_id = hash % num_partitions` +3. Use the prefix-sum algorithm to efficiently extract contiguous row groups per partition via `arrow::compute::take()` +4. For each partition's sub-batch: + - If the partition is already spilled, append to its `SpillWriter` + - Otherwise, call `reservation.try_grow(batch_size)` + - On failure: spill the largest non-spilled partition, retry + - If still fails: spill this partition and write to disk + +All in-memory build data is tracked in a shared `MemoryReservation` registered as `can_spill: true`, making GHJ a cooperative citizen in DataFusion's memory pool. + +### Phase 2: Probe-Side Partitioning + +Same hash-partitioning algorithm as Phase 1, with key differences: + +1. **Spilled build implies spilled probe**: If a partition's build side was spilled, the probe side is also spilled. Both sides must be on disk (or both in memory) for the join phase. + +2. **Aggressive spilling**: On the first memory pressure event, all non-spilled partitions are spilled (both build and probe sides). This prevents thrashing between spilling and accumulating when multiple concurrent GHJ instances share a memory pool. + +3. **Shared reservation**: The same `MemoryReservation` from Phase 1 continues to track probe-side memory. + +### Phase 3: Per-Partition Joins (Slow Path) + +Before joining, adjacent `FinishedPartition`s are merged so each group has roughly `TARGET_PARTITION_BUILD_SIZE` (32 MB) of build data. This reduces the number of `HashJoinExec` invocations while keeping each hash table small. + +Merged groups are joined sequentially — one at a time — so only one `HashJoinInput` consumer exists at any moment. The GHJ reservation is freed before Phase 3 begins; each per-partition `HashJoinExec` tracks its own memory. + +**In-memory partitions** are joined via `join_partition_recursive()`: + +- Concatenate build and probe sub-batches +- Create `HashJoinExec` with both sides as `MemorySourceConfig` +- If the build side is too large for a hash table: recursively repartition (up to `MAX_RECURSION_DEPTH = 3`, yielding up to 16^3 = 4096 effective partitions) + +**Spilled partitions** are joined via `join_with_spilled_probe()`: + +- Build side loaded from memory or disk via `spawn_blocking` +- Probe side streamed via `SpillReaderExec` (never fully loaded into memory) +- If the build side is too large: fall back to eager probe read + recursive repartitioning + +## Spill Mechanism + +### Writing + +`SpillWriter` wraps Arrow IPC `StreamWriter` for incremental appends: + +- Uses `BufWriter` with 1 MB buffer (vs 8 KB default) for sequential throughput +- Batches are appended one at a time — no need to rewrite the file +- `finish()` flushes the writer and returns the `RefCountedTempFile` + +Temp files are created via DataFusion's `DiskManager`, which handles allocation and cleanup. + +### Reading + +Two read paths depending on context: + +**Eager read** (`read_spilled_batches`): Opens file, reads all batches into `Vec`. Used for build-side spill files bounded by `TARGET_PARTITION_BUILD_SIZE`. + +**Streaming read** (`SpillReaderExec`): An `ExecutionPlan` that reads batches on-demand: + +- Spawns a `tokio::task::spawn_blocking` to read from the file on a blocking thread pool +- Uses an `mpsc` channel (capacity 4) to feed batches to the async executor +- Coalesces small sub-batches into ~8192-row chunks before sending, reducing per-batch overhead in the downstream hash join kernel +- The `RefCountedTempFile` handle is moved into the blocking closure to keep the file alive until reading completes + +### Spill Coalescing + +Hash-partitioning creates N sub-batches per input batch. With N=16 partitions and 1000-row input batches, spill files contain ~62-row sub-batches. `SpillReaderExec` coalesces these into ~8192-row batches on read, reducing channel send/recv overhead, hash join kernel invocations, and per-batch `RecordBatch` construction costs. + +## Memory Management + +### Reservation Model + +GHJ uses a single `MemoryReservation` registered as a spillable consumer (`with_can_spill(true)`). This reservation: + +- Tracks all in-memory build and probe data across all partitions during Phases 1 and 2 +- Grows via `try_grow()` before each batch is added to memory +- Shrinks via `shrink()` when partitions are spilled to disk +- Is freed before Phase 3, where each per-partition `HashJoinExec` tracks its own memory via `HashJoinInput` + +### Concurrent Instances + +In a typical Spark executor, multiple tasks run concurrently, each potentially executing a GHJ. All instances share the same DataFusion memory pool. The "spill ALL non-spilled partitions" strategy in Phase 2 makes each instance's spill decision atomic — once triggered, the instance moves all its data to disk in one operation, preventing interleaving with other instances that would otherwise claim freed memory immediately. + +### DataFusion Memory Pool Integration + +DataFusion's memory pool (typically `FairSpillPool`) divides memory between spillable and non-spillable consumers. GHJ registers as spillable so the pool can account for its memory when computing fair shares. The per-partition `HashJoinExec` instances in Phase 3 use non-spillable `HashJoinInput` reservations, but since partitions are joined sequentially, only one hash table exists at a time, keeping peak memory at roughly `build_size / num_partitions`. + +## Hash Partitioning Algorithm + +### Prefix-Sum Approach + +Instead of N separate `take()` kernel calls (one per partition), GHJ uses a prefix-sum algorithm: + +1. **Hash**: Compute hash values for all rows +2. **Assign**: Map each row to a partition: `partition_id = hash % N` +3. **Count**: Count rows per partition +4. **Prefix-sum**: Accumulate counts into start offsets +5. **Scatter**: Place row indices into contiguous regions per partition +6. **Take**: Single `arrow::compute::take()` per partition using the precomputed indices + +This is O(rows) with good cache locality, compared to O(rows x partitions) for the naive approach. + +### Hash Seed Variation + +GHJ hashes on the same join keys that Spark already used for its shuffle exchange, but with a different hash function (ahash via `RandomState` with fixed seeds). Spark's shuffle uses Murmur3, so all rows arriving at a given Spark partition share the same `murmur3(key) % num_spark_partitions` value but have diverse actual key values. GHJ's ahash produces a completely different distribution. + +At each recursion level, a different random seed is used: + +```rust +fn partition_random_state(recursion_level: usize) -> RandomState { + RandomState::with_seeds( + 0x517cc1b727220a95 ^ (recursion_level as u64), + 0x3a8b7c9d1e2f4056, 0, 0, + ) +} +``` + +This ensures rows that hash to the same partition at level 0 are distributed across different sub-partitions at level 1. The only case where repartitioning cannot help is true data skew — many rows with the same key value. No amount of rehashing can separate identical keys, which is why there is a `MAX_RECURSION_DEPTH = 3` limit. + +## Recursive Repartitioning + +When a partition's build side is too large for a hash table (tested via `try_grow(build_size * 3)`, where the 3x accounts for hash table overhead), GHJ recursively repartitions: + +1. Sub-partition both build and probe into 16 new buckets using a different hash seed +2. Recursively join each sub-partition +3. Maximum depth: 3 (yielding up to 16^3 = 4096 effective partitions) +4. If still too large at max depth: return `ResourcesExhausted` error + +## Partition Merging + +After Phase 2, GHJ merges adjacent `FinishedPartition`s to reduce the number of per-partition `HashJoinExec` invocations. The target is `TARGET_PARTITION_BUILD_SIZE` (32 MB) per merged group. For example, with 16 partitions and 200 MB total build data, partitions are merged into ~6 groups of ~32 MB each instead of 16 groups of ~12 MB. + +Merging only combines adjacent partitions (preserving hash locality) and never merges spilled with non-spilled partitions. The merge is a metadata-only operation — it combines batch lists and spill file handles without copying data. + +## Build Side Selection + +GHJ respects Spark's build side selection (`BuildLeft` or `BuildRight`). The `build_left` flag determines: + +- Which input is consumed in Phase 1 (build) vs Phase 2 (probe) +- How join key expressions are mapped +- How `HashJoinExec` is constructed (build side is always left in `CollectLeft` mode) + +When `build_left = false`, the `HashJoinExec` is created with swapped inputs and then `swap_inputs()` is called to produce correct output column ordering. + +## Metrics + +| Metric | Description | +| --- | --- | +| `build_time` | Time spent partitioning the build side | +| `probe_time` | Time spent partitioning the probe side | +| `spill_count` | Number of partition spill events | +| `spilled_bytes` | Total bytes written to spill files | +| `build_input_rows` | Total rows from build input | +| `build_input_batches` | Total batches from build input | +| `input_rows` | Total rows from probe input | +| `input_batches` | Total batches from probe input | +| `output_rows` | Total output rows | +| `elapsed_compute` | Total compute time | + +## Future Work + +- **Adaptive partition count**: Dynamically choose the number of partitions based on input size rather than a fixed default +- **Spill file compression**: Compress Arrow IPC data on disk to reduce I/O volume at the cost of CPU +- **Upstream DataFusion spill support**: Contribute spill capability to DataFusion's `HashJoinExec` to eliminate the need for a separate GHJ operator diff --git a/native/Cargo.lock b/native/Cargo.lock index 78fa3fa124..b1c4565890 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -278,6 +278,7 @@ dependencies = [ "arrow-select", "flatbuffers", "lz4_flex", + "zstd", ] [[package]] @@ -1783,6 +1784,7 @@ dependencies = [ name = "datafusion-comet" version = "0.14.0" dependencies = [ + "ahash", "arrow", "assertables", "async-trait", diff --git a/native/Cargo.toml b/native/Cargo.toml index d5a6aeabc9..49bb498f60 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -34,7 +34,7 @@ edition = "2021" rust-version = "1.88" [workspace.dependencies] -arrow = { version = "57.3.0", features = ["prettyprint", "ffi", "chrono-tz"] } +arrow = { version = "57.3.0", features = ["prettyprint", "ffi", "chrono-tz", "ipc_compression"] } async-trait = { version = "0.1" } bytes = { version = "1.11.1" } parquet = { version = "57.3.0", default-features = false, features = ["experimental"] } diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index cbe397b12b..81132fe534 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -35,6 +35,7 @@ include = [ publish = false [dependencies] +ahash = "0.8" arrow = { workspace = true } parquet = { workspace = true, default-features = false, features = ["experimental", "arrow"] } futures = { workspace = true } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1030e30aaf..00591f88fe 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -175,6 +175,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Spark configuration map for comet-specific settings + pub spark_conf: HashMap, } /// Accept serialized query plan and return the address of the native query plan. @@ -322,6 +324,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + spark_conf: spark_config, }); Ok(Box::into_raw(exec_context) as i64) @@ -535,7 +538,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_spark_conf(exec_context.spark_conf.clone()); let (scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/operators/grace_hash_join.rs b/native/core/src/execution/operators/grace_hash_join.rs new file mode 100644 index 0000000000..f749d47114 --- /dev/null +++ b/native/core/src/execution/operators/grace_hash_join.rs @@ -0,0 +1,2625 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Grace Hash Join operator for Apache DataFusion Comet. +//! +//! Partitions both build and probe sides into N buckets by hashing join keys, +//! then performs per-partition hash joins. Spills partitions to disk (Arrow IPC) +//! when memory is tight. +//! +//! Supports all join types. Recursively repartitions oversized partitions +//! up to `MAX_RECURSION_DEPTH` levels. + +use std::any::Any; +use std::fmt; +use std::fs::File; +use std::io::{BufReader, BufWriter}; +use std::sync::Arc; +use std::sync::Mutex; + +use ahash::RandomState; +use arrow::array::UInt32Array; +use arrow::compute::{concat_batches, take}; +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; +use arrow::ipc::CompressionType; +use arrow::record_batch::RecordBatch; +use datafusion::common::hash_utils::create_hashes; +use datafusion::common::{DataFusionError, JoinType, NullEquality, Result as DFResult}; +use datafusion::execution::context::TaskContext; +use datafusion::execution::disk_manager::RefCountedTempFile; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::joins::utils::JoinFilter; +use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; +use futures::stream::{self, StreamExt, TryStreamExt}; +use futures::Stream; +use log::info; +use tokio::sync::mpsc; + +/// Global atomic counter for unique GHJ instance IDs (debug tracing). +static GHJ_INSTANCE_COUNTER: std::sync::atomic::AtomicUsize = + std::sync::atomic::AtomicUsize::new(0); + +/// Type alias for join key expression pairs. +type JoinOnRef<'a> = &'a [(Arc, Arc)]; + +/// Number of partitions (buckets) for the grace hash join. +const DEFAULT_NUM_PARTITIONS: usize = 16; + +/// Maximum recursion depth for repartitioning oversized partitions. +/// At depth 3 with 16 partitions per level, effective partitions = 16^3 = 4096. +const MAX_RECURSION_DEPTH: usize = 3; + +/// I/O buffer size for spill file reads and writes. The default BufReader/BufWriter +/// size (8 KB) is far too small for multi-GB spill files. 1 MB provides good +/// sequential throughput while keeping per-partition memory overhead modest. +const SPILL_IO_BUFFER_SIZE: usize = 1024 * 1024; + +/// Target number of rows per coalesced batch when reading spill files. +/// Spill files contain many tiny sub-batches (from partitioning). Coalescing +/// into larger batches reduces per-batch overhead in the hash join kernel +/// and channel send/recv costs. +const SPILL_READ_COALESCE_TARGET: usize = 8192; + +/// Target build-side size per merged partition. After Phase 2, adjacent +/// `FinishedPartition`s are merged so each group has roughly this much +/// build data, reducing the number of per-partition HashJoinExec calls. +const TARGET_PARTITION_BUILD_SIZE: usize = 32 * 1024 * 1024; + +/// Random state for hashing join keys into partitions. Uses fixed seeds +/// different from DataFusion's HashJoinExec to avoid correlation. +/// The `recursion_level` is XORed into the seed so that recursive +/// repartitioning uses different hash functions at each level. +fn partition_random_state(recursion_level: usize) -> RandomState { + RandomState::with_seeds( + 0x517cc1b727220a95 ^ (recursion_level as u64), + 0x3a8b7c9d1e2f4056, + 0, + 0, + ) +} + +// --------------------------------------------------------------------------- +// SpillWriter: incremental append to Arrow IPC spill files +// --------------------------------------------------------------------------- + +/// Wraps an Arrow IPC `StreamWriter` for incremental spill writes. +/// Avoids the O(n²) read-rewrite pattern by keeping the writer open. +struct SpillWriter { + writer: StreamWriter>, + temp_file: RefCountedTempFile, + bytes_written: usize, +} + +impl SpillWriter { + /// Create a new spill writer backed by a temp file. + fn new(temp_file: RefCountedTempFile, schema: &SchemaRef) -> DFResult { + let file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(temp_file.path()) + .map_err(|e| DataFusionError::Execution(format!("Failed to open spill file: {e}")))?; + let buf_writer = BufWriter::with_capacity(SPILL_IO_BUFFER_SIZE, file); + let write_options = + IpcWriteOptions::default().try_with_compression(Some(CompressionType::LZ4_FRAME))?; + let writer = StreamWriter::try_new_with_options(buf_writer, schema, write_options)?; + Ok(Self { + writer, + temp_file, + bytes_written: 0, + }) + } + + /// Append a single batch to the spill file. + fn write_batch(&mut self, batch: &RecordBatch) -> DFResult<()> { + if batch.num_rows() > 0 { + self.bytes_written += batch.get_array_memory_size(); + self.writer.write(batch)?; + } + Ok(()) + } + + /// Append multiple batches to the spill file. + fn write_batches(&mut self, batches: &[RecordBatch]) -> DFResult<()> { + for batch in batches { + self.write_batch(batch)?; + } + Ok(()) + } + + /// Finish writing. Must be called before reading back. + fn finish(mut self) -> DFResult<(RefCountedTempFile, usize)> { + self.writer.finish()?; + Ok((self.temp_file, self.bytes_written)) + } +} + +// --------------------------------------------------------------------------- +// SpillReaderExec: streaming ExecutionPlan for reading spill files +// --------------------------------------------------------------------------- + +/// An ExecutionPlan that streams record batches from an Arrow IPC spill file. +/// Used during the join phase so that spilled probe data is read on-demand +/// instead of loaded entirely into memory. +#[derive(Debug)] +struct SpillReaderExec { + spill_file: RefCountedTempFile, + schema: SchemaRef, + cache: PlanProperties, +} + +impl SpillReaderExec { + fn new(spill_file: RefCountedTempFile, schema: SchemaRef) -> Self { + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + datafusion::physical_plan::execution_plan::EmissionType::Incremental, + datafusion::physical_plan::execution_plan::Boundedness::Bounded, + ); + Self { + spill_file, + schema, + cache, + } + } +} + +impl DisplayAs for SpillReaderExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SpillReaderExec") + } +} + +impl ExecutionPlan for SpillReaderExec { + fn name(&self) -> &str { + "SpillReaderExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DFResult { + let schema = Arc::clone(&self.schema); + let coalesce_schema = Arc::clone(&self.schema); + let path = self.spill_file.path().to_path_buf(); + // Move the spill file handle into the blocking closure to keep + // the temp file alive until the reader is done. + let spill_file_handle = self.spill_file.clone(); + + // Use a channel so file I/O runs on a blocking thread and doesn't + // block the async executor. This lets select_all interleave multiple + // partition streams effectively. + let (tx, rx) = mpsc::channel::>(4); + + tokio::task::spawn_blocking(move || { + let _keep_alive = spill_file_handle; + let file = match File::open(&path) { + Ok(f) => f, + Err(e) => { + let _ = tx.blocking_send(Err(DataFusionError::Execution(format!( + "Failed to open spill file: {e}" + )))); + return; + } + }; + let reader = match StreamReader::try_new( + BufReader::with_capacity(SPILL_IO_BUFFER_SIZE, file), + None, + ) { + Ok(r) => r, + Err(e) => { + let _ = tx.blocking_send(Err(DataFusionError::ArrowError(Box::new(e), None))); + return; + } + }; + + // Coalesce small sub-batches into larger ones to reduce per-batch + // overhead in the downstream hash join. + let mut pending: Vec = Vec::new(); + let mut pending_rows = 0usize; + + for batch_result in reader { + let batch = match batch_result { + Ok(b) => b, + Err(e) => { + let _ = + tx.blocking_send(Err(DataFusionError::ArrowError(Box::new(e), None))); + return; + } + }; + if batch.num_rows() == 0 { + continue; + } + pending_rows += batch.num_rows(); + pending.push(batch); + + if pending_rows >= SPILL_READ_COALESCE_TARGET { + let merged = if pending.len() == 1 { + Ok(pending.pop().unwrap()) + } else { + concat_batches(&coalesce_schema, &pending) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + }; + pending.clear(); + pending_rows = 0; + if tx.blocking_send(merged).is_err() { + return; + } + } + } + + // Flush remaining + if !pending.is_empty() { + let merged = if pending.len() == 1 { + Ok(pending.pop().unwrap()) + } else { + concat_batches(&coalesce_schema, &pending) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + }; + let _ = tx.blocking_send(merged); + } + }); + + let batch_stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + batch_stream, + ))) + } +} + +// --------------------------------------------------------------------------- +// StreamSourceExec: wrap an existing stream as an ExecutionPlan +// --------------------------------------------------------------------------- + +/// An ExecutionPlan that yields batches from a pre-existing stream. +/// Used in the fast path to feed the probe side's live stream into +/// a `HashJoinExec` without buffering or spilling. +struct StreamSourceExec { + stream: Mutex>, + schema: SchemaRef, + cache: PlanProperties, +} + +impl StreamSourceExec { + fn new(stream: SendableRecordBatchStream, schema: SchemaRef) -> Self { + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + datafusion::physical_plan::execution_plan::EmissionType::Incremental, + datafusion::physical_plan::execution_plan::Boundedness::Bounded, + ); + Self { + stream: Mutex::new(Some(stream)), + schema, + cache, + } + } +} + +impl fmt::Debug for StreamSourceExec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("StreamSourceExec").finish() + } +} + +impl DisplayAs for StreamSourceExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "StreamSourceExec") + } +} + +impl ExecutionPlan for StreamSourceExec { + fn name(&self) -> &str { + "StreamSourceExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> DFResult> { + Ok(self) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> DFResult { + self.stream + .lock() + .map_err(|e| DataFusionError::Internal(format!("lock poisoned: {e}")))? + .take() + .ok_or_else(|| { + DataFusionError::Internal("StreamSourceExec: stream already consumed".to_string()) + }) + } +} + +// --------------------------------------------------------------------------- +// GraceHashJoinMetrics +// --------------------------------------------------------------------------- + +/// Production metrics for the Grace Hash Join operator. +struct GraceHashJoinMetrics { + /// Baseline metrics (output rows, elapsed compute) + baseline: BaselineMetrics, + /// Time spent partitioning the build side + build_time: Time, + /// Time spent partitioning the probe side + probe_time: Time, + /// Number of spill events + spill_count: Count, + /// Total bytes spilled to disk + spilled_bytes: Count, + /// Number of build-side input rows + build_input_rows: Count, + /// Number of build-side input batches + build_input_batches: Count, + /// Number of probe-side input rows + input_rows: Count, + /// Number of probe-side input batches + input_batches: Count, +} + +impl GraceHashJoinMetrics { + fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline: BaselineMetrics::new(metrics, partition), + build_time: MetricBuilder::new(metrics).subset_time("build_time", partition), + probe_time: MetricBuilder::new(metrics).subset_time("probe_time", partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + build_input_rows: MetricBuilder::new(metrics).counter("build_input_rows", partition), + build_input_batches: MetricBuilder::new(metrics) + .counter("build_input_batches", partition), + input_rows: MetricBuilder::new(metrics).counter("input_rows", partition), + input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), + } + } +} + +// --------------------------------------------------------------------------- +// GraceHashJoinExec +// --------------------------------------------------------------------------- + +/// Grace Hash Join execution plan. +/// +/// Partitions both sides into N buckets, then joins each bucket independently +/// using DataFusion's HashJoinExec. Spills partitions to disk when memory +/// pressure is detected. +#[derive(Debug)] +pub struct GraceHashJoinExec { + /// Left input + left: Arc, + /// Right input + right: Arc, + /// Join key pairs: (left_key, right_key) + on: Vec<(Arc, Arc)>, + /// Optional join filter applied after key matching + filter: Option, + /// Join type + join_type: JoinType, + /// Number of hash partitions + num_partitions: usize, + /// Whether left is the build side (true) or right is (false) + build_left: bool, + /// Maximum build-side bytes for the fast path (0 = disabled) + fast_path_threshold: usize, + /// Output schema + schema: SchemaRef, + /// Plan properties cache + cache: PlanProperties, + /// Metrics + metrics: ExecutionPlanMetricsSet, +} + +impl GraceHashJoinExec { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + left: Arc, + right: Arc, + on: Vec<(Arc, Arc)>, + filter: Option, + join_type: &JoinType, + num_partitions: usize, + build_left: bool, + fast_path_threshold: usize, + ) -> DFResult { + // Build the output schema using HashJoinExec's logic. + // HashJoinExec expects left=build, right=probe. When build_left=false, + // we swap inputs + keys + join type for schema derivation, then store + // original values for our own partitioning logic. + let hash_join = HashJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + let (schema, cache) = if build_left { + (hash_join.schema(), hash_join.properties().clone()) + } else { + // Swap to get correct output schema for build-right + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + (swapped.schema(), swapped.properties().clone()) + }; + + Ok(Self { + left, + right, + on, + filter, + join_type: *join_type, + num_partitions: if num_partitions == 0 { + DEFAULT_NUM_PARTITIONS + } else { + num_partitions + }, + build_left, + fast_path_threshold, + schema, + cache, + metrics: ExecutionPlanMetricsSet::new(), + }) + } +} + +impl DisplayAs for GraceHashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + let on: Vec = self.on.iter().map(|(l, r)| format!("({l}, {r})")).collect(); + write!( + f, + "GraceHashJoinExec: join_type={:?}, on=[{}], num_partitions={}", + self.join_type, + on.join(", "), + self.num_partitions, + ) + } + } + } +} + +impl ExecutionPlan for GraceHashJoinExec { + fn name(&self) -> &str { + "GraceHashJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(GraceHashJoinExec::try_new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.on.clone(), + self.filter.clone(), + &self.join_type, + self.num_partitions, + self.build_left, + self.fast_path_threshold, + )?)) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DFResult { + info!( + "GraceHashJoin: execute() called. build_left={}, join_type={:?}, \ + num_partitions={}, fast_path_threshold={}\n left: {}\n right: {}", + self.build_left, + self.join_type, + self.num_partitions, + self.fast_path_threshold, + DisplayableExecutionPlan::new(self.left.as_ref()).one_line(), + DisplayableExecutionPlan::new(self.right.as_ref()).one_line(), + ); + let left_stream = self.left.execute(partition, Arc::clone(&context))?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; + + let join_metrics = GraceHashJoinMetrics::new(&self.metrics, partition); + + // Determine build/probe streams and schemas based on build_left. + // The internal execution always treats first arg as build, second as probe. + let (build_stream, probe_stream, build_schema, probe_schema, build_on, probe_on) = + if self.build_left { + let build_keys: Vec<_> = self.on.iter().map(|(l, _)| Arc::clone(l)).collect(); + let probe_keys: Vec<_> = self.on.iter().map(|(_, r)| Arc::clone(r)).collect(); + ( + left_stream, + right_stream, + self.left.schema(), + self.right.schema(), + build_keys, + probe_keys, + ) + } else { + // Build right: right is build side, left is probe side + let build_keys: Vec<_> = self.on.iter().map(|(_, r)| Arc::clone(r)).collect(); + let probe_keys: Vec<_> = self.on.iter().map(|(l, _)| Arc::clone(l)).collect(); + ( + right_stream, + left_stream, + self.right.schema(), + self.left.schema(), + build_keys, + probe_keys, + ) + }; + + let on = self.on.clone(); + let filter = self.filter.clone(); + let join_type = self.join_type; + let num_partitions = self.num_partitions; + let build_left = self.build_left; + let fast_path_threshold = self.fast_path_threshold; + let output_schema = Arc::clone(&self.schema); + + let result_stream = futures::stream::once(async move { + execute_grace_hash_join( + build_stream, + probe_stream, + build_on, + probe_on, + on, + filter, + join_type, + num_partitions, + build_left, + fast_path_threshold, + build_schema, + probe_schema, + output_schema, + context, + join_metrics, + ) + .await + }) + .try_flatten(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + result_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +// --------------------------------------------------------------------------- +// Per-partition state +// --------------------------------------------------------------------------- + +/// Per-partition state tracking buffered data or spill writers. +struct HashPartition { + /// In-memory build-side batches for this partition. + build_batches: Vec, + /// In-memory probe-side batches for this partition. + probe_batches: Vec, + /// Incremental spill writer for build side (if spilling). + build_spill_writer: Option, + /// Incremental spill writer for probe side (if spilling). + probe_spill_writer: Option, + /// Approximate memory used by build-side batches in this partition. + build_mem_size: usize, + /// Approximate memory used by probe-side batches in this partition. + probe_mem_size: usize, +} + +impl HashPartition { + fn new() -> Self { + Self { + build_batches: Vec::new(), + probe_batches: Vec::new(), + build_spill_writer: None, + probe_spill_writer: None, + build_mem_size: 0, + probe_mem_size: 0, + } + } + + /// Whether the build side has been spilled to disk. + fn build_spilled(&self) -> bool { + self.build_spill_writer.is_some() + } +} + +// --------------------------------------------------------------------------- +// Main execution logic +// --------------------------------------------------------------------------- + +/// Main execution logic for the grace hash join. +/// +/// `build_stream`/`probe_stream`: already swapped based on build_left. +/// `build_keys`/`probe_keys`: key expressions for their respective sides. +/// `original_on`: original (left_key, right_key) pairs for HashJoinExec. +/// `build_left`: whether left is build side (affects HashJoinExec construction). +#[allow(clippy::too_many_arguments)] +async fn execute_grace_hash_join( + build_stream: SendableRecordBatchStream, + probe_stream: SendableRecordBatchStream, + build_keys: Vec>, + probe_keys: Vec>, + original_on: Vec<(Arc, Arc)>, + filter: Option, + join_type: JoinType, + num_partitions: usize, + build_left: bool, + fast_path_threshold: usize, + build_schema: SchemaRef, + probe_schema: SchemaRef, + _output_schema: SchemaRef, + context: Arc, + metrics: GraceHashJoinMetrics, +) -> DFResult>> { + let ghj_id = GHJ_INSTANCE_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + // Set up memory reservation (shared across build and probe phases) + let mut reservation = MutableReservation( + MemoryConsumer::new("GraceHashJoinExec") + .with_can_spill(true) + .register(&context.runtime_env().memory_pool), + ); + + info!( + "GHJ#{}: started. build_left={}, join_type={:?}, pool reserved={}", + ghj_id, + build_left, + join_type, + context.runtime_env().memory_pool.reserved(), + ); + + let mut partitions: Vec = + (0..num_partitions).map(|_| HashPartition::new()).collect(); + + let mut scratch = ScratchSpace::default(); + + // Phase 1: Partition the build side + { + let _timer = metrics.build_time.timer(); + partition_build_side( + build_stream, + &build_keys, + num_partitions, + &build_schema, + &mut partitions, + &mut reservation, + &context, + &metrics, + &mut scratch, + ) + .await?; + } + + // Log build-side partition summary + { + let pool = &context.runtime_env().memory_pool; + let total_build_rows: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + let total_build_bytes: usize = partitions.iter().map(|p| p.build_mem_size).sum(); + let spilled_count = partitions.iter().filter(|p| p.build_spilled()).count(); + info!( + "GraceHashJoin: build phase complete. {} partitions ({} spilled), \ + total build: {} rows, {} bytes. Memory pool reserved={}", + num_partitions, + spilled_count, + total_build_rows, + total_build_bytes, + pool.reserved(), + ); + for (i, p) in partitions.iter().enumerate() { + if !p.build_batches.is_empty() || p.build_spilled() { + let rows: usize = p.build_batches.iter().map(|b| b.num_rows()).sum(); + info!( + "GraceHashJoin: partition[{}] build: {} batches, {} rows, {} bytes, spilled={}", + i, + p.build_batches.len(), + rows, + p.build_mem_size, + p.build_spilled(), + ); + } + } + } + + // Fast path: if no build partitions spilled and the build side is + // genuinely tiny, skip probe partitioning and stream the probe directly + // through a single HashJoinExec. This avoids spilling gigabytes of + // probe data to disk for a trivial hash table (e.g. 10-row build side). + // + // The threshold uses actual batch sizes (not the unreliable proportional + // estimate). The configured value is divided by spark.executor.cores in + // the planner so each concurrent task gets its fair share. + // Configurable via spark.comet.exec.graceHashJoin.fastPathThreshold. + + let build_spilled = partitions.iter().any(|p| p.build_spilled()); + let actual_build_bytes: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.get_array_memory_size()) + .sum(); + + if !build_spilled && fast_path_threshold > 0 && actual_build_bytes <= fast_path_threshold { + let total_build_rows: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + info!( + "GHJ#{}: fast path — build side tiny ({} rows, {} bytes). \ + Streaming probe directly through HashJoinExec. pool reserved={}", + ghj_id, + total_build_rows, + actual_build_bytes, + context.runtime_env().memory_pool.reserved(), + ); + + // Release our reservation — HashJoinExec tracks its own memory. + reservation.free(); + + let build_data: Vec = partitions + .into_iter() + .flat_map(|p| p.build_batches) + .collect(); + + let build_source = memory_source_exec(build_data, &build_schema)?; + + let probe_source: Arc = Arc::new(StreamSourceExec::new( + probe_stream, + Arc::clone(&probe_schema), + )); + + let (left_source, right_source): (Arc, Arc) = + if build_left { + (build_source, probe_source) + } else { + (probe_source, build_source) + }; + + info!( + "GraceHashJoin: FAST PATH creating HashJoinExec, \ + build_left={}, actual_build_bytes={}", + build_left, actual_build_bytes, + ); + + let stream = if build_left { + let hash_join = HashJoinExec::try_new( + left_source, + right_source, + original_on, + filter, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + info!( + "GraceHashJoin: FAST PATH plan:\n{}", + DisplayableExecutionPlan::new(&hash_join).indent(true) + ); + hash_join.execute(0, context_for_join_output(&context))? + } else { + let hash_join = Arc::new(HashJoinExec::try_new( + left_source, + right_source, + original_on, + filter, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?); + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + info!( + "GraceHashJoin: FAST PATH (swapped) plan:\n{}", + DisplayableExecutionPlan::new(swapped.as_ref()).indent(true) + ); + swapped.execute(0, context_for_join_output(&context))? + }; + + let output_metrics = metrics.baseline.clone(); + let result_stream = stream.inspect_ok(move |batch| { + output_metrics.record_output(batch.num_rows()); + }); + + return Ok(result_stream.boxed()); + } + + let total_build_rows: usize = partitions + .iter() + .flat_map(|p| p.build_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + info!( + "GHJ#{}: slow path — build spilled={}, {} rows, {} bytes (actual). \ + join_type={:?}, build_left={}. pool reserved={}. Partitioning probe side.", + ghj_id, + build_spilled, + total_build_rows, + actual_build_bytes, + join_type, + build_left, + context.runtime_env().memory_pool.reserved(), + ); + + // Phase 2: Partition the probe side + { + let _timer = metrics.probe_time.timer(); + partition_probe_side( + probe_stream, + &probe_keys, + num_partitions, + &probe_schema, + &mut partitions, + &mut reservation, + &build_schema, + &context, + &metrics, + &mut scratch, + ) + .await?; + } + + // Log probe-side partition summary + { + let total_probe_rows: usize = partitions + .iter() + .flat_map(|p| p.probe_batches.iter()) + .map(|b| b.num_rows()) + .sum(); + let total_probe_bytes: usize = partitions.iter().map(|p| p.probe_mem_size).sum(); + let probe_spilled = partitions + .iter() + .filter(|p| p.probe_spill_writer.is_some()) + .count(); + info!( + "GHJ#{}: probe phase complete. \ + total probe (in-memory): {} rows, {} bytes, {} spilled. \ + reservation={}, pool reserved={}", + ghj_id, + total_probe_rows, + total_probe_bytes, + probe_spilled, + reservation.0.size(), + context.runtime_env().memory_pool.reserved(), + ); + } + + // Finish all open spill writers before reading back + let finished_partitions = + finish_spill_writers(partitions, &build_schema, &probe_schema, &metrics)?; + + // Merge adjacent partitions to reduce the number of HashJoinExec calls. + // Compute desired partition count from total build bytes. + let total_build_bytes: usize = finished_partitions.iter().map(|p| p.build_bytes).sum(); + let desired_partitions = if total_build_bytes > 0 { + let desired = total_build_bytes.div_ceil(TARGET_PARTITION_BUILD_SIZE); + desired.max(1).min(num_partitions) + } else { + 1 + }; + let original_partition_count = finished_partitions.len(); + let finished_partitions = merge_finished_partitions(finished_partitions, desired_partitions); + if finished_partitions.len() < original_partition_count { + info!( + "GraceHashJoin: merged {} partitions into {} (total build {} bytes, \ + target {} bytes/partition)", + original_partition_count, + finished_partitions.len(), + total_build_bytes, + TARGET_PARTITION_BUILD_SIZE, + ); + } + + // Release all remaining reservation before Phase 3. The in-memory + // partition data is now owned by finished_partitions and will be moved + // into per-partition HashJoinExec instances (which track memory via + // their own HashJoinInput reservations). Keeping our reservation alive + // would double-count the memory and starve other consumers. + info!( + "GHJ#{}: freeing reservation ({} bytes) before Phase 3. pool reserved={}", + ghj_id, + reservation.0.size(), + context.runtime_env().memory_pool.reserved(), + ); + reservation.free(); + + // Phase 3: Join partitions sequentially. + // We use a concurrency limit of 1 to avoid creating multiple simultaneous + // HashJoinInput reservations per task. With multiple Spark tasks sharing + // the same memory pool, even modest build sides (e.g. 22 MB) can exhaust + // memory when many tasks run concurrent hash table builds simultaneously. + const MAX_CONCURRENT_PARTITIONS: usize = 1; + let semaphore = Arc::new(tokio::sync::Semaphore::new(MAX_CONCURRENT_PARTITIONS)); + let (tx, rx) = mpsc::channel::>(MAX_CONCURRENT_PARTITIONS * 2); + + for partition in finished_partitions { + let tx = tx.clone(); + let sem = Arc::clone(&semaphore); + let original_on = original_on.clone(); + let filter = filter.clone(); + let build_schema = Arc::clone(&build_schema); + let probe_schema = Arc::clone(&probe_schema); + let context = Arc::clone(&context); + + tokio::spawn(async move { + let _permit = match sem.acquire().await { + Ok(p) => p, + Err(_) => return, // semaphore closed + }; + match join_single_partition( + partition, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + ) + .await + { + Ok(streams) => { + for mut stream in streams { + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + return; + } + } + } + } + Err(e) => { + let _ = tx.send(Err(e)).await; + } + } + }); + } + drop(tx); + + let output_metrics = metrics.baseline.clone(); + let output_row_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + let counter = Arc::clone(&output_row_count); + let jt = join_type; + let bl = build_left; + let result_stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }) + .inspect_ok(move |batch| { + output_metrics.record_output(batch.num_rows()); + let prev = counter.fetch_add(batch.num_rows(), std::sync::atomic::Ordering::Relaxed); + let new_total = prev + batch.num_rows(); + // Log every ~1M rows to detect exploding joins + if new_total / 1_000_000 > prev / 1_000_000 { + info!( + "GraceHashJoin: slow path output: {} rows emitted so far \ + (join_type={:?}, build_left={})", + new_total, jt, bl, + ); + } + }); + + Ok(result_stream.boxed()) +} + +/// Wraps MemoryReservation to allow mutation through reference. +struct MutableReservation(MemoryReservation); + +impl MutableReservation { + fn try_grow(&mut self, additional: usize) -> DFResult<()> { + self.0.try_grow(additional) + } + + fn shrink(&mut self, amount: usize) { + self.0.shrink(amount); + } + + fn free(&mut self) -> usize { + self.0.free() + } +} + +// --------------------------------------------------------------------------- +// ScratchSpace: reusable buffers for efficient hash partitioning +// --------------------------------------------------------------------------- + +/// Reusable scratch buffers for partitioning batches. Uses a prefix-sum +/// algorithm (borrowed from the shuffle `multi_partition.rs`) to compute +/// contiguous row-index regions per partition in a single pass, avoiding +/// N separate `take()` kernel calls. +#[derive(Default)] +struct ScratchSpace { + /// Hash values for each row. + hashes: Vec, + /// Partition id assigned to each row. + partition_ids: Vec, + /// Row indices reordered so that each partition's rows are contiguous. + partition_row_indices: Vec, + /// `partition_starts[k]..partition_starts[k+1]` gives the slice of + /// `partition_row_indices` belonging to partition k. + partition_starts: Vec, +} + +impl ScratchSpace { + /// Compute hashes and partition ids, then build the prefix-sum index + /// structures for the given batch. + fn compute_partitions( + &mut self, + batch: &RecordBatch, + keys: &[Arc], + num_partitions: usize, + recursion_level: usize, + ) -> DFResult<()> { + let num_rows = batch.num_rows(); + + // Evaluate key columns + let key_columns: Vec<_> = keys + .iter() + .map(|expr| expr.evaluate(batch).and_then(|cv| cv.into_array(num_rows))) + .collect::>>()?; + + // Hash + self.hashes.resize(num_rows, 0); + self.hashes.truncate(num_rows); + self.hashes.fill(0); + let random_state = partition_random_state(recursion_level); + create_hashes(&key_columns, &random_state, &mut self.hashes)?; + + // Assign partition ids + self.partition_ids.resize(num_rows, 0); + for (i, hash) in self.hashes[..num_rows].iter().enumerate() { + self.partition_ids[i] = (*hash as u32) % (num_partitions as u32); + } + + // Prefix-sum to get contiguous regions + self.map_partition_ids_to_starts_and_indices(num_partitions, num_rows); + + Ok(()) + } + + /// Prefix-sum algorithm from `multi_partition.rs`. + fn map_partition_ids_to_starts_and_indices(&mut self, num_partitions: usize, num_rows: usize) { + let partition_ids = &self.partition_ids[..num_rows]; + + // Count each partition size + let partition_counters = &mut self.partition_starts; + partition_counters.resize(num_partitions + 1, 0); + partition_counters.fill(0); + partition_ids + .iter() + .for_each(|pid| partition_counters[*pid as usize] += 1); + + // Accumulate into partition ends + let mut accum = 0u32; + for v in partition_counters.iter_mut() { + *v += accum; + accum = *v; + } + + // Build partition_row_indices (iterate in reverse to turn ends into starts) + self.partition_row_indices.resize(num_rows, 0); + for (index, pid) in partition_ids.iter().enumerate().rev() { + self.partition_starts[*pid as usize] -= 1; + let pos = self.partition_starts[*pid as usize]; + self.partition_row_indices[pos as usize] = index as u32; + } + } + + /// Get the row index slice for a given partition. + fn partition_slice(&self, partition_id: usize) -> &[u32] { + let start = self.partition_starts[partition_id] as usize; + let end = self.partition_starts[partition_id + 1] as usize; + &self.partition_row_indices[start..end] + } + + /// Number of rows in a given partition. + fn partition_len(&self, partition_id: usize) -> usize { + (self.partition_starts[partition_id + 1] - self.partition_starts[partition_id]) as usize + } + + fn take_partition( + &self, + batch: &RecordBatch, + partition_id: usize, + ) -> DFResult> { + let row_indices = self.partition_slice(partition_id); + if row_indices.is_empty() { + return Ok(None); + } + let indices_array = UInt32Array::from(row_indices.to_vec()); + let columns: Vec<_> = batch + .columns() + .iter() + .map(|col| take(col.as_ref(), &indices_array, None)) + .collect::, _>>()?; + Ok(Some(RecordBatch::try_new(batch.schema(), columns)?)) + } +} + +// --------------------------------------------------------------------------- +// Spill reading +// --------------------------------------------------------------------------- + +/// Read record batches from a finished spill file. +fn read_spilled_batches( + spill_file: &RefCountedTempFile, + _schema: &SchemaRef, +) -> DFResult> { + let file = File::open(spill_file.path()) + .map_err(|e| DataFusionError::Execution(format!("Failed to open spill file: {e}")))?; + let reader = BufReader::with_capacity(SPILL_IO_BUFFER_SIZE, file); + let stream_reader = StreamReader::try_new(reader, None)?; + let batches: Vec = stream_reader.into_iter().collect::, _>>()?; + Ok(batches) +} + +// --------------------------------------------------------------------------- +// Phase 1: Build-side partitioning +// --------------------------------------------------------------------------- + +/// Phase 1: Read all build-side batches, hash-partition into N buckets. +/// Spills the largest partition when memory pressure is detected. +#[allow(clippy::too_many_arguments)] +async fn partition_build_side( + mut input: SendableRecordBatchStream, + keys: &[Arc], + num_partitions: usize, + schema: &SchemaRef, + partitions: &mut [HashPartition], + reservation: &mut MutableReservation, + context: &Arc, + metrics: &GraceHashJoinMetrics, + scratch: &mut ScratchSpace, +) -> DFResult<()> { + while let Some(batch) = input.next().await { + let batch = batch?; + if batch.num_rows() == 0 { + continue; + } + + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); + + // Track total batch size once, estimate per-partition proportionally + let total_batch_size = batch.get_array_memory_size(); + let total_rows = batch.num_rows(); + + scratch.compute_partitions(&batch, keys, num_partitions, 0)?; + + #[allow(clippy::needless_range_loop)] + for part_idx in 0..num_partitions { + if scratch.partition_len(part_idx) == 0 { + continue; + } + + let sub_rows = scratch.partition_len(part_idx); + let sub_batch = if sub_rows == total_rows { + batch.clone() + } else { + scratch.take_partition(&batch, part_idx)?.unwrap() + }; + let batch_size = if total_rows > 0 { + (total_batch_size as u64 * sub_rows as u64 / total_rows as u64) as usize + } else { + 0 + }; + + if partitions[part_idx].build_spilled() { + // This partition is already spilled; append incrementally + if let Some(ref mut writer) = partitions[part_idx].build_spill_writer { + writer.write_batch(&sub_batch)?; + } + } else { + // Try to reserve memory + if reservation.try_grow(batch_size).is_err() { + // Memory pressure: spill the largest in-memory partition + info!( + "GraceHashJoin: memory pressure during build, spilling largest partition" + ); + spill_largest_partition(partitions, schema, context, reservation, metrics)?; + + // Retry reservation after spilling + if reservation.try_grow(batch_size).is_err() { + // Still can't fit; spill this partition too + info!( + "GraceHashJoin: still under pressure, spilling partition {}", + part_idx + ); + spill_partition_build( + &mut partitions[part_idx], + schema, + context, + reservation, + metrics, + )?; + if let Some(ref mut writer) = partitions[part_idx].build_spill_writer { + writer.write_batch(&sub_batch)?; + } + continue; + } + } + + partitions[part_idx].build_mem_size += batch_size; + partitions[part_idx].build_batches.push(sub_batch); + } + } + } + + Ok(()) +} + +/// Spill the largest in-memory build partition to disk. +fn spill_largest_partition( + partitions: &mut [HashPartition], + schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + // Find the largest non-spilled partition + let largest_idx = partitions + .iter() + .enumerate() + .filter(|(_, p)| !p.build_spilled() && !p.build_batches.is_empty()) + .max_by_key(|(_, p)| p.build_mem_size) + .map(|(idx, _)| idx); + + if let Some(idx) = largest_idx { + info!( + "GraceHashJoin: spilling partition {} ({} bytes, {} batches)", + idx, + partitions[idx].build_mem_size, + partitions[idx].build_batches.len() + ); + spill_partition_build(&mut partitions[idx], schema, context, reservation, metrics)?; + } + + Ok(()) +} + +/// Spill a single partition's build-side data to disk using SpillWriter. +fn spill_partition_build( + partition: &mut HashPartition, + schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join build")?; + + let mut writer = SpillWriter::new(temp_file, schema)?; + writer.write_batches(&partition.build_batches)?; + + // Free memory + let freed = partition.build_mem_size; + reservation.shrink(freed); + + metrics.spill_count.add(1); + metrics.spilled_bytes.add(freed); + + partition.build_spill_writer = Some(writer); + partition.build_batches.clear(); + partition.build_mem_size = 0; + + Ok(()) +} + +/// Spill a single partition's probe-side data to disk using SpillWriter. +fn spill_partition_probe( + partition: &mut HashPartition, + schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + if partition.probe_batches.is_empty() && partition.probe_spill_writer.is_some() { + return Ok(()); + } + + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join probe")?; + + let mut writer = SpillWriter::new(temp_file, schema)?; + writer.write_batches(&partition.probe_batches)?; + + let freed = partition.probe_mem_size; + reservation.shrink(freed); + + metrics.spill_count.add(1); + metrics.spilled_bytes.add(freed); + + partition.probe_spill_writer = Some(writer); + partition.probe_batches.clear(); + partition.probe_mem_size = 0; + + Ok(()) +} + +/// Spill both build and probe sides of a partition to disk. +/// When spilling during the probe phase, both sides must be spilled so the +/// join phase reads both consistently from disk. +fn spill_partition_both_sides( + partition: &mut HashPartition, + probe_schema: &SchemaRef, + build_schema: &SchemaRef, + context: &Arc, + reservation: &mut MutableReservation, + metrics: &GraceHashJoinMetrics, +) -> DFResult<()> { + if !partition.build_spilled() { + spill_partition_build(partition, build_schema, context, reservation, metrics)?; + } + if partition.probe_spill_writer.is_none() { + spill_partition_probe(partition, probe_schema, context, reservation, metrics)?; + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Phase 2: Probe-side partitioning +// --------------------------------------------------------------------------- + +/// Phase 2: Read all probe-side batches, route to in-memory buffers or spill files. +/// Tracks probe-side memory in the reservation and spills partitions when pressure +/// is detected, preventing OOM when the probe side is much larger than the build side. +#[allow(clippy::too_many_arguments)] +async fn partition_probe_side( + mut input: SendableRecordBatchStream, + keys: &[Arc], + num_partitions: usize, + schema: &SchemaRef, + partitions: &mut [HashPartition], + reservation: &mut MutableReservation, + build_schema: &SchemaRef, + context: &Arc, + metrics: &GraceHashJoinMetrics, + scratch: &mut ScratchSpace, +) -> DFResult<()> { + let mut probe_rows_accumulated: usize = 0; + while let Some(batch) = input.next().await { + let batch = batch?; + if batch.num_rows() == 0 { + continue; + } + let prev_milestone = probe_rows_accumulated / 5_000_000; + probe_rows_accumulated += batch.num_rows(); + let new_milestone = probe_rows_accumulated / 5_000_000; + if new_milestone > prev_milestone { + info!( + "GraceHashJoin: probe accumulation progress: {} rows, \ + reservation={}, pool reserved={}", + probe_rows_accumulated, + reservation.0.size(), + context.runtime_env().memory_pool.reserved(), + ); + } + + metrics.input_batches.add(1); + metrics.input_rows.add(batch.num_rows()); + + let total_rows = batch.num_rows(); + scratch.compute_partitions(&batch, keys, num_partitions, 0)?; + + #[allow(clippy::needless_range_loop)] + for part_idx in 0..num_partitions { + if scratch.partition_len(part_idx) == 0 { + continue; + } + + let sub_batch = if scratch.partition_len(part_idx) == total_rows { + batch.clone() + } else { + scratch.take_partition(&batch, part_idx)?.unwrap() + }; + + if partitions[part_idx].build_spilled() { + // Build side was spilled, so spill probe side too + if partitions[part_idx].probe_spill_writer.is_none() { + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join probe")?; + let mut writer = SpillWriter::new(temp_file, schema)?; + // Write any accumulated in-memory probe batches first + if !partitions[part_idx].probe_batches.is_empty() { + let freed = partitions[part_idx].probe_mem_size; + let batches = std::mem::take(&mut partitions[part_idx].probe_batches); + writer.write_batches(&batches)?; + partitions[part_idx].probe_mem_size = 0; + reservation.shrink(freed); + } + partitions[part_idx].probe_spill_writer = Some(writer); + } + if let Some(ref mut writer) = partitions[part_idx].probe_spill_writer { + writer.write_batch(&sub_batch)?; + } + } else { + let batch_size = sub_batch.get_array_memory_size(); + if reservation.try_grow(batch_size).is_err() { + // Memory pressure: spill ALL non-spilled partitions. + // With multiple concurrent GHJ instances sharing the pool, + // partial spilling just lets data re-accumulate. Spilling + // everything ensures all subsequent probe data goes directly + // to disk, keeping in-memory footprint near zero. + let total_in_memory: usize = partitions + .iter() + .filter(|p| !p.build_spilled()) + .map(|p| p.build_mem_size + p.probe_mem_size) + .sum(); + let spillable_count = partitions.iter().filter(|p| !p.build_spilled()).count(); + + info!( + "GraceHashJoin: memory pressure during probe, \ + spilling all {} non-spilled partitions ({} bytes)", + spillable_count, total_in_memory, + ); + + for i in 0..partitions.len() { + if !partitions[i].build_spilled() { + spill_partition_both_sides( + &mut partitions[i], + schema, + build_schema, + context, + reservation, + metrics, + )?; + } + } + } + + if partitions[part_idx].build_spilled() { + // Partition was just spilled above — write to spill writer + if partitions[part_idx].probe_spill_writer.is_none() { + let temp_file = context + .runtime_env() + .disk_manager + .create_tmp_file("grace hash join probe")?; + partitions[part_idx].probe_spill_writer = + Some(SpillWriter::new(temp_file, schema)?); + } + if let Some(ref mut writer) = partitions[part_idx].probe_spill_writer { + writer.write_batch(&sub_batch)?; + } + } else { + partitions[part_idx].probe_mem_size += batch_size; + partitions[part_idx].probe_batches.push(sub_batch); + } + } + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Finish spill writers +// --------------------------------------------------------------------------- + +/// State of a finished partition ready for joining. +/// After merging, a partition may hold multiple spill files from adjacent +/// original partitions. +struct FinishedPartition { + build_batches: Vec, + probe_batches: Vec, + build_spill_files: Vec, + probe_spill_files: Vec, + /// Total build-side bytes (in-memory + spilled) for merge decisions. + build_bytes: usize, +} + +/// Finish all open spill writers so files can be read back. +fn finish_spill_writers( + partitions: Vec, + _left_schema: &SchemaRef, + _right_schema: &SchemaRef, + _metrics: &GraceHashJoinMetrics, +) -> DFResult> { + let mut finished = Vec::with_capacity(partitions.len()); + + for partition in partitions { + let (build_spill_files, spilled_build_bytes) = + if let Some(writer) = partition.build_spill_writer { + let (file, bytes) = writer.finish()?; + (vec![file], bytes) + } else { + (vec![], 0) + }; + + let probe_spill_files = if let Some(writer) = partition.probe_spill_writer { + let (file, _bytes) = writer.finish()?; + vec![file] + } else { + vec![] + }; + + finished.push(FinishedPartition { + build_bytes: partition.build_mem_size + spilled_build_bytes, + build_batches: partition.build_batches, + probe_batches: partition.probe_batches, + build_spill_files, + probe_spill_files, + }); + } + + Ok(finished) +} + +/// Merge adjacent finished partitions to reduce the number of per-partition +/// HashJoinExec calls. Groups adjacent partitions so each merged group has +/// roughly `TARGET_PARTITION_BUILD_SIZE` bytes of build data. +fn merge_finished_partitions( + partitions: Vec, + target_count: usize, +) -> Vec { + let original_count = partitions.len(); + if target_count >= original_count { + return partitions; + } + + // Divide original_count partitions into target_count groups as evenly as possible + let base_group_size = original_count / target_count; + let remainder = original_count % target_count; + + let mut merged = Vec::with_capacity(target_count); + let mut iter = partitions.into_iter(); + + for group_idx in 0..target_count { + // First `remainder` groups get one extra partition + let group_size = base_group_size + if group_idx < remainder { 1 } else { 0 }; + + let mut build_batches = Vec::new(); + let mut probe_batches = Vec::new(); + let mut build_spill_files = Vec::new(); + let mut probe_spill_files = Vec::new(); + let mut build_bytes = 0usize; + + for _ in 0..group_size { + if let Some(p) = iter.next() { + build_batches.extend(p.build_batches); + probe_batches.extend(p.probe_batches); + build_spill_files.extend(p.build_spill_files); + probe_spill_files.extend(p.probe_spill_files); + build_bytes += p.build_bytes; + } + } + + merged.push(FinishedPartition { + build_batches, + probe_batches, + build_spill_files, + probe_spill_files, + build_bytes, + }); + } + + merged +} + +// --------------------------------------------------------------------------- +// Phase 3: Per-partition hash joins +// --------------------------------------------------------------------------- + +/// The output batch size for HashJoinExec within GHJ. +/// +/// With the default Comet batch size (8192), HashJoinExec produces thousands +/// of small output batches, causing significant per-batch overhead for large +/// joins (e.g., 150M output rows = 18K batches at 8192). +/// +/// 1M rows gives ~150 batches for a 150M row join — enough to avoid +/// per-batch overhead while keeping each output batch at a few hundred MB. +/// Cannot use `usize::MAX` because HashJoinExec pre-allocates Vec with +/// capacity = batch_size in `get_matched_indices_with_limit_offset`. +/// Cannot use 10M+ because output batches become multi-GB and cause OOM. +const GHJ_OUTPUT_BATCH_SIZE: usize = 1_000_000; + +/// Create a TaskContext with a larger output batch size for HashJoinExec. +/// +/// Input splitting is handled by StreamSourceExec (not batch_size). +fn context_for_join_output(context: &Arc) -> Arc { + let batch_size = GHJ_OUTPUT_BATCH_SIZE.max(context.session_config().batch_size()); + Arc::new(TaskContext::new( + context.task_id(), + context.session_id(), + context.session_config().clone().with_batch_size(batch_size), + context.scalar_functions().clone(), + context.aggregate_functions().clone(), + context.window_functions().clone(), + context.runtime_env(), + )) +} + +/// Create a `StreamSourceExec` that yields `data` batches without splitting. +/// +/// Unlike `DataSourceExec(MemorySourceConfig)`, `StreamSourceExec` does NOT +/// wrap its output in `BatchSplitStream`. This is critical for the build side +/// because Arrow's zero-copy `batch.slice()` shares underlying buffers, so +/// `get_record_batch_memory_size()` reports the full buffer size for every +/// slice — causing `collect_left_input` to vastly over-count memory and +/// trigger spurious OOM. Additionally, using `batch_size` large enough to +/// prevent splitting can cause Arrow i32 offset overflow for string columns. +fn memory_source_exec( + data: Vec, + schema: &SchemaRef, +) -> DFResult> { + let schema_clone = Arc::clone(schema); + let stream = + RecordBatchStreamAdapter::new(Arc::clone(schema), stream::iter(data.into_iter().map(Ok))); + Ok(Arc::new(StreamSourceExec::new( + Box::pin(stream), + schema_clone, + ))) +} + +/// Join a single partition: reads build-side spill (if any) via spawn_blocking, +/// then delegates to `join_with_spilled_probe` or `join_partition_recursive`. +/// Returns the resulting streams for this partition. +/// +/// Takes all owned data so it can be called inside `tokio::spawn`. +#[allow(clippy::too_many_arguments)] +async fn join_single_partition( + partition: FinishedPartition, + original_on: Vec<(Arc, Arc)>, + filter: Option, + join_type: JoinType, + build_left: bool, + build_schema: SchemaRef, + probe_schema: SchemaRef, + context: Arc, +) -> DFResult> { + // Get build-side batches (from memory or disk — build side is typically small). + // Use spawn_blocking for spill reads to avoid blocking the async executor. + let mut build_batches = partition.build_batches; + if !partition.build_spill_files.is_empty() { + let schema = Arc::clone(&build_schema); + let spill_files = partition.build_spill_files; + let spilled = tokio::task::spawn_blocking(move || { + let mut all = Vec::new(); + for spill_file in &spill_files { + all.extend(read_spilled_batches(spill_file, &schema)?); + } + Ok::<_, DataFusionError>(all) + }) + .await + .map_err(|e| { + DataFusionError::Execution(format!("GraceHashJoin: build spill read task failed: {e}")) + })??; + build_batches.extend(spilled); + } + + // Coalesce many tiny sub-batches into single batches to reduce per-batch + // overhead in HashJoinExec. Per-partition data is bounded by + // TARGET_PARTITION_BUILD_SIZE so concat won't hit i32 offset overflow. + let build_batches = if build_batches.len() > 1 { + vec![concat_batches(&build_schema, &build_batches)?] + } else { + build_batches + }; + + let mut streams = Vec::new(); + + if !partition.probe_spill_files.is_empty() { + // Probe side has spill file(s). Also include any in-memory probe + // batches (possible after merging adjacent partitions). + join_with_spilled_probe( + build_batches, + partition.probe_spill_files, + partition.probe_batches, + &original_on, + &filter, + &join_type, + build_left, + &build_schema, + &probe_schema, + &context, + &mut streams, + )?; + } else { + // Probe side is in-memory: coalesce before joining + let probe_batches = if partition.probe_batches.len() > 1 { + vec![concat_batches(&probe_schema, &partition.probe_batches)?] + } else { + partition.probe_batches + }; + join_partition_recursive( + build_batches, + probe_batches, + &original_on, + &filter, + &join_type, + build_left, + &build_schema, + &probe_schema, + &context, + 1, + &mut streams, + )?; + } + + Ok(streams) +} + +/// Join a partition where the probe side was spilled to disk. +/// Uses SpillReaderExec to stream probe data from the spill file instead of +/// loading it all into memory. The build side (typically small) is loaded +/// into a MemorySourceConfig for the hash table. +#[allow(clippy::too_many_arguments)] +fn join_with_spilled_probe( + build_batches: Vec, + probe_spill_files: Vec, + probe_in_memory: Vec, + original_on: JoinOnRef<'_>, + filter: &Option, + join_type: &JoinType, + build_left: bool, + build_schema: &SchemaRef, + probe_schema: &SchemaRef, + context: &Arc, + streams: &mut Vec, +) -> DFResult<()> { + let probe_spill_files_count = probe_spill_files.len(); + + // Skip if build side is empty and join type requires it + let build_empty = build_batches.is_empty(); + let skip = match join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => { + if build_left { + build_empty + } else { + false // probe emptiness unknown without reading + } + } + JoinType::Left | JoinType::LeftMark => { + if build_left { + build_empty + } else { + false + } + } + JoinType::Right => { + if !build_left { + build_empty + } else { + false + } + } + _ => false, + }; + if skip { + return Ok(()); + } + + let build_size: usize = build_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let build_rows: usize = build_batches.iter().map(|b| b.num_rows()).sum(); + info!( + "GraceHashJoin: join_with_spilled_probe build: {} batches/{} rows/{} bytes, \ + probe: streaming from spill file", + build_batches.len(), + build_rows, + build_size, + ); + + // If build side exceeds the target partition size, fall back to eager + // read + recursive repartitioning. This prevents creating HashJoinExec + // with oversized build sides that expand into huge hash tables. + let needs_repartition = build_size > TARGET_PARTITION_BUILD_SIZE; + + if needs_repartition { + info!( + "GraceHashJoin: build too large for streaming probe ({} bytes > {} target), \ + falling back to eager read + repartition", + build_size, TARGET_PARTITION_BUILD_SIZE, + ); + let mut probe_batches = probe_in_memory; + for spill_file in &probe_spill_files { + probe_batches.extend(read_spilled_batches(spill_file, probe_schema)?); + } + return join_partition_recursive( + build_batches, + probe_batches, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + 1, + streams, + ); + } + + // Concatenate build side into single batch. Per-partition data is bounded + // by TARGET_PARTITION_BUILD_SIZE so this won't hit i32 offset overflow. + let build_data = if build_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(build_schema))] + } else if build_batches.len() == 1 { + build_batches + } else { + vec![concat_batches(build_schema, &build_batches)?] + }; + + // Build side: StreamSourceExec to avoid BatchSplitStream splitting + let build_source = memory_source_exec(build_data, build_schema)?; + + // Probe side: streaming from spill file(s). + // With a single spill file and no in-memory batches, use the streaming + // SpillReaderExec. Otherwise read eagerly since the merged group sizes + // are bounded by TARGET_PARTITION_BUILD_SIZE. + let probe_source: Arc = + if probe_spill_files.len() == 1 && probe_in_memory.is_empty() { + Arc::new(SpillReaderExec::new( + probe_spill_files.into_iter().next().unwrap(), + Arc::clone(probe_schema), + )) + } else { + let mut probe_batches = probe_in_memory; + for spill_file in &probe_spill_files { + probe_batches.extend(read_spilled_batches(spill_file, probe_schema)?); + } + let probe_data = if probe_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(probe_schema))] + } else { + vec![concat_batches(probe_schema, &probe_batches)?] + }; + memory_source_exec(probe_data, probe_schema)? + }; + + // HashJoinExec expects left=build in CollectLeft mode + let (left_source, right_source) = if build_left { + (build_source as Arc, probe_source) + } else { + (probe_source, build_source as Arc) + }; + + info!( + "GraceHashJoin: SPILLED PROBE PATH creating HashJoinExec, \ + build_left={}, build_size={}, probe_source={}", + build_left, + build_size, + if probe_spill_files_count == 1 { + "SpillReaderExec" + } else { + "StreamSourceExec" + }, + ); + + let stream = if build_left { + let hash_join = HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + info!( + "GraceHashJoin: SPILLED PROBE PATH plan:\n{}", + DisplayableExecutionPlan::new(&hash_join).indent(true) + ); + hash_join.execute(0, context_for_join_output(context))? + } else { + let hash_join = Arc::new(HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?); + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + info!( + "GraceHashJoin: SPILLED PROBE PATH (swapped) plan:\n{}", + DisplayableExecutionPlan::new(swapped.as_ref()).indent(true) + ); + swapped.execute(0, context_for_join_output(context))? + }; + + streams.push(stream); + Ok(()) +} + +/// Join a single partition, recursively repartitioning if the build side is too large. +/// +/// `build_keys` / `probe_keys` for repartitioning are extracted from `original_on` +/// based on `build_left`. +#[allow(clippy::too_many_arguments)] +fn join_partition_recursive( + build_batches: Vec, + probe_batches: Vec, + original_on: JoinOnRef<'_>, + filter: &Option, + join_type: &JoinType, + build_left: bool, + build_schema: &SchemaRef, + probe_schema: &SchemaRef, + context: &Arc, + recursion_level: usize, + streams: &mut Vec, +) -> DFResult<()> { + // Skip partitions that cannot produce output based on join type. + // The join type uses Spark's left/right semantics. Map build/probe + // back to left/right based on build_left. + let (left_empty, right_empty) = if build_left { + (build_batches.is_empty(), probe_batches.is_empty()) + } else { + (probe_batches.is_empty(), build_batches.is_empty()) + }; + let skip = match join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => left_empty || right_empty, + JoinType::Left | JoinType::LeftMark => left_empty, + JoinType::Right => right_empty, + JoinType::Full => left_empty && right_empty, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + left_empty || right_empty + } + }; + if skip { + return Ok(()); + } + + // Check if build side is too large and needs recursive repartitioning. + let build_size: usize = build_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let build_rows: usize = build_batches.iter().map(|b| b.num_rows()).sum(); + let probe_size: usize = probe_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let probe_rows: usize = probe_batches.iter().map(|b| b.num_rows()).sum(); + let pool_reserved = context.runtime_env().memory_pool.reserved(); + info!( + "GraceHashJoin: join_partition_recursive level={}, \ + build: {} batches/{} rows/{} bytes, \ + probe: {} batches/{} rows/{} bytes, \ + pool reserved={}", + recursion_level, + build_batches.len(), + build_rows, + build_size, + probe_batches.len(), + probe_rows, + probe_size, + pool_reserved, + ); + // Repartition if the build side exceeds the target size. This prevents + // creating HashJoinExec with oversized build sides whose hash tables + // can expand well beyond the raw data size and exhaust the memory pool. + let needs_repartition = build_size > TARGET_PARTITION_BUILD_SIZE; + if needs_repartition { + info!( + "GraceHashJoin: repartition needed at level {}: \ + build_size={} > target={}, pool reserved={}", + recursion_level, + build_size, + TARGET_PARTITION_BUILD_SIZE, + context.runtime_env().memory_pool.reserved(), + ); + } + + if needs_repartition { + if recursion_level >= MAX_RECURSION_DEPTH { + let total_build_rows: usize = build_batches.iter().map(|b| b.num_rows()).sum(); + return Err(DataFusionError::ResourcesExhausted(format!( + "GraceHashJoin: build side partition is still too large after {} levels of \ + repartitioning ({} bytes, {} rows). Consider increasing \ + spark.comet.exec.graceHashJoin.numPartitions or \ + spark.executor.memory.", + MAX_RECURSION_DEPTH, build_size, total_build_rows + ))); + } + + info!( + "GraceHashJoin: repartitioning oversized partition at level {} \ + (build: {} bytes, {} batches)", + recursion_level, + build_size, + build_batches.len() + ); + + return repartition_and_join( + build_batches, + probe_batches, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + recursion_level, + streams, + ); + } + + // Concatenate sub-batches into single batches to reduce per-batch overhead + // in HashJoinExec. Per-partition data is bounded by TARGET_PARTITION_BUILD_SIZE + // so this won't hit i32 offset overflow. + let build_data = if build_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(build_schema))] + } else if build_batches.len() == 1 { + build_batches + } else { + vec![concat_batches(build_schema, &build_batches)?] + }; + let probe_data = if probe_batches.is_empty() { + vec![RecordBatch::new_empty(Arc::clone(probe_schema))] + } else if probe_batches.len() == 1 { + probe_batches + } else { + vec![concat_batches(probe_schema, &probe_batches)?] + }; + + // Create per-partition hash join. + // HashJoinExec expects left=build (CollectLeft mode). + // Both sides use StreamSourceExec to avoid DataSourceExec's BatchSplitStream. + let build_source = memory_source_exec(build_data, build_schema)?; + let probe_source = memory_source_exec(probe_data, probe_schema)?; + + let (left_source, right_source) = if build_left { + (build_source, probe_source) + } else { + (probe_source, build_source) + }; + + let pool_before_join = context.runtime_env().memory_pool.reserved(); + info!( + "GraceHashJoin: RECURSIVE PATH creating HashJoinExec at level={}, \ + build_left={}, build_size={}, probe_size={}, pool reserved={}", + recursion_level, build_left, build_size, probe_size, pool_before_join, + ); + + let stream = if build_left { + let hash_join = HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + info!( + "GraceHashJoin: RECURSIVE PATH plan (level={}):\n{}", + recursion_level, + DisplayableExecutionPlan::new(&hash_join).indent(true) + ); + hash_join.execute(0, context_for_join_output(context))? + } else { + let hash_join = Arc::new(HashJoinExec::try_new( + left_source, + right_source, + original_on.to_vec(), + filter.clone(), + join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?); + let swapped = hash_join.swap_inputs(PartitionMode::CollectLeft)?; + info!( + "GraceHashJoin: RECURSIVE PATH (swapped, level={}) plan:\n{}", + recursion_level, + DisplayableExecutionPlan::new(swapped.as_ref()).indent(true) + ); + swapped.execute(0, context_for_join_output(context))? + }; + + streams.push(stream); + Ok(()) +} + +/// Repartition build and probe batches into sub-partitions using a different +/// hash seed, then recursively join each sub-partition. +#[allow(clippy::too_many_arguments)] +fn repartition_and_join( + build_batches: Vec, + probe_batches: Vec, + original_on: JoinOnRef<'_>, + filter: &Option, + join_type: &JoinType, + build_left: bool, + build_schema: &SchemaRef, + probe_schema: &SchemaRef, + context: &Arc, + recursion_level: usize, + streams: &mut Vec, +) -> DFResult<()> { + let num_sub_partitions = DEFAULT_NUM_PARTITIONS; + + // Extract build/probe key expressions from original_on + let (build_keys, probe_keys): (Vec<_>, Vec<_>) = if build_left { + original_on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip() + } else { + original_on + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .unzip() + }; + + let mut scratch = ScratchSpace::default(); + + // Sub-partition the build side + let mut build_sub: Vec> = + (0..num_sub_partitions).map(|_| Vec::new()).collect(); + for batch in &build_batches { + let total_rows = batch.num_rows(); + scratch.compute_partitions(batch, &build_keys, num_sub_partitions, recursion_level)?; + for (i, sub_vec) in build_sub.iter_mut().enumerate() { + if scratch.partition_len(i) == 0 { + continue; + } + if scratch.partition_len(i) == total_rows { + sub_vec.push(batch.clone()); + } else if let Some(sub) = scratch.take_partition(batch, i)? { + sub_vec.push(sub); + } + } + } + + // Sub-partition the probe side + let mut probe_sub: Vec> = + (0..num_sub_partitions).map(|_| Vec::new()).collect(); + for batch in &probe_batches { + let total_rows = batch.num_rows(); + scratch.compute_partitions(batch, &probe_keys, num_sub_partitions, recursion_level)?; + for (i, sub_vec) in probe_sub.iter_mut().enumerate() { + if scratch.partition_len(i) == 0 { + continue; + } + if scratch.partition_len(i) == total_rows { + sub_vec.push(batch.clone()); + } else if let Some(sub) = scratch.take_partition(batch, i)? { + sub_vec.push(sub); + } + } + } + + // Recursively join each sub-partition + for (build_part, probe_part) in build_sub.into_iter().zip(probe_sub.into_iter()) { + join_partition_recursive( + build_part, + probe_part, + original_on, + filter, + join_type, + build_left, + build_schema, + probe_schema, + context, + recursion_level + 1, + streams, + )?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::execution::memory_pool::FairSpillPool; + use datafusion::execution::runtime_env::RuntimeEnvBuilder; + use datafusion::physical_expr::expressions::Column; + use datafusion::prelude::SessionConfig; + use datafusion::prelude::SessionContext; + use futures::TryStreamExt; + + fn make_batch(ids: &[i32], values: &[&str]) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(ids.to_vec())), + Arc::new(StringArray::from(values.to_vec())), + ], + ) + .unwrap() + } + + #[tokio::test] + async fn test_grace_hash_join_basic() -> DFResult<()> { + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + + let left_batches = vec![ + make_batch(&[1, 2, 3, 4, 5], &["a", "b", "c", "d", "e"]), + make_batch(&[6, 7, 8], &["f", "g", "h"]), + ]; + let right_batches = vec![ + make_batch(&[2, 4, 6, 8], &["x", "y", "z", "w"]), + make_batch(&[1, 3, 5, 7], &["p", "q", "r", "s"]), + ]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 4, // Use 4 partitions for testing + true, + 10 * 1024 * 1024, // 10 MB fast path threshold + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + // Count total rows - should be 8 (each left id matches exactly one right id) + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 8, "Expected 8 matching rows for inner join"); + + Ok(()) + } + + #[tokio::test] + async fn test_grace_hash_join_empty_partition() -> DFResult<()> { + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let right_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let left_batches = vec![RecordBatch::try_new( + Arc::clone(&left_schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?]; + let right_batches = vec![RecordBatch::try_new( + Arc::clone(&right_schema), + vec![Arc::new(Int32Array::from(vec![10, 20, 30]))], + )?]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 4, + true, + 10 * 1024 * 1024, // 10 MB fast path threshold + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 0, "Expected 0 rows for non-matching keys"); + + Ok(()) + } + + /// Helper to create a SessionContext with a bounded FairSpillPool. + fn context_with_memory_limit(pool_bytes: usize) -> SessionContext { + let pool = Arc::new(FairSpillPool::new(pool_bytes)); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(pool) + .build_arc() + .unwrap(); + let config = SessionConfig::new(); + SessionContext::new_with_config_rt(config, runtime) + } + + /// Generate a batch of N rows with sequential IDs and a padding string + /// column to control memory size. Each row is ~100 bytes of padding. + fn make_large_batch(start_id: i32, count: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let ids: Vec = (start_id..start_id + count as i32).collect(); + let padding = "x".repeat(100); + let vals: Vec<&str> = (0..count).map(|_| padding.as_str()).collect(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(ids)), + Arc::new(StringArray::from(vals)), + ], + ) + .unwrap() + } + + /// Test that GHJ correctly repartitions a large build side instead of + /// creating an oversized HashJoinExec hash table that OOMs. + /// + /// Setup: 256 MB memory pool, ~80 MB build side, ~10 MB probe side. + /// Without repartitioning, the hash table would be ~240 MB and could + /// exhaust the 256 MB pool. With repartitioning (32 MB threshold), + /// the build side is split into sub-partitions of ~5 MB each. + #[tokio::test] + async fn test_grace_hash_join_repartitions_large_build() -> DFResult<()> { + // 256 MB pool — tight enough that a 80 MB build → ~240 MB hash table fails + let ctx = context_with_memory_limit(256 * 1024 * 1024); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + + // Build side: ~80 MB (800K rows × ~100 bytes) + let left_batches = vec![ + make_large_batch(0, 200_000), + make_large_batch(200_000, 200_000), + make_large_batch(400_000, 200_000), + make_large_batch(600_000, 200_000), + ]; + let build_bytes: usize = left_batches.iter().map(|b| b.get_array_memory_size()).sum(); + eprintln!( + "Test build side: {} bytes ({} MB)", + build_bytes, + build_bytes / (1024 * 1024) + ); + + // Probe side: small (~1 MB, 10K rows) + let right_batches = vec![make_large_batch(0, 10_000)]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + // Disable fast path to force slow path + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 16, + true, // build_left + 0, // fast_path_threshold = 0 (disabled) + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + // All 10K probe rows match (IDs 0..10000 exist in build) + assert_eq!(total_rows, 10_000, "Expected 10000 matching rows"); + + Ok(()) + } + + /// Same test but with build_left=false to exercise the swap_inputs path. + #[tokio::test] + async fn test_grace_hash_join_repartitions_large_build_right() -> DFResult<()> { + let ctx = context_with_memory_limit(256 * 1024 * 1024); + let task_ctx = ctx.task_ctx(); + + let left_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("val", DataType::Utf8, false), + ])); + + // Probe side (left): small + let left_batches = vec![make_large_batch(0, 10_000)]; + + // Build side (right): ~80 MB + let right_batches = vec![ + make_large_batch(0, 200_000), + make_large_batch(200_000, 200_000), + make_large_batch(400_000, 200_000), + make_large_batch(600_000, 200_000), + ]; + + let left_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + Arc::clone(&left_schema), + None, + )?))); + let right_source = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + Arc::clone(&right_schema), + None, + )?))); + + let on = vec![( + Arc::new(Column::new("id", 0)) as Arc, + Arc::new(Column::new("id", 0)) as Arc, + )]; + + let grace_join = GraceHashJoinExec::try_new( + left_source, + right_source, + on, + None, + &JoinType::Inner, + 16, + false, // build_left=false → right is build side + 0, // fast_path_threshold = 0 (disabled) + )?; + + let stream = grace_join.execute(0, task_ctx)?; + let result_batches: Vec = stream.try_collect().await?; + + let total_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 10_000, "Expected 10000 matching rows"); + + Ok(()) + } +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..ed1dce219e 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -32,6 +32,8 @@ mod iceberg_scan; mod parquet_writer; pub use parquet_writer::ParquetWriterExec; mod csv_scan; +mod grace_hash_join; +pub use grace_hash_join::GraceHashJoinExec; pub mod projection; mod scan; pub use csv_scan::init_csv_datasource_exec; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 094777e796..5086e44e4f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -61,7 +61,7 @@ use datafusion::{ physical_plan::{ aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy}, empty::EmptyExec, - joins::{utils::JoinFilter, HashJoinExec, PartitionMode, SortMergeJoinExec}, + joins::{utils::JoinFilter, SortMergeJoinExec}, limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, @@ -163,6 +163,8 @@ pub struct PhysicalPlanner { exec_context_id: i64, partition: i32, session_ctx: Arc, + /// Spark configuration map, used to read comet-specific settings. + spark_conf: HashMap, } impl Default for PhysicalPlanner { @@ -177,6 +179,7 @@ impl PhysicalPlanner { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, partition, + spark_conf: HashMap::new(), } } @@ -185,9 +188,14 @@ impl PhysicalPlanner { exec_context_id, partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), + spark_conf: self.spark_conf, } } + pub fn with_spark_conf(self, spark_conf: HashMap) -> Self { + Self { spark_conf, ..self } + } + /// Return session context of this planner. pub fn session_ctx(&self) -> &Arc { &self.session_ctx @@ -1566,49 +1574,46 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); - let hash_join = Arc::new(HashJoinExec::try_new( - left, - right, - join_params.join_on, - join_params.join_filter, - &join_params.join_type, - None, - PartitionMode::Partitioned, - // null doesn't equal to null in Spark join key. If the join key is - // `EqualNullSafe`, Spark will rewrite it during planning. - NullEquality::NullEqualsNothing, - )?); - - // If the hash join is build right, we need to swap the left and right - if join.build_side == BuildSide::BuildLeft as i32 { - Ok(( - scans, - Arc::new(SparkPlan::new( - spark_plan.plan_id, - hash_join, - vec![join_params.left, join_params.right], - )), - )) - } else { - let swapped_hash_join = - hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?; + use crate::execution::spark_config::{ + SparkConfig, COMET_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD, + COMET_GRACE_HASH_JOIN_NUM_PARTITIONS, SPARK_EXECUTOR_CORES, + }; - let mut additional_native_plans = vec![]; - if swapped_hash_join.as_any().is::() { - // a projection was added to the hash join - additional_native_plans.push(Arc::clone(swapped_hash_join.children()[0])); - } + let num_partitions = self + .spark_conf + .get_usize(COMET_GRACE_HASH_JOIN_NUM_PARTITIONS, 16); + let executor_cores = + self.spark_conf.get_usize(SPARK_EXECUTOR_CORES, 1).max(1); + // The configured threshold is the total budget across all + // concurrent tasks. Divide by executor cores so each task's + // fast-path hash table stays within its fair share. + let fast_path_threshold = self + .spark_conf + .get_usize(COMET_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD, 10 * 1024 * 1024) + / executor_cores; + + let build_left = join.build_side == BuildSide::BuildLeft as i32; + + let grace_join = + Arc::new(crate::execution::operators::GraceHashJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + &join_params.join_type, + num_partitions, + build_left, + fast_path_threshold, + )?); - Ok(( - scans, - Arc::new(SparkPlan::new_with_additional( - spark_plan.plan_id, - swapped_hash_join, - vec![join_params.left, join_params.right], - additional_native_plans, - )), - )) - } + Ok(( + scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + grace_join, + vec![join_params.left, join_params.right], + )), + )) } OpStruct::Window(wnd) => { let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; @@ -3772,7 +3777,7 @@ mod tests { let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); - assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); + assert_eq!("GraceHashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); assert_eq!("ScanExec", hash_join_exec.children[0].native_plan.name()); assert_eq!("ScanExec", hash_join_exec.children[1].native_plan.name()); diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index 277c0eb43b..062437812c 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -23,6 +23,10 @@ pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.nativ pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory"; pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores"; +pub(crate) const COMET_GRACE_HASH_JOIN_NUM_PARTITIONS: &str = + "spark.comet.exec.graceHashJoin.numPartitions"; +pub(crate) const COMET_GRACE_HASH_JOIN_FAST_PATH_THRESHOLD: &str = + "spark.comet.exec.graceHashJoin.fastPathThreshold"; pub(crate) trait SparkConfig { fn get_bool(&self, name: &str) -> bool; diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala index a4d31a59ac..abbb1deaab 100644 --- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala +++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala @@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo /** @@ -64,6 +66,28 @@ object RewriteJoin extends JoinSelectionHelper { case _ => plan } + /** + * Returns true if the build side is small enough to benefit from hash join over sort-merge + * join. When both sides are large, SMJ's streaming merge on pre-sorted data can outperform hash + * join's per-task hash table construction. + */ + private def buildSideSmallEnough(smj: SortMergeJoinExec, buildSide: BuildSide): Boolean = { + val maxBuildSize = CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.get() + if (maxBuildSize <= 0) { + return true // no limit + } + smj.logicalLink match { + case Some(join: Join) => + val buildSize = buildSide match { + case BuildLeft => join.left.stats.sizeInBytes + case BuildRight => join.right.stats.sizeInBytes + } + buildSize <= maxBuildSize + case _ => + true // no stats available, allow the rewrite + } + } + def rewrite(plan: SparkPlan): SparkPlan = plan match { case smj: SortMergeJoinExec => getSmjBuildSide(smj) match { @@ -75,6 +99,12 @@ object RewriteJoin extends JoinSelectionHelper { "Cannot rewrite SortMergeJoin to HashJoin: " + s"BuildRight with ${smj.joinType} is not supported") plan + case Some(buildSide) if !buildSideSmallEnough(smj, buildSide) => + withInfo( + smj, + "Cannot rewrite SortMergeJoin to HashJoin: " + + "build side exceeds spark.comet.exec.replaceSortMergeJoin.maxBuildSize") + plan case Some(buildSide) => ShuffledHashJoinExec( smj.leftKeys, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..2d2222129c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -225,6 +225,33 @@ object CometMetricNode { "join_time" -> SQLMetrics.createNanoTimingMetric(sc, "Total time for joining")) } + /** + * SQL Metrics for GraceHashJoin + */ + def graceHashJoinMetrics(sc: SparkContext): Map[String, SQLMetric] = { + Map( + "build_time" -> + SQLMetrics.createNanoTimingMetric(sc, "Total time for partitioning build-side"), + "probe_time" -> + SQLMetrics.createNanoTimingMetric(sc, "Total time for partitioning probe-side"), + "join_time" -> + SQLMetrics.createNanoTimingMetric(sc, "Total time for per-partition joins"), + "spill_count" -> SQLMetrics.createMetric(sc, "Count of spills"), + "spilled_bytes" -> SQLMetrics.createSizeMetric(sc, "Total spilled bytes"), + "build_input_rows" -> + SQLMetrics.createMetric(sc, "Number of rows consumed by build-side"), + "build_input_batches" -> + SQLMetrics.createMetric(sc, "Number of batches consumed by build-side"), + "input_rows" -> + SQLMetrics.createMetric(sc, "Number of rows consumed by probe-side"), + "input_batches" -> + SQLMetrics.createMetric(sc, "Number of batches consumed by probe-side"), + "output_batches" -> SQLMetrics.createMetric(sc, "Number of batches produced"), + "output_rows" -> SQLMetrics.createMetric(sc, "Number of rows produced"), + "elapsed_compute" -> + SQLMetrics.createNanoTimingMetric(sc, "Total elapsed compute time")) + } + /** * SQL Metrics for DataFusion SortMergeJoin */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..5c3d1919c7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1724,7 +1724,7 @@ object CometHashJoinExec extends CometOperatorSerde[HashJoin] with CometHashJoin doConvert(join, builder, childOp: _*) override def createExec(nativeOp: Operator, op: HashJoin): CometNativeExec = { - CometHashJoinExec( + CometGraceHashJoinExec( nativeOp, op, op.output, @@ -1795,6 +1795,61 @@ case class CometHashJoinExec( CometMetricNode.hashJoinMetrics(sparkContext) } +case class CometGraceHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + override val output: Seq[Attribute], + override val outputOrdering: Seq[SortOrder], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException(s"GraceHashJoin should not take $x as the JoinType") + } + + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, buildSide, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometGraceHashJoinExec => + this.output == other.output && + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(output, leftKeys, rightKeys, condition, buildSide, left, right) + + override lazy val metrics: Map[String, SQLMetric] = + CometMetricNode.graceHashJoinMetrics(sparkContext) +} + case class CometBroadcastHashJoinExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 6111b9c0d4..b476297dcf 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -19,17 +19,20 @@ package org.apache.comet.exec +import scala.util.Random + import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometGraceHashJoinExec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.{DataTypes, Decimal, StructField, StructType} import org.apache.comet.CometConf +import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometJoinSuite extends CometTestBase { import testImplicits._ @@ -446,4 +449,253 @@ class CometJoinSuite extends CometTestBase { """.stripMargin)) } } + + // Common SQL config for Grace Hash Join tests + private val graceHashJoinConf: Seq[(String, String)] = Seq( + CometConf.COMET_EXEC_GRACE_HASH_JOIN_NUM_PARTITIONS.key -> "4", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") + + test("Grace HashJoin - all join types") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 100).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Right join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Full outer join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left semi join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT SEMI JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left anti join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT ANTI JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + } + } + } + } + + test("Grace HashJoin - with filter condition") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 100).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i % 10, i + 2)), "tbl_b") { + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")) + + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")) + + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")) + } + } + } + } + + test("Grace HashJoin - various data types") { + withSQLConf(graceHashJoinConf: _*) { + // String keys + withParquetTable((0 until 50).map(i => (s"key_${i % 10}", i)), "str_a") { + withParquetTable((0 until 50).map(i => (s"key_${i % 5}", i * 2)), "str_b") { + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(str_a) */ * FROM str_a JOIN str_b ON str_a._1 = str_b._1")) + } + } + + // Decimal keys + withParquetTable((0 until 50).map(i => (Decimal(i % 10), i)), "dec_a") { + withParquetTable((0 until 50).map(i => (Decimal(i % 5), i * 2)), "dec_b") { + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(dec_a) */ * FROM dec_a JOIN dec_b ON dec_a._1 = dec_b._1")) + } + } + } + } + + test("Grace HashJoin - empty tables") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable(Seq.empty[(Int, Int)], "empty_a") { + withParquetTable((0 until 10).map(i => (i, i)), "nonempty_b") { + // Empty left side + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(empty_a) */ * FROM empty_a JOIN nonempty_b ON empty_a._1 = nonempty_b._1")) + + // Empty left with left join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(empty_a) */ * FROM empty_a LEFT JOIN nonempty_b ON empty_a._1 = nonempty_b._1")) + + // Empty right side + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(nonempty_b) */ * FROM nonempty_b JOIN empty_a ON nonempty_b._1 = empty_a._1")) + + // Empty right with right join + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(nonempty_b) */ * FROM nonempty_b RIGHT JOIN empty_a ON nonempty_b._1 = empty_a._1")) + } + } + } + } + + test("Grace HashJoin - self join") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 50).map(i => (i, i % 10)), "self_tbl") { + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(a) */ * FROM self_tbl a JOIN self_tbl b ON a._2 = b._2")) + } + } + } + + test("Grace HashJoin - build side selection") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 100).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 100).map(i => (i % 10, i + 2)), "tbl_b") { + // Build left (hint on left table) + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Build right (hint on right table) + checkSparkAnswer( + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Left join build right + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + + // Right join build left + checkSparkAnswer(sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1")) + } + } + } + } + + test("Grace HashJoin - plan shows CometGraceHashJoinExec") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 50).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 50).map(i => (i % 10, i + 2)), "tbl_b") { + val df = sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator(df, Seq(classOf[CometGraceHashJoinExec])) + } + } + } + } + + test("Grace HashJoin - multiple key columns") { + withSQLConf(graceHashJoinConf: _*) { + withParquetTable((0 until 50).map(i => (i, i % 5, i % 3)), "multi_a") { + withParquetTable((0 until 50).map(i => (i % 10, i % 5, i % 3)), "multi_b") { + checkSparkAnswer( + sql("SELECT /*+ SHUFFLE_HASH(multi_a) */ * FROM multi_a JOIN multi_b " + + "ON multi_a._2 = multi_b._2 AND multi_a._3 = multi_b._3")) + } + } + } + } + + // Schema with types that work well as join keys (no NaN/float issues) + private val fuzzJoinSchema = StructType( + Seq( + StructField("c_int", DataTypes.IntegerType), + StructField("c_long", DataTypes.LongType), + StructField("c_str", DataTypes.StringType), + StructField("c_date", DataTypes.DateType), + StructField("c_dec", DataTypes.createDecimalType(10, 2)), + StructField("c_short", DataTypes.ShortType), + StructField("c_bool", DataTypes.BooleanType))) + + private val joinTypes = + Seq("JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN", "LEFT SEMI JOIN", "LEFT ANTI JOIN") + + test("Grace HashJoin fuzz - all join types with generated data") { + val dataGenOptions = + DataGenOptions(allowNull = true, generateNegativeZero = false, generateNaN = false) + + withSQLConf(graceHashJoinConf: _*) { + withTempPath { dir => + val path1 = s"${dir.getAbsolutePath}/fuzz_left" + val path2 = s"${dir.getAbsolutePath}/fuzz_right" + val random = new Random(42) + + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator + .makeParquetFile(random, spark, path1, fuzzJoinSchema, 200, dataGenOptions) + ParquetGenerator + .makeParquetFile(random, spark, path2, fuzzJoinSchema, 200, dataGenOptions) + } + + spark.read.parquet(path1).createOrReplaceTempView("fuzz_l") + spark.read.parquet(path2).createOrReplaceTempView("fuzz_r") + + for (jt <- joinTypes) { + // Join on int column + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(fuzz_l) */ * FROM fuzz_l $jt fuzz_r ON fuzz_l.c_int = fuzz_r.c_int")) + + // Join on string column + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(fuzz_l) */ * FROM fuzz_l $jt fuzz_r ON fuzz_l.c_str = fuzz_r.c_str")) + + // Join on decimal column + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(fuzz_l) */ * FROM fuzz_l $jt fuzz_r ON fuzz_l.c_dec = fuzz_r.c_dec")) + } + } + } + } + + test("Grace HashJoin fuzz - with spilling") { + val dataGenOptions = + DataGenOptions(allowNull = true, generateNegativeZero = false, generateNaN = false) + + // Use very small memory pool to force spilling + withSQLConf( + (graceHashJoinConf ++ Seq( + CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key -> "10000000", + CometConf.COMET_EXEC_GRACE_HASH_JOIN_NUM_PARTITIONS.key -> "8")): _*) { + withTempPath { dir => + val path1 = s"${dir.getAbsolutePath}/spill_left" + val path2 = s"${dir.getAbsolutePath}/spill_right" + val random = new Random(99) + + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator + .makeParquetFile(random, spark, path1, fuzzJoinSchema, 500, dataGenOptions) + ParquetGenerator + .makeParquetFile(random, spark, path2, fuzzJoinSchema, 500, dataGenOptions) + } + + spark.read.parquet(path1).createOrReplaceTempView("spill_l") + spark.read.parquet(path2).createOrReplaceTempView("spill_r") + + for (jt <- joinTypes) { + checkSparkAnswer(sql( + s"SELECT /*+ SHUFFLE_HASH(spill_l) */ * FROM spill_l $jt spill_r ON spill_l.c_int = spill_r.c_int")) + } + } + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinBenchmark.scala new file mode 100644 index 0000000000..01b413de15 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometJoinBenchmark.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.{CometConf, CometSparkSessionExtensions} + +/** + * Benchmark to compare join implementations: Spark Sort Merge Join, Comet Sort Merge Join, Comet + * Hash Join, and Comet Grace Hash Join across all join types. + * + * To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make \ + * benchmark-org.apache.spark.sql.benchmark.CometJoinBenchmark + * }}} + * + * Results will be written to "spark/benchmarks/CometJoinBenchmark-**results.txt". + */ +object CometJoinBenchmark extends CometBenchmarkBase { + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometJoinBenchmark") + .set("spark.master", "local[5]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set("spark.executor.memoryOverhead", "10g") + .set( + "spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + + val sparkSession = SparkSession.builder + .config(conf) + .withExtensions(new CometSparkSessionExtensions) + .getOrCreate() + + sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_ONHEAP_MEMORY_OVERHEAD.key, "10g") + sparkSession.conf.set("parquet.enable.dictionary", "false") + sparkSession.conf.set("spark.sql.shuffle.partitions", "2") + + sparkSession + } + + /** Base Comet exec config — shuffle mode auto, no SMJ replacement by default. */ + private val cometBaseConf = Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "auto", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") + + private def prepareTwoTables(dir: java.io.File, rows: Int, keyCardinality: Int): Unit = { + val left = spark + .range(rows) + .selectExpr( + s"id % $keyCardinality as key", + "id as l_val1", + "cast(id * 1.5 as double) as l_val2") + prepareTable(dir, left) + spark.read.parquet(dir.getCanonicalPath + "/parquetV1").createOrReplaceTempView("left_table") + + val rightDir = new java.io.File(dir, "right") + rightDir.mkdirs() + val right = spark + .range(rows) + .selectExpr( + s"id % $keyCardinality as key", + "id as r_val1", + "cast(id * 2.5 as double) as r_val2") + right.write + .mode("overwrite") + .option("compression", "snappy") + .parquet(rightDir.getCanonicalPath) + spark.read.parquet(rightDir.getCanonicalPath).createOrReplaceTempView("right_table") + } + + private def addJoinCases(benchmark: Benchmark, query: String): Unit = { + // 1. Spark Sort Merge Join (baseline — no Comet) + benchmark.addCase("Spark Sort Merge Join") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + "spark.sql.join.preferSortMergeJoin" -> "true") { + spark.sql(query).noop() + } + } + + // 2. Comet Sort Merge Join (Spark plans SMJ, Comet executes it natively) + benchmark.addCase("Comet Sort Merge Join") { _ => + withSQLConf( + (cometBaseConf ++ Map( + CometConf.COMET_REPLACE_SMJ.key -> "false", + "spark.sql.join.preferSortMergeJoin" -> "true")).toSeq: _*) { + spark.sql(query).noop() + } + } + + // 3. Comet Grace Hash Join (replace SMJ with ShuffledHashJoin, Comet executes with GHJ) + benchmark.addCase("Comet Grace Hash Join") { _ => + withSQLConf( + (cometBaseConf ++ Map( + CometConf.COMET_REPLACE_SMJ.key -> "true")).toSeq: _*) { + spark.sql(query).noop() + } + } + } + + private def joinBenchmark(joinType: String, rows: Int, keyCardinality: Int): Unit = { + val joinClause = joinType match { + case "Inner" => "JOIN" + case "Left" => "LEFT JOIN" + case "Right" => "RIGHT JOIN" + case "Full" => "FULL OUTER JOIN" + case "LeftSemi" => "LEFT SEMI JOIN" + case "LeftAnti" => "LEFT ANTI JOIN" + } + + val selectCols = joinType match { + case "LeftSemi" | "LeftAnti" => "l.key, l.l_val1, l.l_val2" + case _ => "l.key, l.l_val1, r.r_val1" + } + + val query = + s"SELECT $selectCols FROM left_table l $joinClause right_table r ON l.key = r.key" + + val benchmark = + new Benchmark( + s"$joinType Join (rows=$rows, cardinality=$keyCardinality)", + rows, + output = output) + + addJoinCases(benchmark, query) + benchmark.run() + } + + private def joinWithFilterBenchmark(rows: Int, keyCardinality: Int): Unit = { + val query = + "SELECT l.key, l.l_val1, r.r_val1 FROM left_table l " + + "JOIN right_table r ON l.key = r.key WHERE l.l_val1 > r.r_val1" + + val benchmark = + new Benchmark( + s"Inner Join with Filter (rows=$rows, cardinality=$keyCardinality)", + rows, + output = output) + + addJoinCases(benchmark, query) + benchmark.run() + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val rows = 1024 * 1024 * 2 + val keyCardinality = rows / 10 // ~10 matches per key + + withTempPath { dir => + prepareTwoTables(dir, rows, keyCardinality) + + runBenchmark("Join Benchmark") { + for (joinType <- Seq("Inner", "Left", "Right", "Full", "LeftSemi", "LeftAnti")) { + joinBenchmark(joinType, rows, keyCardinality) + } + joinWithFilterBenchmark(rows, keyCardinality) + } + } + } +}