Skip to content

[Proposal] PaddleSpatial SA-GNN Integration for GeaFlow-Infer #776

@kitalkuyo-gita

Description

@kitalkuyo-gita

1. Background and Motivation

1.1 Problem Statement

GeaFlow's existing inference subsystem (geaflow-infer) was built exclusively around the PyTorch framework. The entire execution path — from environment installation (install-infer-env.sh) to the Python inference session (inferSession.py) — hardcodes PyTorch assumptions. As a result, users who wish to deploy models from the PaddlePaddle ecosystem are completely blocked:

  • install-infer-env.sh only installs PyTorch and its dependencies.
  • infer_server.py instantiates TorchInferSession unconditionally.
  • There is no abstraction layer separating "session protocol" from "framework implementation".
  • Graph algorithms in GeaFlow DSL have no access to spatial GNN models such as those provided by PaddleSpatial.

Meanwhile, PaddlePaddle is widely adopted in China's industrial AI ecosystem, and PaddleSpatial provides production-grade spatial graph neural network algorithms — notably the Spatial Adaptive GNN (SA-GNN) — that are directly applicable to location-aware graph analytics (ride-hailing networks, urban road graphs, spatial knowledge graphs, etc.).

1.2 Motivation: Why SA-GNN?

SA-GNN (Spatial Adaptive Graph Neural Network) improves upon direction-agnostic GNNs such as GraphSAGE by:

  1. Partitioning neighbours into directional spatial sectors (e.g., 8 compass directions) based on coordinate angles, then aggregating each sector independently (SpatialOrientedAGG).
  2. Performing degree-normalised local aggregation (SpatialLocalAGG) as a fast, GCN-like first layer.
  3. Encoding location as a first-class feature, enabling the model to distinguish a node's north neighbours from its south neighbours — a distinction invisible to standard GNNs.
Model Spatial awareness Multi-hop Framework
GraphSAGE (existing) None Yes PyTorch
SA-GNN (this proposal) Directional sectors Yes PaddlePaddle / PGL

1.3 Design Goals

  1. G-1 Add first-class PaddlePaddle support to geaflow-infer without breaking existing PyTorch workflows.
  2. G-2 Introduce a framework-agnostic inference session abstraction (BaseInferSession) so future frameworks (TensorFlow, MindSpore, etc.) can be plugged in with minimal effort.
  3. G-3 Implement the SA-GNN algorithm from PaddleSpatial as a built-in GQL graph algorithm (SAGNN), callable via CALL SAGNN(...).
  4. G-4 Allow flexible Python runtime selection: users may use a Miniconda-managed virtual environment or a system-installed Python (useful for local development and CI).
  5. G-5 All changes must be backward compatible: existing jobs using geaflow.infer.framework.type=TORCH (or unset) must continue to work without modification.

1.4 Non-Goals

  • Replacing or deprecating PyTorch support.
  • Supporting PaddlePaddle static-graph inference (Paddle Inference / TensorRT) in this proposal — only dynamic eager mode is shipped initially.
  • Training models inside GeaFlow; this proposal is for inference only.
  • Multi-GPU distributed inference inside a single Python worker process.

2. Constraints

2.1 Architectural Constraints

C-1: The Python inference sub-process must remain a single-threaded worker loop.

infer_server.py runs a tight while True loop reading from shared memory and writing results back. All session implementations must be safe to call from this loop and must not spawn additional threads.

C-2: Data crossing the Java–Python bridge must be pickle-serialisable.

The PicklerDataBridger uses Python's pickle protocol. paddle.Tensor is not directly picklable; all tensor data must be converted to numpy arrays or Python-native types before crossing the bridge. This is enforced by PaddleInferSession._coerce_to_native().

C-3: The TransFormFunction interface contract must not change.

Existing PyTorch UDFs implement load_model(), transform_pre(), and transform_post(). The new BaseInferSession abstraction must accept any conforming class without requiring UDF authors to change their code.

C-4: The Java InferContext / InferEnvironmentManager lifecycle must remain single-instance per JVM process.

InferEnvironmentManager is a process-level singleton (guarded by AtomicBoolean INITIALIZED). The Paddle path must work within this constraint.

2.2 Compatibility Constraints

C-5: Default framework is TORCH — zero configuration change required for existing jobs.

The new config key geaflow.infer.framework.type defaults to "TORCH". All code paths that handle null or empty values fall back to TORCH.

C-6: The --tfClassName CLI flag is kept for backward compatibility.

infer_server.py still accepts --tfClassName; --modelClassName is the new preferred alias. The resolution order is: modelClassNametfClassName.

C-7: System Python mode is opt-in.

geaflow.infer.env.use.system.python defaults to false. When false, the existing Miniconda virtual environment path is used unchanged.


3. Current State Analysis

3.1 Existing Inference Subsystem (before this change)

geaflow-infer/
├── InferEnvironmentManager.java   # Creates Miniconda venv, runs install shell
├── InferEnvironmentContext.java   # Holds paths (pythonExec, inferScript, libPath)
├── InferContext.java              # Drives Java→Python lifecycle
└── resources/infer/
    ├── env/install-infer-env.sh   # Shell: downloads Miniconda, installs requirements.txt
    └── inferRuntime/
        ├── infer_server.py        # Python entry point (hardcoded TorchInferSession)
        └── inferSession.py        # TorchInferSession (standalone, no base class)

Problems identified:

  • infer_server.py:56 instantiates TorchInferSession unconditionally — no dispatch mechanism.
  • inferSession.py has no abstract interface — copy-paste would be required to add any new framework.
  • install-infer-env.sh has no conditional branch for non-PyTorch frameworks.
  • FrameworkConfigKeys has no keys related to framework selection.

3.2 Existing GQL Graph Algorithms

GraphSAGE (GraphSAGE.java) is the only GNN algorithm registered in BuildInSqlFunctionTable. It uses InferContextPool backed exclusively by PyTorch. SA-GNN is a distinct algorithm family that cannot be expressed as a configuration variant of GraphSAGE due to its direction-aware aggregation and PGL graph primitives.


4. Design

4.1 Architecture Overview

GeaFlow Job (Java)
       │
       │  config: geaflow.infer.framework.type = PADDLE
       ▼
InferEnvironmentManager
  ├── (TORCH path) → install-infer-env.sh [PyTorch]  → inferFiles/
  └── (PADDLE path)→ install-infer-env.sh [Paddle]   → inferFiles/
                          │  install_paddlepaddle()
                          │  install_requirements(requirements_paddle.txt)
                          ▼
InferContext
  └── runs: python3 infer_server.py
              --modelClassName=SAGNNTransFormFunction
              --framework=PADDLE
              --input_queue_shm_id=...
              --output_queue_shm_id=...

infer_server.py
  └── _create_infer_session(framework="PADDLE", transform_class)
        └── PaddleInferSession(transform_class)   ← new
              │  inherits BaseInferSession         ← new
              ▼
        SAGNNTransFormFunction (user UDF)
              │  uses PGL mini-graph
              │  calls SAGNNModel.forward()
              ▼
        List[float] embedding  →  pickle bridge  →  Java List<Double>

4.2 Python Session Layer Refactor

A new abstract base class is introduced:

# baseInferSession.py (new)
class BaseInferSession(abc.ABC):
    def __init__(self, transform_class): ...

    @abc.abstractmethod
    def run(self, *inputs): ...

Existing PyTorch session becomes a concrete subclass:

# inferSession.py (updated)
class TorchInferSession(BaseInferSession):
    def run(self, *inputs):
        a, b = self._transform.transform_pre(*inputs)
        return self._transform.transform_post(a)

New PaddlePaddle session:

# paddleInferSession.py (new)
class PaddleInferSession(BaseInferSession):
    def run(self, *inputs):
        pre_result, aux = self._transform.transform_pre(*inputs)
        post_result = self._transform.transform_post(pre_result)
        return self._coerce_to_native(post_result)   # Tensor → list

    @staticmethod
    def _coerce_to_native(obj):
        # Recursively unwraps paddle.Tensor → Python list
        ...

Framework dispatch in infer_server.py:

def _create_infer_session(framework, transform_class):
    if framework.upper() == "PADDLE":
        from paddleInferSession import PaddleInferSession
        return PaddleInferSession(transform_class)
    else:
        from inferSession import TorchInferSession
        return TorchInferSession(transform_class)

4.3 Environment Bootstrap

install-infer-env.sh receives two new arguments:

Position Name Default
$4 FRAMEWORK_TYPE TORCH
$5 PADDLE_GPU_ENABLE false
$6 PADDLE_CUDA_VERSION 11.7

When FRAMEWORK_TYPE=PADDLE, the script calls install_paddlepaddle() before install_requirements(), because pgl and paddlespatial depend on paddlepaddle being present at install time.

GPU wheel selection:

if [[ "${PADDLE_GPU_ENABLE}" == "true" ]]; then
    cuda_postfix=$(echo "${PADDLE_CUDA_VERSION}" | tr -d '.')   # "11.7" → "117"
    PADDLE_WHEEL="paddlepaddle-gpu==2.6.0.post${cuda_postfix}"
else
    PADDLE_WHEEL="paddlepaddle==2.6.0"
fi

A dedicated requirements_paddle.txt is provided for PGL and PaddleSpatial dependencies:

pgl>=2.2.4
paddlespatial>=0.1.0
numpy>=1.21.0,<2.0.0
scipy>=1.7.0
psutil>=5.9.0

4.4 Java Configuration Layer

InferContext.runInferTask() now passes --framework to the Python process:

String frameworkType = config.getString(INFER_FRAMEWORK_TYPE);
if (frameworkType == null || frameworkType.isEmpty()) frameworkType = "TORCH";
runCommands.add(inferEnvironmentContext.getInferFrameworkParam(frameworkType));

InferEnvironmentManager.createInferVirtualEnv() forwards the new arguments to the shell script:

execParams.add(frameworkType.toUpperCase());
execParams.add(String.valueOf(paddleGpu));
execParams.add(cudaVersion);

A system Python bypass (INFER_ENV_USE_SYSTEM_PYTHON = true) skips Miniconda installation entirely, using a pre-installed Python interpreter on the host. This is useful for development environments and CI pipelines where conda overhead is undesirable.

4.5 SA-GNN Algorithm (SAGNN.java)

The SAGNN class implements AlgorithmUserFunction and is callable via GQL:

CALL SAGNN([numSamples, [numLayers]]) YIELD (vid, embedding)

Algorithm iterations:

Iteration Action
1 Each vertex samples up to numSamples neighbours and broadcasts its feature vector.
2 Collect received features into neighbourFeatureCache; re-broadcast own features.
3 … numLayers+1 Invoke Python SA-GNN model with vertex features + cached neighbour features; store resulting embedding.

Feature vector convention:

The last 2 elements of the vertex feature vector are treated as (coord_x, coord_y) by the Python model. All preceding elements are semantic features. If a vertex has fewer than 64 features, the vector is zero-padded; if more, it is truncated to 64 dimensions.

4.6 SA-GNN Python Model (PaddleSpatialSAGNNTransFormFunctionUDF.py)

The UDF implements the full SA-GNN architecture using PGL primitives:

SAGNNModel
├── SpatialLocalAGG   (GCN-like, degree-normalised message passing)
├── SpatialOrientedAGG
│   ├── _partition_edges_by_sector()  (angle-based directional bucketing)
│   └── 9 × SpatialLocalAGG sub-convolutions (8 sectors + 1 catch-all)
└── Linear projection  (hidden_dim → output_dim)

Mini-graph construction per inference call:

  • Node 0: the centre vertex
  • Nodes 1..K: its sampled neighbours (K = numSamples)
  • Edges: directed (i → 0) for i in 1..K (neighbours → centre for message passing)
  • Node features: (num_nodes, input_dim) float32 array
  • Node coordinates: (num_nodes, 2) float32 array (passed via graph.node_feat['coord'])

5. Configuration Reference

5.1 New Configuration Keys

Key Type Default Description
geaflow.infer.framework.type String TORCH Inference framework: TORCH or PADDLE
geaflow.infer.env.paddle.gpu.enable Boolean false Install paddlepaddle-gpu instead of CPU-only
geaflow.infer.env.paddle.cuda.version String 11.7 CUDA version for GPU wheel (e.g. 11.7, 12.0)
geaflow.infer.env.use.system.python Boolean false Skip Miniconda, use system Python instead
geaflow.infer.env.system.python.path String (none) Absolute path to system Python (e.g. /usr/bin/python3)

5.2 Existing Keys (unchanged, required for SAGNN)

Key Required value for SAGNN
geaflow.infer.env.enable true
geaflow.infer.env.user.transform.classname SAGNNTransFormFunction
geaflow.infer.env.conda.url miniconda installer URL (if not using system Python)

5.3 Minimal Configuration Example

# Required
geaflow.infer.env.enable=true
geaflow.infer.framework.type=PADDLE
geaflow.infer.env.user.transform.classname=SAGNNTransFormFunction
geaflow.infer.env.conda.url=https://example.com/Miniconda3-latest-Linux-x86_64.sh

# Optional: GPU
geaflow.infer.env.paddle.gpu.enable=true
geaflow.infer.env.paddle.cuda.version=11.7

6. Changed Files

6.1 New Files

File Description
geaflow-infer/…/inferRuntime/baseInferSession.py Abstract base class for all inference sessions
geaflow-infer/…/inferRuntime/paddleInferSession.py PaddlePaddle inference session implementation
geaflow-infer/…/inferRuntime/requirements_paddle.txt Pip requirements for PGL / PaddleSpatial
geaflow-dsl-plan/…/udf/graph/SAGNN.java SA-GNN GQL algorithm UDF
geaflow-dsl-plan/…/resources/PaddleSpatialSAGNNTransFormFunctionUDF.py Reference user UDF for SA-GNN

6.2 Modified Files

File Change summary
FrameworkConfigKeys.java Added 5 new config keys (INFER_FRAMEWORK_TYPE, INFER_ENV_PADDLE_*, INFER_ENV_USE_SYSTEM_PYTHON, INFER_ENV_SYSTEM_PYTHON_PATH)
InferContext.java Pass --framework to Python sub-process; improve env-ready polling with ScheduledExecutorService
InferEnvironmentContext.java Support system Python path resolution; add getInferFrameworkParam() and getInferModelClassNameParam()
InferEnvironmentManager.java Add constructSystemPythonEnvironment(); forward Paddle args to shell script
install-infer-env.sh Add install_paddlepaddle() function; accept $4–$6 args; branch on FRAMEWORK_TYPE
infer_server.py Add --framework CLI arg; implement _create_infer_session() factory; accept --modelClassName alias
inferSession.py Refactor TorchInferSession to extend BaseInferSession
BuildInSqlFunctionTable.java Register SAGNN as a built-in graph algorithm

7. API / GQL Changes

7.1 New GQL Syntax

-- Basic usage (10 neighbours, 2 layers by default)
CALL SAGNN() YIELD (vid, embedding)

-- With custom parameters
CALL SAGNN(20, 3) YIELD (vid, embedding)
-- Parameters:
--   arg[0]: numSamples (int, default 10) — neighbours to sample per vertex
--   arg[1]: numLayers  (int, default 2)  — number of SA-GNN layers

Output schema: (vid ANY, embedding STRING) — embedding is a JSON-serialised List<Double>.

7.2 TransFormFunction Interface (unchanged)

Existing UDFs are not affected. The three-method contract (load_model, transform_pre, transform_post) remains the extension point for both frameworks:

class TransFormFunction(abc.ABC):
    def __init__(self, input_size: int): ...
    def load_model(self, *args): ...
    def transform_pre(self, *args): ...
    def transform_post(self, *args): ...

8. Testing Plan

8.1 Unit Tests

  • PaddleInferSession._coerce_to_native(): verify paddle.Tensor, nested lists, dicts, and scalars are all correctly coerced to Python-native types.
  • SAGNNTransFormFunction._split_feat_coord(): verify padding, truncation, and edge cases (empty vector, vector shorter than coord_dim).
  • SAGNNTransFormFunction._build_mini_graph(): verify graph has correct node count, edge directions, and self-loop fallback for isolated nodes.
  • InferEnvironmentContext: verify system Python path detection for /opt/homebrew/bin/python3, /usr/bin/python3.

8.2 Integration Tests

Test class Coverage
SAGNNAlgorithmTest End-to-end GQL query with mock Python process
SAGNNInferIntegrationTest Full Java→Python→Java round-trip using system Python and a randomly-initialised SA-GNN model

Test data files:

  • data/sagnn_vertex.txt: vertex features including 2 coordinate columns at the end
  • data/sagnn_edge.txt: edge list
  • expect/gql_sagnn_001.txt, expect/gql_sagnn_002.txt: expected output snapshots

9. Rollout Plan

  1. Phase 1 (this PR): Merge all framework-layer changes + SAGNN algorithm + baseInferSession abstraction.
  2. Phase 2 (follow-up): Add paddlespatial static inference mode (paddle.inference.create_predictor) for production performance tuning.
  3. Phase 3 (future): Generalise the framework dispatch to support additional backends (TensorFlow Lite, ONNX Runtime) using the same BaseInferSession interface.

10. References

  • PaddleSpatial GitHub
  • PGL (Paddle Graph Learning)
  • SA-GNN paper: Spatial Adaptive Graph Neural Network for Location-Aware Services (Baidu Research)
  • GeaFlow-Infer design doc: geaflow/geaflow-infer/README.md
  • Related issue: GeaFlow GraphSAGE integration (for design reference)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions