diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake
index d40ae17e40545..d59c944c8926f 100644
--- a/cmake/onnxruntime_providers_vitisai.cmake
+++ b/cmake/onnxruntime_providers_vitisai.cmake
@@ -19,7 +19,16 @@
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
)
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs})
- onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs})
+ set(onnxruntime_providers_vitisai_all_srcs ${onnxruntime_providers_vitisai_cc_srcs})
+ if(WIN32)
+ # Sets the DLL version info on Windows: https://learn.microsoft.com/en-us/windows/win32/menurc/versioninfo-resource
+ list(APPEND onnxruntime_providers_vitisai_all_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_providers_vitisai.rc")
+ endif()
+ onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_all_srcs})
+ if(WIN32)
+ # FILE_NAME preprocessor definition is used in onnxruntime_providers_vitisai.rc
+ target_compile_definitions(onnxruntime_providers_vitisai PRIVATE FILE_NAME=\"onnxruntime_providers_vitisai.dll\")
+ endif()
onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} ${GSL_TARGET} safeint_interface flatbuffers::flatbuffers Boost::mp11)
target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED} ${ABSEIL_LIBS})
if(MSVC)
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index d4e023b0f86a0..9ae3e79d86443 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1786,7 +1786,7 @@ endif()
endif()
endif()
-if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
+if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_CUDA_MINIMAL)
set(custom_op_src_patterns
"${TEST_SRC_DIR}/testdata/custom_op_library/*.h"
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 81d4f2589151b..fa1914f2a927b 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -452,6 +452,36 @@ public struct OrtApi
public IntPtr Graph_GetModelMetadata;
public IntPtr GetModelCompatibilityForEpDevices;
public IntPtr CreateExternalInitializerInfo;
+
+ // v1.24 APIs
+ public IntPtr TensorTypeAndShape_HasShape;
+ public IntPtr KernelInfo_GetConfigEntries;
+ public IntPtr KernelInfo_GetOperatorDomain;
+ public IntPtr KernelInfo_GetOperatorType;
+ public IntPtr KernelInfo_GetOperatorSinceVersion;
+ public IntPtr GetInteropApi;
+ public IntPtr SessionGetEpDeviceForOutputs;
+ public IntPtr GetNumHardwareDevices;
+ public IntPtr GetHardwareDevices;
+ public IntPtr GetHardwareDeviceEpIncompatibilityDetails;
+ public IntPtr DeviceEpIncompatibilityDetails_GetReasonsBitmask;
+ public IntPtr DeviceEpIncompatibilityDetails_GetNotes;
+ public IntPtr DeviceEpIncompatibilityDetails_GetErrorCode;
+ public IntPtr ReleaseDeviceEpIncompatibilityDetails;
+ public IntPtr GetCompatibilityInfoFromModel;
+ public IntPtr GetCompatibilityInfoFromModelBytes;
+ public IntPtr CreateEnvWithOptions;
+ public IntPtr Session_GetEpGraphAssignmentInfo;
+ public IntPtr EpAssignedSubgraph_GetEpName;
+ public IntPtr EpAssignedSubgraph_GetNodes;
+ public IntPtr EpAssignedNode_GetName;
+ public IntPtr EpAssignedNode_GetDomain;
+ public IntPtr EpAssignedNode_GetOperatorType;
+ public IntPtr RunOptionsSetSyncStream;
+ public IntPtr GetTensorElementTypeAndShapeDataReference;
+ // v1.25 APIs
+ public IntPtr RunOptionsEnableProfiling;
+ public IntPtr RunOptionsDisableProfiling;
}
internal static class NativeMethods
@@ -884,6 +914,16 @@ static NativeMethods()
(DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer(
api_.CopyTensors,
typeof(DOrtCopyTensors));
+
+ OrtGetCompatibilityInfoFromModel =
+ (DOrtGetCompatibilityInfoFromModel)Marshal.GetDelegateForFunctionPointer(
+ api_.GetCompatibilityInfoFromModel,
+ typeof(DOrtGetCompatibilityInfoFromModel));
+
+ OrtGetCompatibilityInfoFromModelBytes =
+ (DOrtGetCompatibilityInfoFromModelBytes)Marshal.GetDelegateForFunctionPointer(
+ api_.GetCompatibilityInfoFromModelBytes,
+ typeof(DOrtGetCompatibilityInfoFromModelBytes));
}
internal class NativeLib
@@ -3092,6 +3132,31 @@ public delegate IntPtr DOrtEpSelectionDelegate(
public static DOrtReleasePrepackedWeightsContainer OrtReleasePrepackedWeightsContainer;
+ ///
+ /// Extract EP compatibility info from a precompiled model file.
+ ///
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtGetCompatibilityInfoFromModel(
+ byte[] /* const ORTCHAR_T* */ model_path,
+ byte[] /* const char* */ ep_type,
+ IntPtr /* OrtAllocator* */ allocator,
+ out IntPtr /* char** */ compatibility_info);
+
+ public static DOrtGetCompatibilityInfoFromModel OrtGetCompatibilityInfoFromModel;
+
+ ///
+ /// Extract EP compatibility info from precompiled model bytes in memory.
+ ///
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtGetCompatibilityInfoFromModelBytes(
+ byte[] /* const void* */ model_data,
+ UIntPtr /* size_t */ model_data_length,
+ byte[] /* const char* */ ep_type,
+ IntPtr /* OrtAllocator* */ allocator,
+ out IntPtr /* char** */ compatibility_info);
+
+ public static DOrtGetCompatibilityInfoFromModelBytes OrtGetCompatibilityInfoFromModelBytes;
+
#endregion
} // class NativeMethods
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
index 22f541e2207fa..0876db3f21209 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
@@ -524,6 +524,75 @@ public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices(
return (OrtCompiledModelCompatibility)status;
}
+ ///
+ /// Extract EP compatibility info from a precompiled model file.
+ ///
+ ///
+ /// Parses the model file to extract the compatibility info string for a specific execution provider
+ /// from the model's metadata properties. This is only applicable to models that have been precompiled
+ /// for an EP. Standard ONNX models do not contain this information.
+ /// The compatibility info can then be passed to to
+ /// check if a precompiled model is compatible with the current system.
+ ///
+ /// Path to the ONNX model file.
+ /// The execution provider type string. Use to get this value.
+ /// The compatibility info string, or null if no compatibility info exists for the specified EP.
+ /// If modelPath or epType is null or empty.
+ /// If the model file cannot be read or parsed.
+ public string GetCompatibilityInfoFromModel(string modelPath, string epType)
+ {
+ if (string.IsNullOrEmpty(modelPath))
+ throw new ArgumentException("modelPath must be non-empty", nameof(modelPath));
+ if (string.IsNullOrEmpty(epType))
+ throw new ArgumentException("epType must be non-empty", nameof(epType));
+
+ var allocator = OrtAllocator.DefaultInstance;
+ var pathBytes = NativeOnnxValueHelper.GetPlatformSerializedString(modelPath);
+ var epTypeUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(epType);
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.OrtGetCompatibilityInfoFromModel(
+ pathBytes, epTypeUtf8, allocator.Pointer, out IntPtr compatInfoPtr));
+
+ if (compatInfoPtr == IntPtr.Zero)
+ return null;
+
+ return NativeOnnxValueHelper.StringFromNativeUtf8(compatInfoPtr, allocator);
+ }
+
+ ///
+ /// Extract EP compatibility info from precompiled model bytes in memory.
+ ///
+ ///
+ /// Same as but reads from a memory buffer instead of a file.
+ /// Useful when precompiled models are loaded from encrypted storage, network, or other non-file sources.
+ ///
+ /// The model data bytes.
+ /// The execution provider type string. Use to get this value.
+ /// The compatibility info string, or null if no compatibility info exists for the specified EP.
+ /// If modelData is null/empty or epType is null or empty.
+ /// If the model data cannot be parsed.
+ public string GetCompatibilityInfoFromModelBytes(byte[] modelData, string epType)
+ {
+ if (modelData == null || modelData.Length == 0)
+ throw new ArgumentException("modelData must be non-empty", nameof(modelData));
+ if (string.IsNullOrEmpty(epType))
+ throw new ArgumentException("epType must be non-empty", nameof(epType));
+
+ var allocator = OrtAllocator.DefaultInstance;
+ var epTypeUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(epType);
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.OrtGetCompatibilityInfoFromModelBytes(
+ modelData, (UIntPtr)modelData.Length, epTypeUtf8,
+ allocator.Pointer, out IntPtr compatInfoPtr));
+
+ if (compatInfoPtr == IntPtr.Zero)
+ return null;
+
+ return NativeOnnxValueHelper.StringFromNativeUtf8(compatInfoPtr, allocator);
+ }
+
///
/// Get/Set log level property of OrtEnv instance
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs
index 103fe5bc10106..f1b792454f205 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs
@@ -10,6 +10,8 @@ namespace Microsoft.ML.OnnxRuntime.Tests;
using System.Linq;
using Xunit;
using System.Collections.Generic;
+using Google.Protobuf;
+using Onnx;
public class EpCompatibilityTests
{
@@ -23,6 +25,35 @@ private IReadOnlyList GetDevices()
return epDevices;
}
+ ///
+ /// Creates a minimal valid ONNX ModelProto with optional compatibility metadata.
+ ///
+ private static byte[] CreateModelWithCompatibilityMetadata(
+ Dictionary epCompatibilityInfo = null)
+ {
+ var modelProto = new ModelProto();
+ modelProto.IrVersion = (long)Onnx.Version.IrVersion;
+ modelProto.Graph = new GraphProto { Name = "test_graph" };
+
+ var opset = new OperatorSetIdProto();
+ opset.Domain = "";
+ opset.Version = 13;
+ modelProto.OpsetImport.Add(opset);
+
+ if (epCompatibilityInfo != null)
+ {
+ foreach (var kvp in epCompatibilityInfo)
+ {
+ var prop = new StringStringEntryProto();
+ prop.Key = "ep_compatibility_info." + kvp.Key;
+ prop.Value = kvp.Value;
+ modelProto.MetadataProps.Add(prop);
+ }
+ }
+
+ return modelProto.ToByteArray();
+ }
+
[Fact]
public void GetEpCompatibility_InvalidArgs()
{
@@ -45,5 +76,103 @@ public void GetEpCompatibility_SingleDeviceCpuProvider()
// CPU defaults to not applicable in this scenario
Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status);
}
+
+ [Fact]
+ public void GetCompatibilityInfoFromModel_InvalidArgs()
+ {
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModel(null, "TestEP"));
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModel("", "TestEP"));
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModel("model.onnx", null));
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModel("model.onnx", ""));
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModel_FileNotFound()
+ {
+ Assert.Throws(
+ () => ortEnvInstance.GetCompatibilityInfoFromModel("nonexistent_model_path.onnx", "TestEP"));
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModelBytes_InvalidArgs()
+ {
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModelBytes(null, "TestEP"));
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModelBytes(new byte[0], "TestEP"));
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModelBytes(new byte[] { 1, 2, 3 }, null));
+ Assert.Throws(() => ortEnvInstance.GetCompatibilityInfoFromModelBytes(new byte[] { 1, 2, 3 }, ""));
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModel_WithMetadata()
+ {
+ const string epType = "TestCompatEP";
+ const string expectedCompatInfo = "test_compat_v1.0_driver_123";
+
+ byte[] modelData = CreateModelWithCompatibilityMetadata(
+ new Dictionary { { epType, expectedCompatInfo } });
+
+ string tempModelPath = System.IO.Path.Combine(
+ System.IO.Path.GetTempPath(),
+ System.IO.Path.GetRandomFileName() + ".onnx");
+
+ System.IO.File.WriteAllBytes(tempModelPath, modelData);
+
+ try
+ {
+ string result = ortEnvInstance.GetCompatibilityInfoFromModel(tempModelPath, epType);
+ Assert.NotNull(result);
+ Assert.Equal(expectedCompatInfo, result);
+ }
+ finally
+ {
+ if (System.IO.File.Exists(tempModelPath))
+ {
+ System.IO.File.Delete(tempModelPath);
+ }
+ }
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModelBytes_InvalidModelData()
+ {
+ byte[] invalidData = System.Text.Encoding.UTF8.GetBytes("this is not a valid ONNX model");
+ Assert.Throws(
+ () => ortEnvInstance.GetCompatibilityInfoFromModelBytes(invalidData, "TestEP"));
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModelBytes_WithMetadata()
+ {
+ const string epType = "TestCompatEP";
+ const string expectedCompatInfo = "test_compat_v1.0_driver_123";
+
+ byte[] modelData = CreateModelWithCompatibilityMetadata(
+ new Dictionary { { epType, expectedCompatInfo } });
+
+ string result = ortEnvInstance.GetCompatibilityInfoFromModelBytes(modelData, epType);
+ Assert.NotNull(result);
+ Assert.Equal(expectedCompatInfo, result);
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModelBytes_NotFound()
+ {
+ // Create model with metadata for a different EP
+ byte[] modelData = CreateModelWithCompatibilityMetadata(
+ new Dictionary { { "DifferentEP", "some_value" } });
+
+ string result = ortEnvInstance.GetCompatibilityInfoFromModelBytes(modelData, "NonExistentEP");
+ Assert.Null(result);
+ }
+
+ [Fact]
+ public void GetCompatibilityInfoFromModelBytes_NoMetadata()
+ {
+ // Create model without any compatibility metadata
+ byte[] modelData = CreateModelWithCompatibilityMetadata();
+
+ string result = ortEnvInstance.GetCompatibilityInfoFromModelBytes(modelData, "AnyEP");
+ Assert.Null(result);
+ }
}
#endif
diff --git a/js/react_native/android/CMakeLists.txt b/js/react_native/android/CMakeLists.txt
index 0bcf552ff9e41..a23f5ba7cd8ab 100644
--- a/js/react_native/android/CMakeLists.txt
+++ b/js/react_native/android/CMakeLists.txt
@@ -1,5 +1,5 @@
+cmake_minimum_required(VERSION 3.13)
project(OnnxruntimeJSI)
-cmake_minimum_required(VERSION 3.9.0)
set(PACKAGE_NAME "onnxruntime-react-native")
set(BUILD_DIR ${CMAKE_SOURCE_DIR}/build)
@@ -97,3 +97,6 @@ target_link_libraries(
${log-lib} # <-- Logcat logger
android # <-- Android JNI core
)
+
+# 16KB page size support (Android 15+ requirement)
+target_link_options(onnxruntimejsi PRIVATE "-Wl,-z,max-page-size=16384")
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
deleted file mode 100644
index 651f270230a75..0000000000000
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ /dev/null
@@ -1,341 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "contrib_ops/cpu/bert/attention_base.h"
-#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
-#include "core/providers/common.h"
-
-namespace onnxruntime {
-namespace contrib {
-
-Status AttentionBase::CheckInputs(const TensorShape& input_shape,
- const TensorShape& weights_shape,
- const TensorShape& bias_shape,
- const Tensor*& mask_index,
- const Tensor* past,
- const Tensor* attention_bias,
- void* parameters,
- const Tensor* past_seq_len) const {
- // Abbreviation and Meanings:
- // B: batch_size
- // S: sequence_length (input sequence length of query)
- // P: past_sequence_length (past sequence length of key or value)
- // L: kv_sequence_length (input sequence length of key or value)
- // M: max_sequence_length
- // T: total_sequence_length = past_sequence_length + kv_sequence_length
- // N: num_heads
- // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
- // H_v: v_head_size
- // D_i: input hidden size
- // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
- // D_v: v_hidden_size = num_heads * v_head_size
-
- // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value).
-
- // Input shapes:
- // input (Q/K/V) : (B, S, D_i)
- // weights (Q/K/V) : (D_i, D + D + D_v)
- // bias (Q/K/V) : (D + D + D_v)
- // mask_index : see below
- // past (K/V) : (2, B, N, P, H) or NULL
- // attention_bias : (B or 1, N or 1, S, T) or NULL
-
- // For mask_index, the following shapes are supported:
- // NULL, (B, 1), (1, 1)
- // (B), (2 * B), (3 * B + 2)
- // (B, T)
- // (B, S, T)
- // (B, 1, M, M)
- //
- // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger
- // than hidden dimension of Q, K and V.
-
- if (past != nullptr && attention_bias != nullptr) {
- // past is used on GPT-2 model with past state, we don't have a case for attention bias yet
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and attention_bias");
- }
-
- const auto& dims = input_shape.GetDims();
- if (dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
- dims.size());
- }
-
- auto& batch_size = dims[0];
- auto& sequence_length = dims[1];
- int64_t input_hidden_size = dims[2];
-
- const auto& bias_dims = bias_shape.GetDims();
- if (bias_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
- bias_dims.size());
- }
-
- const auto& weights_dims = weights_shape.GetDims();
- if (weights_dims.size() != 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' is expected to have 2 dimensions, got ",
- weights_dims.size());
- }
- if (weights_dims[0] != input_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 1 dimension 0 should have same length as dimension 2 of input 0");
- }
-
- if (bias_dims[0] != weights_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
- }
-
- int64_t q_hidden_size = bias_dims[0] / static_cast(3);
- int64_t k_hidden_size = q_hidden_size;
- int64_t v_hidden_size = k_hidden_size;
- if (qkv_hidden_sizes_.size() != 0) {
- if (qkv_hidden_sizes_.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "qkv_hidden_sizes attribute should have 3 elements");
- }
-
- for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) {
- if (qkv_hidden_sizes_[i] % num_heads_ != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "hidden_size should be divisible by num_heads:", qkv_hidden_sizes_[i]);
- }
- }
-
- q_hidden_size = qkv_hidden_sizes_[0];
- k_hidden_size = qkv_hidden_sizes_[1];
- v_hidden_size = qkv_hidden_sizes_[2];
- }
-
- int64_t kv_sequence_length = sequence_length;
-
- if (q_hidden_size != k_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "qkv_hidden_sizes first element should be same as the second");
- }
-
- if (this->require_same_hidden_size_ && k_hidden_size != v_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Hidden size of Q, K and V shall be same");
- }
-
- if (bias_dims[0] != q_hidden_size + k_hidden_size + v_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'bias' dimension 0 should have same length as sum of Q/K/V hidden sizes:",
- " q_hidden_size=", q_hidden_size, " k_hidden_size=", k_hidden_size, " v_hidden_size=",
- v_hidden_size, "bias_dims[0]=", bias_dims[0]);
- }
-
- int64_t past_sequence_length = 0;
- if (past != nullptr) { // past is optional
- if (k_hidden_size != v_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' expect k_hidden_size == v_hidden_size");
- }
-
- const auto& past_dims = past->Shape().GetDims();
- if (past_dims.size() != 5) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is expected to have 5 dimension, got ",
- past_dims.size());
- }
-
- if (past_dims[0] != 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 0 shall have length of 2");
- }
-
- if (past_dims[1] != batch_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
- }
-
- if (static_cast(past_dims[2]) != num_heads_) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
- }
-
- if (static_cast(past_dims[4]) != k_hidden_size / num_heads_) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'past' dimension 2 shall have length of ", k_hidden_size / num_heads_);
- }
-
- if (!past_present_share_buffer_) {
- past_sequence_length = past_dims[3];
- } else {
- if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "past_sequence_length tensor must be of one element when past_present_share_buffer is set");
- }
- past_sequence_length = *past_seq_len->Data();
- }
- }
-
- int64_t total_sequence_length = kv_sequence_length + past_sequence_length;
- if (past != nullptr && past_present_share_buffer_) {
- const auto& past_dims = past->Shape().GetDims();
- if (past_dims[3] < total_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "when past_present_share_buffer, past tensor sequence must not smaller than total_sequqnce_length ");
- }
- }
-
- int64_t max_sequence_length = -1;
- AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
- if (mask_index != nullptr) { // mask_index is optional
- mask_type = AttentionMaskType::MASK_UNKNOWN;
- auto status = this->CheckMask(mask_index, mask_type,
- max_sequence_length, batch_size, sequence_length, total_sequence_length);
- if (status != Status::OK()) {
- return status;
- }
-
- if (mask_type == AttentionMaskType::MASK_2D_DUMMY) {
- mask_index = nullptr;
- mask_type = AttentionMaskType::MASK_NONE;
- }
- }
-
- gsl::span attention_bias_dims;
- if (attention_bias != nullptr) {
- attention_bias_dims = attention_bias->Shape().GetDims();
-
- ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias(
- attention_bias_dims, batch_size, num_heads_, sequence_length, total_sequence_length));
- }
-
- if (past != nullptr && past_present_share_buffer_) {
- if (max_sequence_length <= 0) {
- max_sequence_length = past->Shape().GetDims()[3];
- }
- if (max_sequence_length != past->Shape().GetDims()[3]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "max_sequence_length not matching from mask and past when past_present_share_buffer_ is set");
- }
- }
-
- if (parameters != nullptr) {
- AttentionParameters* output_parameters = reinterpret_cast(parameters);
- output_parameters->batch_size = static_cast(batch_size);
- output_parameters->sequence_length = static_cast(sequence_length);
- output_parameters->past_sequence_length = static_cast(past_sequence_length);
- output_parameters->kv_sequence_length = static_cast(kv_sequence_length);
- output_parameters->total_sequence_length = static_cast(total_sequence_length);
- output_parameters->max_sequence_length = static_cast(max_sequence_length);
- output_parameters->input_hidden_size = static_cast(input_hidden_size);
- output_parameters->hidden_size = static_cast(q_hidden_size);
- output_parameters->v_hidden_size = static_cast(v_hidden_size);
- output_parameters->head_size = static_cast(q_hidden_size) / num_heads_;
- output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads_;
- output_parameters->num_heads = num_heads_;
- output_parameters->is_unidirectional = is_unidirectional_;
- output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
- output_parameters->do_rotary = do_rotary_;
- output_parameters->rotary_dim = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
- output_parameters->mask_filter_value = mask_filter_value_;
- output_parameters->scale = scale_;
- output_parameters->mask_type = mask_type;
- output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1;
- output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1;
- output_parameters->qkv_format = Q_K_V_BNSH;
- }
-
- return Status::OK();
-}
-
-Status AttentionBase::CheckMask(const Tensor* mask_index,
- AttentionMaskType& mask_type,
- int64_t& max_sequence_length,
- int64_t batch_size,
- int64_t sequence_length,
- int64_t total_sequence_length) const {
- const auto& mask_dims = mask_index->Shape().GetDims();
- if (mask_dims.size() == 1) {
- if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size && mask_dims[0] != 3 * batch_size + 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2");
- }
- mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START
- : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
- } else if (mask_dims.size() == 2) {
- if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) {
- mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
- } else {
- // Add operator supports broadcasting. Here we handle a case with only one element in the 2nd dimension.
- if ((mask_dims[0] == batch_size || mask_dims[0] == 1) && mask_dims[1] == 1) {
- // Mask will have same value after propagation, which has same effect as no mask.
- mask_type = AttentionMaskType::MASK_2D_DUMMY;
- } else {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'mask_index' with 2D data shall have shape "
- "batch_size x total_sequence_length");
- }
- }
- } else if (mask_dims.size() == 3) {
- if (mask_dims[0] != batch_size || mask_dims[1] != sequence_length || mask_dims[2] != total_sequence_length) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'mask_index' with 3D data shall have shape "
- "batch_size x sequence_length x total_sequence_length");
- }
- mask_type = AttentionMaskType::MASK_3D_ATTENTION;
- } else if (mask_dims.size() == 4) {
- if (mask_dims[0] != batch_size || mask_dims[1] != 1 || mask_dims[2] != mask_dims[3] ||
- mask_dims[2] < total_sequence_length) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'mask_index' with 4D data shall have shape "
- "batch_size x 1 x max_sequence_length x max_sequence_length)");
- }
- max_sequence_length = mask_dims[3];
- mask_type = AttentionMaskType::MASK_4D_MEGATRON;
- if (this->is_unidirectional_) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'mask_index' with 4D data shall have is_unidirectional set to false");
- }
- } else {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
- mask_dims.size());
- }
-
- return Status::OK();
-}
-
-Status AttentionBase::CheckInputs(const TensorShape& input_shape,
- const TensorShape& weights_shape,
- const TensorShape& bias_shape,
- const Tensor*& mask_index,
- const Tensor* past,
- const Tensor* attention_bias,
- void* parameters,
- const int max_threads_per_block,
- const Tensor* past_seq_len) const {
- if (num_heads_ > max_threads_per_block) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
- }
-
- return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, attention_bias, parameters, past_seq_len);
-}
-
-Tensor* AttentionBase::GetPresent(OpKernelContext* context,
- const Tensor* past,
- int batch_size,
- int head_size,
- int kv_sequence_length,
- int& past_sequence_length) const {
- // Input and output shapes:
- // past : (2, batch_size, num_heads, past_sequence_length, head_size)
- // present : (2, batch_size, num_heads, past_sequence_length + kv_sequence_length, head_size)
-
- past_sequence_length = (nullptr != past) ? static_cast(past->Shape().GetDims()[3]) : 0;
- std::array present_dims{2, batch_size, num_heads_, static_cast(kv_sequence_length) + past_sequence_length, head_size};
-
- TensorShape present_shape(present_dims);
- Tensor* present = context->Output(1, present_shape);
- if (nullptr != past && nullptr == present) {
- ORT_THROW("Expect to have present state output when past state input is given");
- }
-
- return present;
-}
-
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index bd7f03379b2f0..2872fcfda5bbf 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -3,12 +3,19 @@
#pragma once
+#include
#include
#include "core/common/common.h"
-#include "core/framework/op_kernel.h"
#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h"
+#ifndef SHARED_PROVIDER
+#include "core/framework/op_kernel.h"
+#include "core/providers/common.h"
+#endif
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cpu/bert/attention_parameters.h"
+#ifndef SHARED_PROVIDER
+#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
+#endif
namespace onnxruntime {
namespace contrib {
@@ -25,14 +32,25 @@ class AttentionBase {
const int max_threads_per_block, // for CUDA
const Tensor* past_seq_len = nullptr) const;
+#ifdef SHARED_PROVIDER
Tensor* GetPresent(OpKernelContext* context,
const Tensor* past,
int batch_size,
int head_size,
int kv_sequence_length,
int& past_sequence_length) const;
+#else
+ template
+ Tensor* GetPresent(TOpKernelContext* context,
+ const Tensor* past,
+ int batch_size,
+ int head_size,
+ int kv_sequence_length,
+ int& past_sequence_length) const;
+#endif
protected:
+ // Keep the class layout identical in SHARED_PROVIDER and non-SHARED_PROVIDER builds.
MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
template
@@ -54,7 +72,9 @@ class AttentionBase {
require_same_hidden_size_ = require_same_hidden_size;
+#ifndef SHARED_PROVIDER
SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions());
+#endif
}
Status CheckMask(const Tensor* mask_index,
@@ -84,5 +104,299 @@ class AttentionBase {
float scale_; // the scale to be used for softmax
};
+#ifndef SHARED_PROVIDER
+// Inline implementations of out-of-line methods for non-SHARED_PROVIDER builds
+// (attention_base.cc definitions are used only in the SHARED_PROVIDER bridge path).
+inline Status AttentionBase::CheckMask(const Tensor* mask_index,
+ AttentionMaskType& mask_type,
+ int64_t& max_sequence_length,
+ int64_t batch_size,
+ int64_t sequence_length,
+ int64_t total_sequence_length) const {
+ const auto& mask_dims = mask_index->Shape().GetDims();
+ if (mask_dims.size() == 1) {
+ if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size && mask_dims[0] != 3 * batch_size + 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2");
+ }
+ mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START
+ : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
+ } else if (mask_dims.size() == 2) {
+ if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) {
+ mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
+ } else {
+ if ((mask_dims[0] == batch_size || mask_dims[0] == 1) && mask_dims[1] == 1) {
+ mask_type = AttentionMaskType::MASK_2D_DUMMY;
+ } else {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'mask_index' with 2D data shall have shape "
+ "batch_size x total_sequence_length");
+ }
+ }
+ } else if (mask_dims.size() == 3) {
+ if (mask_dims[0] != batch_size || mask_dims[1] != sequence_length || mask_dims[2] != total_sequence_length) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'mask_index' with 3D data shall have shape "
+ "batch_size x sequence_length x total_sequence_length");
+ }
+ mask_type = AttentionMaskType::MASK_3D_ATTENTION;
+ } else if (mask_dims.size() == 4) {
+ if (mask_dims[0] != batch_size || mask_dims[1] != 1 || mask_dims[2] != mask_dims[3] ||
+ mask_dims[2] < total_sequence_length) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'mask_index' with 4D data shall have shape "
+ "batch_size x 1 x max_sequence_length x max_sequence_length)");
+ }
+ max_sequence_length = mask_dims[3];
+ mask_type = AttentionMaskType::MASK_4D_MEGATRON;
+ if (this->is_unidirectional_) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'mask_index' with 4D data shall have is_unidirectional set to false");
+ }
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
+ mask_dims.size());
+ }
+
+ return Status::OK();
+}
+
+inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
+ const TensorShape& weights_shape,
+ const TensorShape& bias_shape,
+ const Tensor*& mask_index,
+ const Tensor* past,
+ const Tensor* attention_bias,
+ void* parameters,
+ const Tensor* past_seq_len) const {
+ if (past != nullptr && attention_bias != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and attention_bias");
+ }
+
+ const auto& dims = input_shape.GetDims();
+ if (dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
+ dims.size());
+ }
+
+ auto& batch_size = dims[0];
+ auto& sequence_length = dims[1];
+ int64_t input_hidden_size = dims[2];
+
+ const auto& bias_dims = bias_shape.GetDims();
+ if (bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
+ bias_dims.size());
+ }
+
+ const auto& weights_dims = weights_shape.GetDims();
+ if (weights_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'weights' is expected to have 2 dimensions, got ",
+ weights_dims.size());
+ }
+ if (weights_dims[0] != input_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 1 dimension 0 should have same length as dimension 2 of input 0");
+ }
+
+ if (bias_dims[0] != weights_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
+ }
+
+ int64_t q_hidden_size = bias_dims[0] / static_cast(3);
+ int64_t k_hidden_size = q_hidden_size;
+ int64_t v_hidden_size = k_hidden_size;
+ if (qkv_hidden_sizes_.size() != 0) {
+ if (qkv_hidden_sizes_.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "qkv_hidden_sizes attribute should have 3 elements");
+ }
+
+ for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) {
+ if (qkv_hidden_sizes_[i] % num_heads_ != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "hidden_size should be divisible by num_heads:", qkv_hidden_sizes_[i]);
+ }
+ }
+
+ q_hidden_size = qkv_hidden_sizes_[0];
+ k_hidden_size = qkv_hidden_sizes_[1];
+ v_hidden_size = qkv_hidden_sizes_[2];
+ }
+
+ int64_t kv_sequence_length = sequence_length;
+
+ if (q_hidden_size != k_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "qkv_hidden_sizes first element should be same as the second");
+ }
+
+ if (this->require_same_hidden_size_ && k_hidden_size != v_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Hidden size of Q, K and V shall be same");
+ }
+
+ if (bias_dims[0] != q_hidden_size + k_hidden_size + v_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'bias' dimension 0 should have same length as sum of Q/K/V hidden sizes:",
+ " q_hidden_size=", q_hidden_size, " k_hidden_size=", k_hidden_size, " v_hidden_size=",
+ v_hidden_size, "bias_dims[0]=", bias_dims[0]);
+ }
+
+ int64_t past_sequence_length = 0;
+ if (past != nullptr) {
+ if (k_hidden_size != v_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' expect k_hidden_size == v_hidden_size");
+ }
+
+ const auto& past_dims = past->Shape().GetDims();
+ if (past_dims.size() != 5) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is expected to have 5 dimension, got ",
+ past_dims.size());
+ }
+
+ if (past_dims[0] != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 0 shall have length of 2");
+ }
+
+ if (past_dims[1] != batch_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
+ }
+
+ if (static_cast(past_dims[2]) != num_heads_) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
+ }
+
+ if (static_cast(past_dims[4]) != k_hidden_size / num_heads_) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'past' dimension 2 shall have length of ", k_hidden_size / num_heads_);
+ }
+
+ if (!past_present_share_buffer_) {
+ past_sequence_length = past_dims[3];
+ } else {
+ if (past_seq_len == nullptr || !::onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "past_sequence_length tensor must be of one element when past_present_share_buffer is set");
+ }
+ past_sequence_length = *past_seq_len->Data();
+ }
+ }
+
+ int64_t total_sequence_length = kv_sequence_length + past_sequence_length;
+ if (past != nullptr && past_present_share_buffer_) {
+ const auto& past_dims = past->Shape().GetDims();
+ if (past_dims[3] < total_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "when past_present_share_buffer, past tensor sequence must not smaller than total_sequence_length ");
+ }
+ }
+
+ int64_t max_sequence_length = -1;
+ AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
+ if (mask_index != nullptr) {
+ mask_type = AttentionMaskType::MASK_UNKNOWN;
+ auto status = this->CheckMask(mask_index, mask_type,
+ max_sequence_length, batch_size, sequence_length, total_sequence_length);
+ if (status != Status::OK()) {
+ return status;
+ }
+
+ if (mask_type == AttentionMaskType::MASK_2D_DUMMY) {
+ mask_index = nullptr;
+ mask_type = AttentionMaskType::MASK_NONE;
+ }
+ }
+
+ gsl::span attention_bias_dims;
+ if (attention_bias != nullptr) {
+ attention_bias_dims = attention_bias->Shape().GetDims();
+
+ ORT_RETURN_IF_ERROR(::onnxruntime::contrib::multihead_attention_helper::CheckAttentionBias(
+ attention_bias_dims, batch_size, num_heads_, sequence_length, total_sequence_length));
+ }
+
+ if (past != nullptr && past_present_share_buffer_) {
+ if (max_sequence_length <= 0) {
+ max_sequence_length = past->Shape().GetDims()[3];
+ }
+ if (max_sequence_length != past->Shape().GetDims()[3]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "max_sequence_length not matching from mask and past when past_present_share_buffer_ is set");
+ }
+ }
+
+ if (parameters != nullptr) {
+ AttentionParameters* output_parameters = reinterpret_cast(parameters);
+ output_parameters->batch_size = static_cast(batch_size);
+ output_parameters->sequence_length = static_cast(sequence_length);
+ output_parameters->past_sequence_length = static_cast(past_sequence_length);
+ output_parameters->kv_sequence_length = static_cast(kv_sequence_length);
+ output_parameters->total_sequence_length = static_cast(total_sequence_length);
+ output_parameters->max_sequence_length = static_cast(max_sequence_length);
+ output_parameters->input_hidden_size = static_cast(input_hidden_size);
+ output_parameters->hidden_size = static_cast(q_hidden_size);
+ output_parameters->v_hidden_size = static_cast(v_hidden_size);
+ output_parameters->head_size = static_cast(q_hidden_size) / num_heads_;
+ output_parameters->v_head_size = static_cast(v_hidden_size) / num_heads_;
+ output_parameters->num_heads = num_heads_;
+ output_parameters->is_unidirectional = is_unidirectional_;
+ output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
+ output_parameters->do_rotary = do_rotary_;
+ output_parameters->rotary_dim = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
+ output_parameters->mask_filter_value = mask_filter_value_;
+ output_parameters->scale = scale_;
+ output_parameters->mask_type = mask_type;
+ output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1;
+ output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1;
+ output_parameters->qkv_format = Q_K_V_BNSH;
+ }
+
+ return Status::OK();
+}
+
+inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
+ const TensorShape& weights_shape,
+ const TensorShape& bias_shape,
+ const Tensor*& mask_index,
+ const Tensor* past,
+ const Tensor* attention_bias,
+ void* parameters,
+ const int max_threads_per_block,
+ const Tensor* past_seq_len) const {
+ if (num_heads_ > max_threads_per_block) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
+ }
+
+ return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, attention_bias, parameters, past_seq_len);
+}
+
+template
+inline Tensor* AttentionBase::GetPresent(TOpKernelContext* context,
+ const Tensor* past,
+ int batch_size,
+ int head_size,
+ int kv_sequence_length,
+ int& past_sequence_length) const {
+ past_sequence_length = (nullptr != past) ? static_cast(past->Shape().GetDims()[3]) : 0;
+ std::array present_dims{2, batch_size, num_heads_,
+ static_cast(kv_sequence_length) + past_sequence_length, head_size};
+
+ TensorShape present_shape(present_dims);
+ Tensor* present = context->Output(1, present_shape);
+ if (nullptr != past && nullptr == present) {
+ ORT_THROW("Expect to have present state output when past state input is given");
+ }
+
+ return present;
+}
+#endif // SHARED_PROVIDER
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
index 9a123e80adc18..f316a0dfdf91c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
@@ -10,33 +10,33 @@ namespace contrib {
// Parameters deduced from node attributes and inputs/outputs.
struct AttentionParameters {
- int batch_size;
- int sequence_length;
- int kv_sequence_length; // input sequence length of K or V
- int past_sequence_length; // sequence length in past state of K or V
- int total_sequence_length; // total sequence length of K or V
- int max_sequence_length; // max sequence length from 4D mask
- int input_hidden_size; // first dimension of weights for input projection
- int hidden_size; // hidden size of Q or K
- int head_size; // hidden size per head of Q or K
- int v_hidden_size; // hidden size of V
- int v_head_size; // hidden size per head of V
- int num_heads;
- int num_splits; // number of splits for splitkv
+ int batch_size = 0;
+ int sequence_length = 0;
+ int kv_sequence_length = 0; // input sequence length of K or V
+ int past_sequence_length = 0; // sequence length in past state of K or V
+ int total_sequence_length = 0; // total sequence length of K or V
+ int max_sequence_length = 0; // max sequence length from 4D mask
+ int input_hidden_size = 0; // first dimension of weights for input projection
+ int hidden_size = 0; // hidden size of Q or K
+ int head_size = 0; // hidden size per head of Q or K
+ int v_hidden_size = 0; // hidden size of V
+ int v_head_size = 0; // hidden size per head of V
+ int num_heads = 0;
+ int num_splits = 0; // number of splits for splitkv
int rotary_dim = 0; // rotary embedding dimension
- int beam_width;
+ int beam_width = 0;
bool is_unidirectional = false;
bool past_present_share_buffer = false;
bool is_packed_qkv = false; // whether qkv is packed
bool do_rotary = false;
bool broadcast_attn_bias_dim_0 = false;
bool broadcast_attn_bias_dim_1 = false;
- float mask_filter_value;
- float scale;
+ float mask_filter_value = 0.0f;
+ float scale = 0.0f;
bool use_tf32 = false;
bool is_output_bnsh = false; // whether the output format is BNSH
- AttentionMaskType mask_type;
- AttentionQkvFormat qkv_format;
+ AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
+ AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
};
// Parameters deduced from node attributes and inputs/outputs.
diff --git a/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.cc
deleted file mode 100644
index 97f75d297d789..0000000000000
--- a/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "longformer_attention_base.h"
-
-namespace onnxruntime {
-namespace contrib {
-
-Status LongformerAttentionBase::CheckInputs(const TensorShape& input_shape,
- const TensorShape& weights_shape,
- const TensorShape& bias_shape,
- const TensorShape& attention_mask_shape,
- const TensorShape& global_weights_shape,
- const TensorShape& global_bias_shape,
- const TensorShape& global_mask_shape) const {
- // Input shapes:
- // input : (batch_size, sequence_length, hidden_size)
- // weights : (hidden_size, 3 * hidden_size) -- format 1
- // (3, hidden_size, hidden_size) -- format 0
- // bias : (3 * hidden_size) -- format 1 (bias for Q, K, V)
- // (5 * hidden_size) -- format 0 (bias for Q, K, V, Global_K, Global_V)
- // attention_mask : (batch_size, sequence_length)
- // global_weights : (hidden_size, 3 * hidden_size) -- format 1
- // (3, hidden_size, hidden_size) -- format 0
- // global_bias : (3 * hidden_size) -- format 1 (bias for Global_Q, Global_K, Global_V)
- // (1 * hidden_size) -- format 0 (bias for Global_Q)
- // global_attention_mask : (batch_size, sequence_length)
-
- const auto& dims = input_shape.GetDims();
- if (dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
- dims.size());
- }
-
- int batch_size = static_cast(dims[0]);
- int sequence_length = static_cast(dims[1]);
- auto hidden_size = dims[2];
- if (sequence_length % (2 * window_) != 0) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'input' dimension 1 should be divisible by 2W, where W is value of the window attribute.");
- }
- if (hidden_size % num_heads_ != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'input' dimension 2 should be divisible by value of the num_heads attribute.");
- }
-
- const auto& weights_dims = weights_shape.GetDims();
- bool use_merged_qkv_weights = (weights_shape.NumDimensions() == 2);
- if (use_merged_qkv_weights) {
- if (weights_dims[0] != hidden_size || weights_dims[1] != 3 * hidden_size) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'weights' shape should be (hidden_size, 3 * hidden_size) for format 1");
- }
- } else {
- if (weights_dims.size() != 3 ||
- weights_dims[0] != 3 || weights_dims[1] != hidden_size || weights_dims[2] != hidden_size) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'weights' shape should be (3, hidden_size, hidden_size) for format 0");
- }
- }
-
- const auto& bias_dims = bias_shape.GetDims();
- if (bias_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
- bias_dims.size());
- }
-
- if (use_merged_qkv_weights) {
- if (bias_dims[0] != 3 * hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'bias' shape should be (3 * hidden_size) for format 1");
- }
- } else {
- if (bias_dims[0] != 5 * hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'bias' shape should be (5 * hidden_size) for format 0");
- }
- }
-
- const auto& mask_dims = attention_mask_shape.GetDims();
- if (mask_dims.size() == 2) {
- if (static_cast(mask_dims[0]) != batch_size || static_cast(mask_dims[1]) != sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Inputs 'attention_mask' shape shall be (batch_size, sequence_length)");
- }
- } else {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'attention_mask' is expected to have 2 dimensions, got ", mask_dims.size());
- }
-
- const auto& global_weights_dims = global_weights_shape.GetDims();
- if (use_merged_qkv_weights) {
- if (global_weights_dims.size() != 2 ||
- global_weights_dims[0] != hidden_size || global_weights_dims[1] != 3 * hidden_size) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'global_weights' shape should be (hidden_size, 3 * hidden_size) for format 1");
- }
- } else {
- if (global_weights_dims.size() != 3 || global_weights_dims[0] != 3 ||
- global_weights_dims[1] != hidden_size || global_weights_dims[2] != hidden_size) {
- return ORT_MAKE_STATUS(
- ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'global_weights' shape should be (3, hidden_size, hidden_size) for format 0");
- }
- }
-
- const auto& global_bias_dims = global_bias_shape.GetDims();
- if (global_bias_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'global_bias' is expected to have 1 dimension, got ",
- global_bias_dims.size());
- }
-
- if (use_merged_qkv_weights) {
- if (global_bias_dims[0] != 3 * hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'global_bias' shape should be (3 * hidden_size) for format 1");
- }
- } else {
- if (global_bias_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'global_bias' shape should be (hidden_size) for format 0");
- }
- }
-
- const auto& global_mask_dims = global_mask_shape.GetDims();
- if (global_mask_dims.size() != 2 ||
- static_cast(global_mask_dims[0]) != batch_size ||
- static_cast(global_mask_dims[1]) != sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'global_attention_mask' shape shall be (batch_size, sequence_length)");
- }
-
- return Status::OK();
-}
-
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.h
index ac1cccaa83cf9..bb1dfea38ae80 100644
--- a/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/longformer_attention_base.h
@@ -4,7 +4,9 @@
#pragma once
#include "core/common/common.h"
+#ifndef SHARED_PROVIDER
#include "core/framework/op_kernel.h"
+#endif
namespace onnxruntime {
namespace contrib {
@@ -20,7 +22,8 @@ class LongformerAttentionBase {
const TensorShape& global_attention_mask_shape) const;
protected:
- LongformerAttentionBase(const OpKernelInfo& info) {
+ template
+ LongformerAttentionBase(const KernelInfoType& info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast(num_heads);
@@ -43,5 +46,126 @@ constexpr const char* kUseHalf4 = "ORT_LONGFORMER_USE_HALF4";
} // namespace longformer
+#ifndef SHARED_PROVIDER
+// Inline implementation of CheckInputs for non-SHARED_PROVIDER builds.
+inline Status LongformerAttentionBase::CheckInputs(const TensorShape& input_shape,
+ const TensorShape& weights_shape,
+ const TensorShape& bias_shape,
+ const TensorShape& attention_mask_shape,
+ const TensorShape& global_weights_shape,
+ const TensorShape& global_bias_shape,
+ const TensorShape& global_mask_shape) const {
+ const auto& dims = input_shape.GetDims();
+ if (dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ",
+ dims.size());
+ }
+
+ int batch_size = static_cast(dims[0]);
+ int sequence_length = static_cast(dims[1]);
+ auto hidden_size = dims[2];
+ if (sequence_length % (2 * window_) != 0) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'input' dimension 1 should be divisible by 2W, where W is value of the window attribute.");
+ }
+ if (hidden_size % num_heads_ != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'input' dimension 2 should be divisible by value of the num_heads attribute.");
+ }
+
+ const auto& weights_dims = weights_shape.GetDims();
+ bool use_merged_qkv_weights = (weights_shape.NumDimensions() == 2);
+ if (use_merged_qkv_weights) {
+ if (weights_dims[0] != hidden_size || weights_dims[1] != 3 * hidden_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'weights' shape should be (hidden_size, 3 * hidden_size) for format 1");
+ }
+ } else {
+ if (weights_dims.size() != 3 ||
+ weights_dims[0] != 3 || weights_dims[1] != hidden_size || weights_dims[2] != hidden_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'weights' shape should be (3, hidden_size, hidden_size) for format 0");
+ }
+ }
+
+ const auto& bias_dims = bias_shape.GetDims();
+ if (bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
+ bias_dims.size());
+ }
+
+ if (use_merged_qkv_weights) {
+ if (bias_dims[0] != 3 * hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'bias' shape should be (3 * hidden_size) for format 1");
+ }
+ } else {
+ if (bias_dims[0] != 5 * hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'bias' shape should be (5 * hidden_size) for format 0");
+ }
+ }
+
+ const auto& mask_dims = attention_mask_shape.GetDims();
+ if (mask_dims.size() == 2) {
+ if (static_cast(mask_dims[0]) != batch_size || static_cast(mask_dims[1]) != sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Inputs 'attention_mask' shape shall be (batch_size, sequence_length)");
+ }
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'attention_mask' is expected to have 2 dimensions, got ", mask_dims.size());
+ }
+
+ const auto& global_weights_dims = global_weights_shape.GetDims();
+ if (use_merged_qkv_weights) {
+ if (global_weights_dims.size() != 2 ||
+ global_weights_dims[0] != hidden_size || global_weights_dims[1] != 3 * hidden_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'global_weights' shape should be (hidden_size, 3 * hidden_size) for format 1");
+ }
+ } else {
+ if (global_weights_dims.size() != 3 || global_weights_dims[0] != 3 ||
+ global_weights_dims[1] != hidden_size || global_weights_dims[2] != hidden_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'global_weights' shape should be (3, hidden_size, hidden_size) for format 0");
+ }
+ }
+
+ const auto& global_bias_dims = global_bias_shape.GetDims();
+ if (global_bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'global_bias' is expected to have 1 dimension, got ",
+ global_bias_dims.size());
+ }
+
+ if (use_merged_qkv_weights) {
+ if (global_bias_dims[0] != 3 * hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'global_bias' shape should be (3 * hidden_size) for format 1");
+ }
+ } else {
+ if (global_bias_dims[0] != hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'global_bias' shape should be (hidden_size) for format 0");
+ }
+ }
+
+ const auto& global_mask_dims = global_mask_shape.GetDims();
+ if (global_mask_dims.size() != 2 ||
+ static_cast(global_mask_dims[0]) != batch_size ||
+ static_cast(global_mask_dims[1]) != sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'global_attention_mask' shape shall be (batch_size, sequence_length)");
+ }
+
+ return Status::OK();
+}
+#endif // SHARED_PROVIDER
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/crop.h b/onnxruntime/contrib_ops/cpu/crop.h
index 3b72ef429c1f7..97577304e948e 100644
--- a/onnxruntime/contrib_ops/cpu/crop.h
+++ b/onnxruntime/contrib_ops/cpu/crop.h
@@ -4,7 +4,9 @@
#pragma once
#include "core/common/common.h"
+#ifndef SHARED_PROVIDER
#include "core/framework/op_kernel.h"
+#endif
#include
@@ -13,9 +15,10 @@ namespace contrib {
class CropBase {
protected:
- CropBase(const OpKernelInfo& info)
- : border_(info.GetAttrsOrDefault("border")),
- scale_(info.GetAttrsOrDefault("scale")) {
+ template
+ CropBase(const KernelInfoType& info)
+ : border_(info.template GetAttrsOrDefault("border")),
+ scale_(info.template GetAttrsOrDefault("scale")) {
}
Status ValidateInput(const Tensor* X) const {
diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc
index a656abb098911..56bff8aa30f68 100644
--- a/onnxruntime/core/framework/allocator.cc
+++ b/onnxruntime/core/framework/allocator.cc
@@ -237,9 +237,22 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id),
mem_type1);
} else if (strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
- strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
+ strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0 ||
+ // Accept pre-1.25 names "WebGPU_Buffer"/"WebNN_Tensor" for backward compatibility
+ // with released onnxruntime-genai that still uses the old names.
+ // Normalize to the current (short) constant so downstream name comparisons work.
+ // See: https://github.com/microsoft/onnxruntime/pull/27207
+ strcmp(name1, "WebGPU_Buffer") == 0 ||
+ strcmp(name1, "WebNN_Tensor") == 0) {
+ // Map old long names to current short constants to keep downstream name comparisons consistent.
+ const char* normalized_name = name1;
+ if (strcmp(name1, "WebGPU_Buffer") == 0) {
+ normalized_name = onnxruntime::WEBGPU_BUFFER;
+ } else if (strcmp(name1, "WebNN_Tensor") == 0) {
+ normalized_name = onnxruntime::WEBNN_TENSOR;
+ }
*out = new OrtMemoryInfo(
- name1, type,
+ normalized_name, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id),
mem_type1);
diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc
index 6bcbdc401619a..bee7f048b7c6e 100644
--- a/onnxruntime/core/framework/tensorprotoutils.cc
+++ b/onnxruntime/core/framework/tensorprotoutils.cc
@@ -349,66 +349,112 @@ Status TensorProtoWithExternalDataToTensorProto(
return Status::OK();
}
-Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
- const std::filesystem::path& location,
- const std::filesystem::path& model_path) {
- // Reject absolute paths
- ORT_RETURN_IF(location.is_absolute(),
- "Absolute paths not allowed for external data location");
- if (!base_dir.empty()) {
- // Resolve and verify the path stays within model directory
- auto base_canonical = std::filesystem::weakly_canonical(base_dir);
- // If the symlink exists, it resolves to the target path;
- // so if the symlink is outside the directory it would be caught here.
- auto resolved = std::filesystem::weakly_canonical(base_dir / location);
-
- // Check that resolved path starts with base directory
- auto [base_end, resolved_it] = std::mismatch(
- base_canonical.begin(), base_canonical.end(),
- resolved.begin(), resolved.end());
-
- if (base_end != base_canonical.end()) {
- // If validation against logical base_dir fails, we check against the
- // real (canonical) path of the model file to support symlinked models
- // (e.g. models in Hugging Face Hub local cache).
- if (!model_path.empty()) {
- auto real_model_dir = std::filesystem::weakly_canonical(model_path).parent_path();
-
- auto [real_base_end, real_resolved_it] = std::mismatch(
- real_model_dir.begin(), real_model_dir.end(),
- resolved.begin(), resolved.end());
-
- if (real_base_end == real_model_dir.end()) {
- return Status::OK();
- }
+// Wraps std::filesystem::weakly_canonical with error_code handling.
+static Status WeaklyCanonicalPath(const std::filesystem::path& path, std::filesystem::path& result) {
+ std::error_code ec;
+ result = std::filesystem::weakly_canonical(path, ec);
+ ORT_RETURN_IF(ec, "Failed to get the weakly canonical path: ", path, " - ", ec.message());
+ return Status::OK();
+}
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
- "External data path: ", location, " (resolved path: ", resolved,
- ") escapes both model directory: ", base_dir,
- " and real model directory: ", real_model_dir);
- }
+// Wraps std::filesystem::exists with error_code handling.
+static Status PathExists(const std::filesystem::path& path, bool& exists) {
+ std::error_code ec;
+ exists = std::filesystem::exists(path, ec);
+ ORT_RETURN_IF(ec, "Failed to check existence of path: ", path, " - ", ec.message());
+ return Status::OK();
+}
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
- "External data path: ", location, " (resolved path: ", resolved,
- ") escapes model directory: ", base_dir);
- }
- } else {
- // The basedir is empty, which occurs when 1) the session loads a model from bytes and 2) the application does not
- // set an external file folder path via the session config option
- // `kOrtSessionOptionsModelExternalInitializersFileFolderPath`.
-
- // We conservatively check that the normalized relative path does not contain ".." path components that would allow
- // access to arbitrary files outside of the current working directory. Based on ONNX checker validation.
- auto norm_location = location.lexically_normal();
-
- for (const auto& path_component : norm_location) {
- if (path_component == ORT_TSTR("..")) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "External data path: ", location,
- " (model loaded from bytes) escapes working directory");
- }
+// Checks whether `path` has the given path prefix.
+static bool HasPathComponentPrefix(const std::filesystem::path& prefix, const std::filesystem::path& path) {
+ auto [prefix_end, path_it] = std::mismatch(prefix.begin(), prefix.end(), path.begin(), path.end());
+ return prefix_end == prefix.end();
+}
+
+Status ValidateExternalDataPath(const std::filesystem::path& model_path,
+ const std::filesystem::path& external_data_path) {
+ ORT_RETURN_IF(external_data_path.empty(), "Empty external data path not allowed");
+
+ // Note: Use !root_path().empty() to reject paths like '/some/path` even on Windows.
+ ORT_RETURN_IF(!external_data_path.root_path().empty(), "Absolute path not allowed for external data location");
+
+#if defined(__wasm__)
+ std::error_code error_code;
+ std::filesystem::current_path(error_code);
+ if (error_code) {
+ // If we can't access the current working directory in a WASM build, we assume that the WASM
+ // environment does not have a virtual filesystem and defer validation to an ExternalDataLoader for
+ // a WASM EP.
+ return Status::OK();
+ }
+#endif
+
+ // Determine the model directory: use model file's parent directory if provided,
+ // otherwise use the current working directory.
+ std::filesystem::path model_dir = model_path.empty() || model_path.parent_path().empty()
+ ? std::filesystem::path{"."}
+ : model_path.parent_path();
+
+ // Resolve the model directory and the external data path to their weakly canonical forms, which
+ // resolves symlinks but does not require that the paths actually exist yet.
+ std::filesystem::path model_dir_canonical;
+ std::filesystem::path external_data_path_canonical;
+ ORT_RETURN_IF_ERROR(WeaklyCanonicalPath(model_dir, model_dir_canonical));
+ ORT_RETURN_IF_ERROR(WeaklyCanonicalPath(model_dir_canonical / external_data_path, external_data_path_canonical));
+
+ // Check that the external data path is contained by the model directory.
+ // If it is, check if the external data file actually exists.
+ if (HasPathComponentPrefix(model_dir_canonical, external_data_path_canonical)) {
+ bool path_exists = false;
+ ORT_RETURN_IF_ERROR(PathExists(external_data_path_canonical, path_exists));
+ ORT_RETURN_IF(!path_exists, "External data path does not exist: ", external_data_path_canonical);
+ return Status::OK();
+ }
+
+ if (model_path.empty()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "External data path for model loaded from bytes escapes working directory. ",
+ "External data path: ", external_data_path, " resolved path: ",
+ external_data_path_canonical, " ", "working directory: ", model_dir);
+ }
+
+ // The model file itself may be a symlink. Therefore, check against the real/canonical directory of the model
+ // after resolving all symlinks.
+ //
+ // This supports symlinked models (e.g., Hugging Face Hub local cache) where the canonical
+ // parent of the model file differs from the parent directory of the symlinked model file.
+ std::error_code ec;
+ if (!std::filesystem::is_symlink(model_path, ec)) {
+ // Note: is_symlink returns false if file is not a symlink, file does not exist, or an error
+ // occurred (e.g., permissions). In any of these cases, we just return an error.
+ std::string fs_error_msg;
+ if (ec) {
+ fs_error_msg = " filesystem::is_symlink error: " + ec.message();
}
+
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "External data path for model escapes model directory. ",
+ "External data path: ", external_data_path, " resolved path: ",
+ external_data_path_canonical, " ", "model directory: ", model_dir, fs_error_msg);
}
- return Status::OK();
+
+ std::filesystem::path real_model_path;
+ ORT_RETURN_IF_ERROR(WeaklyCanonicalPath(model_path, real_model_path));
+ auto real_model_dir = real_model_path.parent_path();
+
+ // Check that the external data path is contained by the real model directory.
+ // If it is, check if the external data file actually exists.
+ if (HasPathComponentPrefix(real_model_dir, external_data_path_canonical)) {
+ bool path_exists = false;
+ ORT_RETURN_IF_ERROR(PathExists(external_data_path_canonical, path_exists));
+ ORT_RETURN_IF(!path_exists, "External data path does not exist: ", external_data_path_canonical);
+ return Status::OK();
+ }
+
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "External data path: ", external_data_path, " (resolved path: ",
+ external_data_path_canonical, ") escapes both model directory: ", model_dir,
+ " and real model directory: ", real_model_dir);
}
Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h
index f3f33a32b8076..e7649c072416c 100644
--- a/onnxruntime/core/framework/tensorprotoutils.h
+++ b/onnxruntime/core/framework/tensorprotoutils.h
@@ -539,19 +539,33 @@ Status TensorProtoWithExternalDataToTensorProto(
ONNX_NAMESPACE::TensorProto& new_tensor_proto);
///
-/// Validates if the external data path is under the model directory.
-/// If the model is a symlink, it checks against both the logical model directory (base_dir)
-/// and the real/canonical directory of the model.
-/// If the `base_dir` is empty, the function only ensures that `location` is not an absolute path.
+/// Validates that the given external data path is not an absolute path, is under the model directory
+/// (after resolving symlinks), and exists.
+///
+/// The model path can be empty if the model is loaded from bytes and the application did not specify a directory
+/// for external data files. In this case, the external data path must be contained under the current working
+/// directory.
+///
+/// The model path can point to a non-existing model file if the model is loaded from bytes and the application
+/// specified a directory for external data files via the session config entry
+/// `kOrtSessionOptionsModelExternalInitializersFileFolderPath`. In this case, the model_path is set to
+/// " / virtual_model.onnx" and the external data path
+/// must be contained under `kOrtSessionOptionsModelExternalInitializersFileFolderPath`.
+///
+/// If the model itself is a symlink, this function checks against both the directory containing the symlink
+/// and the real/canonical directory of the model after resolving all symlinks.
+///
+/// On WASM builds, this function skips most validation (except checks for non-empty/non-absolute path) if we are
+/// unable to query the current working directory, as this indicates that the WASM environment does not have
+/// a valid filesystem. If skipped, an ExternalDataLoader will validate the location and contents of the
+/// external data file at the time of access.
///
-/// Logical model location directory
-/// Location string retrieved from TensorProto external data
-/// Optional path to the model file, used for canonical path validation if base_dir check fails
-/// The function will fail if the resolved full path is not under the logical model directory
-/// nor the real directory of the model path
-Status ValidateExternalDataPath(const std::filesystem::path& base_dir,
- const std::filesystem::path& location,
- const std::filesystem::path& model_path = {});
+/// Path to the model file. Can be empty or point to a virtual file.
+/// External data file path to be validated.
+/// Retrieved from TensorProto external data info
+/// The function will fail if the resolved `external_data_path` path is not under the model directory
+Status ValidateExternalDataPath(const std::filesystem::path& model_path,
+ const std::filesystem::path& external_data_path);
#endif // !defined(SHARED_PROVIDER)
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 5aa466ecb5bc7..3599edbfcd357 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -3742,10 +3742,7 @@ Status Graph::ConvertInitializersIntoOrtValues() {
FindAllSubgraphs(all_subgraphs);
const auto& model_path = GetModel().ModelPath();
- PathString model_dir;
- if (!model_path.empty()) {
- ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, model_dir));
- }
+ std::unordered_set validated_external_data_paths;
auto put_weights_maybe_in_memory_func = [&](Graph& graph) -> Status {
// if we have any initializers that are not in memory, put them there.
@@ -3771,11 +3768,17 @@ Status Graph::ConvertInitializersIntoOrtValues() {
std::unique_ptr external_data_info;
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));
const auto& location = external_data_info->GetRelPath();
- auto st = utils::ValidateExternalDataPath(model_dir, location, model_path);
- if (!st.IsOK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
- "External data path validation failed for initializer: ", tensor_proto.name(),
- ". Error: ", st.ErrorMessage());
+
+ if (validated_external_data_paths.count(location) == 0) {
+ auto st = utils::ValidateExternalDataPath(model_path, location);
+
+ if (!st.IsOK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "External data path validation failed for initializer: ", tensor_proto.name(),
+ ". Error: ", st.ErrorMessage());
+ }
+
+ validated_external_data_paths.insert(location);
}
}
continue;
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
index da2e8fc37382a..fdc0818e8437b 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
@@ -43,7 +43,7 @@ bool IsDQWeightSigned(int32_t dt_weight) {
}
// Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits.
-// Used by both DQMatMulToMatMulNBitsAction and DQCastMatMulToMatMulNBitsAction.
+// Used by DQMatMulToMatMulNBitsAction.
struct TransposedQuantizedTensors {
Tensor weight;
Tensor scale;
@@ -486,146 +486,6 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph,
return Status::OK();
}
-DQCastMatMulToMatMulNBitsAction::DQCastMatMulToMatMulNBitsAction(
- int64_t accuracy_level,
- concurrency::ThreadPool* intra_op_thread_pool)
- : accuracy_level_{accuracy_level},
- intra_op_thread_pool_{intra_op_thread_pool} {
- ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");
-}
-
-Status DQCastMatMulToMatMulNBitsAction::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
- // Selected nodes layout (from DQCastMatMulToMatMulNBitsSelector):
- // Input(0) = DQ node
- // Input(1) = Cast on input B (between DQ and MatMul)
- // Target() = MatMul node
- auto* dq_node = selected_nodes.Input(0);
- auto* cast_b_node = selected_nodes.Input(1);
- auto& matmul_node = selected_nodes.Target();
-
- // --- Transpose DQ weights/scales/zp via shared helper ---
- TransposedQuantizedTensors transposed;
- ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits(
- graph, *dq_node, "fused_DQ_Cast_MatMul", intra_op_thread_pool_, transposed));
-
- // MatMulNBits operates in the DQ scale dtype.
- // Always insert Cast on input A (to DQ dtype) and Cast on output (DQ dtype to MatMul output dtype).
- // ORT's redundant cast elimination optimizer will clean up unnecessary casts later.
-
- // Determine DQ output element type (e.g., fp16)
- int32_t dq_output_dtype = cast_b_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- // Determine MatMul output element type (e.g., fp32)
- int32_t matmul_output_dtype = matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
-
- const auto& dq_attrs = dq_node->GetAttributes();
- const auto* weight_arg = dq_node->InputDefs()[0];
- auto K = weight_arg->Shape()->dim(0).dim_value();
- auto N = weight_arg->Shape()->dim(1).dim_value();
- auto block_size = dq_attrs.at("block_size").i();
- int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type();
- auto bits = DQWeightBits(dt_weight);
-
- // --- Create fp16 NodeArg for MatMulNBits input A ---
- NodeArg* matmul_input_a = matmul_node.MutableInputDefs()[0];
- ONNX_NAMESPACE::TypeProto input_a_fp16_type;
- input_a_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype);
- if (matmul_input_a->Shape()) {
- *input_a_fp16_type.mutable_tensor_type()->mutable_shape() =
- matmul_input_a->TypeAsProto()->tensor_type().shape();
- }
- auto cast_a_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_input_a_cast");
- NodeArg* input_a_arg = &graph.GetOrCreateNodeArg(cast_a_out_name, &input_a_fp16_type);
-
- // --- Create fp16 NodeArg for MatMulNBits output ---
- ONNX_NAMESPACE::TypeProto output_fp16_type;
- output_fp16_type.mutable_tensor_type()->set_elem_type(dq_output_dtype);
- if (matmul_node.OutputDefs()[0]->Shape()) {
- *output_fp16_type.mutable_tensor_type()->mutable_shape() =
- matmul_node.OutputDefs()[0]->TypeAsProto()->tensor_type().shape();
- }
- auto mnb_out_name = graph.GenerateNodeArgName(matmul_node.Name() + "_matmulnbits_out");
- NodeArg* mnb_output_arg = &graph.GetOrCreateNodeArg(mnb_out_name, &output_fp16_type);
-
- // --- Create MatMulNBits node ---
- NodeAttributes attrs;
- utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs);
- utils::SetNodeAttribute(utils::MakeAttribute("N", N), attrs);
- utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs);
- utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs);
- utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), attrs);
-
- auto& new_node = graph.AddNode(
- graph.GenerateNodeName(matmul_node.Name() + "_MatMulNBits"),
- "MatMulNBits",
- "Fused DQ+Cast+MatMul to MatMulNBits",
- {input_a_arg},
- {mnb_output_arg},
- &attrs,
- kMSDomain);
-
- const auto& target_provider = matmul_node.GetExecutionProviderType();
- new_node.SetExecutionProviderType(target_provider.empty() ? kCpuExecutionProvider : target_provider);
-
- // Add transposed weight, scale, zp to inputs
- auto& input_defs = new_node.MutableInputDefs();
- input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight)));
- new_node.MutableInputArgsCount().push_back(1);
-
- input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale)));
- new_node.MutableInputArgsCount().push_back(1);
-
- if (transposed.zero_point_proto) {
- input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, *transposed.zero_point_proto, std::move(*transposed.zero_point)));
- new_node.MutableInputArgsCount().push_back(1);
- }
-
- // --- Insert Cast on input A: matmul_input_dtype -> dq_output_dtype ---
- {
- NodeAttributes cast_attrs;
- utils::SetNodeAttribute(
- utils::MakeAttribute("to", static_cast(dq_output_dtype)),
- cast_attrs);
- auto& cast_node = graph.AddNode(
- graph.GenerateNodeName(matmul_node.Name() + "_Cast_input_a"),
- "Cast", "",
- {matmul_input_a},
- {input_a_arg},
- &cast_attrs,
- kOnnxDomain);
- cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType());
- }
-
- // --- Insert Cast on output: dq_output_dtype -> matmul_output_dtype ---
- {
- NodeAttributes cast_attrs;
- utils::SetNodeAttribute(
- utils::MakeAttribute("to", static_cast(matmul_output_dtype)),
- cast_attrs);
- auto& cast_node = graph.AddNode(
- graph.GenerateNodeName(matmul_node.Name() + "_Cast_output"),
- "Cast", "",
- {mnb_output_arg},
- {const_cast(matmul_node.OutputDefs()[0])},
- &cast_attrs,
- kOnnxDomain);
- cast_node.SetExecutionProviderType(new_node.GetExecutionProviderType());
- }
-
- // --- Remove original nodes ---
- auto remove_node = [&graph](Node* node) {
- if (node) {
- graph_utils::RemoveNodeOutputEdges(graph, *node);
- graph.RemoveNode(node->Index());
- }
- };
-
- remove_node(&matmul_node);
- remove_node(cast_b_node);
- remove_node(dq_node);
-
- return Status::OK();
-}
-
static std::vector GetGemmMoveInfo(bool does_q_node_exist) {
NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0};
NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1};
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h
index e112959cc58da..02a8353707599 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h
@@ -107,20 +107,6 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew {
concurrency::ThreadPool* intra_op_thread_pool_;
};
-// Used together with DQCastMatMulToMatMulNBitsSelector.
-// Handles DQ -> Cast(fp16->fp32) -> MatMul fusion to MatMulNBits,
-// including optional Cast on input A and output type alignment.
-struct DQCastMatMulToMatMulNBitsAction : public Action {
- DQCastMatMulToMatMulNBitsAction(int64_t accuracy_level,
- concurrency::ThreadPool* intra_op_thread_pool);
-
- Status Run(Graph&, const NodesToOptimize& selected_nodes) const override;
-
- private:
- int64_t accuracy_level_;
- concurrency::ThreadPool* intra_op_thread_pool_;
-};
-
struct GemmReplaceWithQuant : public Action {
GemmReplaceWithQuant();
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
index 0b04445692c9b..8cab6911646f2 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
@@ -7,6 +7,7 @@
#include
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
+
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h"
@@ -306,7 +307,12 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
intra_op_thread_pool);
#if !defined(ORT_MINIMAL_BUILD)
- std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider};
+ // Include "" (empty string) to match nodes not yet assigned to an EP.
+ // For FP16 models on CPU EP, FP16 MatMul nodes are not claimed during partitioning
+ // (no FP16 MatMul kernel on CPU), leaving their EP unassigned. The DQ->MatMul fusion
+ // should still apply; the action assigns kCpuExecutionProvider to the resulting
+ // MatMulNBits node (which has both float and float16 CPU kernels).
+ std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""};
std::unique_ptr selector = std::make_unique(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
@@ -316,25 +322,6 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
#else
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
#endif
-
- // DQ -> Cast(fp16->fp32) -> MatMul pattern.
- // Handles FP16 models where Cast nodes are inserted between DQ and MatMul.
- const std::string cast_action_name{"DQCastMatMulToMatMulNBits"};
-
- std::unique_ptr cast_action =
- std::make_unique(qdq_matmulnbits_accuracy_level,
- intra_op_thread_pool);
-
-#if !defined(ORT_MINIMAL_BUILD)
- std::unique_ptr cast_selector =
- std::make_unique(providers);
- qdq_selector_action_registry.RegisterSelectorAndAction(cast_action_name,
- {{"MatMul", {}}},
- std::move(cast_selector),
- std::move(cast_action));
-#else
- qdq_selector_action_registry.RegisterAction(cast_action_name, std::move(cast_action));
-#endif
}
void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
@@ -416,7 +403,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer(
apply_context,
// this transformer is compatible with CPU, DML, ACL and CUDA EP.
// There is further EP control on the rule level.
- {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} {
+ // Also accept nodes with empty EP (unassigned) so that individual selectors
+ // that include "" in their compatible providers can match unassigned nodes.
+ {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider, ""}} {
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
index c39dfeb082e35..8a00fe11ff3fd 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
@@ -651,75 +651,6 @@ bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Nod
return ValidateBlockwiseDQForMatMulNBits(graph, *dq_nodes[0]);
}
-std::optional
-DQCastMatMulToMatMulNBitsSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
- // Check EP compatibility
- const std::string_view node_ep = node.GetExecutionProviderType();
- if (!compatible_providers_.empty() &&
- std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) {
- return std::nullopt;
- }
-
- const auto& graph = graph_viewer.GetGraph();
-
- // node must be MatMul
- if (node.OpType() != "MatMul") {
- return std::nullopt;
- }
-
- if (node.InputDefs().size() < 2) {
- return std::nullopt;
- }
-
- // Check input B: must be Cast(fp16->fp32)
- const Node* cast_b = graph_viewer.GetProducerNode(node.InputDefs()[1]->Name());
- if (!cast_b || cast_b->OpType() != "Cast") {
- return std::nullopt;
- }
-
- const auto& cast_b_attrs = cast_b->GetAttributes();
- auto to_iter = cast_b_attrs.find("to");
- if (to_iter == cast_b_attrs.end() ||
- to_iter->second.i() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) {
- return std::nullopt;
- }
-
- // Cast B input must be fp16
- if (!cast_b->InputDefs()[0]->TypeAsProto() ||
- cast_b->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() !=
- ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) {
- return std::nullopt;
- }
-
- // Cast B must have exactly 1 output edge (to MatMul) and not be a graph output
- if (!optimizer_utils::CheckOutputEdges(graph, *cast_b, 1)) {
- return std::nullopt;
- }
-
- // Cast B's input must come from a DQ node
- const Node* dq_node = graph_viewer.GetProducerNode(cast_b->InputDefs()[0]->Name());
- if (!dq_node || dq_node->OpType() != QDQ::DQOpName) {
- return std::nullopt;
- }
-
- // DQ must have exactly 1 output edge (to Cast B) and not be a graph output
- if (!optimizer_utils::CheckOutputEdges(graph, *dq_node, 1)) {
- return std::nullopt;
- }
-
- if (!ValidateBlockwiseDQForMatMulNBits(graph, *dq_node)) {
- return std::nullopt;
- }
-
- // Build selection
- NodesToOptimizeIndicesBuilder builder;
- builder.input_nodes.push_back(dq_node->Index());
- builder.input_nodes.push_back(cast_b->Index());
- builder.target_node = node.Index();
-
- return builder.Build();
-}
-
bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node,
const std::vector& dq_nodes,
const std::vector& q_nodes) const {
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
index 5c10668733785..79c374b301442 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
@@ -461,27 +461,6 @@ class DQMatMulToMatMulNBitsSelector : public BaseSelector {
: BaseSelector(std::make_unique(), compatible_providers) {}
};
-// Convert "DQ -> Cast(fp16->fp32) -> MatMul" to "MatMulNBits".
-// Handles Cast(fp16->fp32) between DQ and MatMul on input B, and optionally on input A.
-// Selection layout:
-// input_nodes[0] = DQ node
-// input_nodes[1] = Cast on input B (between DQ and MatMul)
-// target_node = MatMul
-// output_nodes = {}
-class DQCastMatMulToMatMulNBitsSelector : public NodeSelector {
- public:
- explicit DQCastMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {})
- : compatible_providers_(compatible_providers.begin(), compatible_providers.end()) {}
-
- DQCastMatMulToMatMulNBitsSelector(DQCastMatMulToMatMulNBitsSelector&& rhs) noexcept
- : compatible_providers_(std::move(rhs.compatible_providers_)) {}
-
- std::optional Select(const GraphViewer& graph_viewer, const Node& node) const override;
-
- private:
- std::vector compatible_providers_;
-};
-
// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmSelector : public BaseSelector {
diff --git a/onnxruntime/core/providers/cpu/math/cumsum.cc b/onnxruntime/core/providers/cpu/math/cumsum.cc
index 8321b81021d19..14ea6712f7f46 100644
--- a/onnxruntime/core/providers/cpu/math/cumsum.cc
+++ b/onnxruntime/core/providers/cpu/math/cumsum.cc
@@ -13,29 +13,6 @@ using namespace onnxruntime;
namespace onnxruntime {
-namespace cumsum_op {
-Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out) {
- if (!axis_tensor)
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor must be provided to the CumSum op");
-
- if (axis_tensor->Shape().NumDimensions() > 1)
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be 0D or 1D");
-
- if (axis_tensor->IsDataType()) {
- axis_out = static_cast(axis_tensor->Data()[0]);
- } else if (axis_tensor->IsDataType()) {
- axis_out = axis_tensor->Data()[0];
- } else {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be of type `int32_t` or `int64_t`");
- }
-
- axis_out = HandleNegativeAxis(axis_out, input_rank);
-
- return Status::OK();
-}
-
-} // namespace cumsum_op
-
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(
CumSum,
11,
diff --git a/onnxruntime/core/providers/cpu/math/cumsum.h b/onnxruntime/core/providers/cpu/math/cumsum.h
index fa1c1ceb0df10..b7443ada40861 100644
--- a/onnxruntime/core/providers/cpu/math/cumsum.h
+++ b/onnxruntime/core/providers/cpu/math/cumsum.h
@@ -1,11 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#pragma once
#include "core/common/common.h"
+#include "core/providers/common.h"
+
+#ifndef SHARED_PROVIDER
#include "core/framework/op_kernel.h"
+#endif
namespace onnxruntime {
+#ifndef SHARED_PROVIDER
template
class CumSum final : public OpKernel {
public:
@@ -17,10 +23,33 @@ class CumSum final : public OpKernel {
int64_t exclusive_;
int64_t reverse_;
};
+#endif
namespace cumsum_op {
+#ifdef SHARED_PROVIDER
Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out);
+#else
+inline Status GetAxis(const Tensor* axis_tensor, int64_t input_rank, int64_t& axis_out) {
+ if (!axis_tensor)
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor must be provided to the CumSum op");
+
+ if (axis_tensor->Shape().NumDimensions() > 1)
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be 0D or 1D");
+
+ if (axis_tensor->IsDataType()) {
+ axis_out = static_cast(axis_tensor->Data()[0]);
+ } else if (axis_tensor->IsDataType()) {
+ axis_out = axis_tensor->Data()[0];
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Axis tensor should be of type `int32_t` or `int64_t`");
+ }
+
+ axis_out = HandleNegativeAxis(axis_out, input_rank);
+
+ return Status::OK();
+}
+#endif // SHARED_PROVIDER
} // namespace cumsum_op
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.cc b/onnxruntime/core/providers/cpu/object_detection/roialign.cc
index 87958a9f7e2dd..0680be3aea49c 100644
--- a/onnxruntime/core/providers/cpu/object_detection/roialign.cc
+++ b/onnxruntime/core/providers/cpu/object_detection/roialign.cc
@@ -258,76 +258,6 @@ void RoiAlignForward(const TensorShape& output_shape, const T* bottom_data, floa
}
} // namespace
-Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, const Tensor* batch_indices_ptr) {
- constexpr int64_t EXPECTED_NUM_ROI_DIMS = 2;
- constexpr int64_t EXPECTED_SECOND_ROI_DIM = 4;
- if (!X_ptr) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null input X ptr");
- }
- if (!rois_ptr) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null rois_ptr");
- }
- if (!batch_indices_ptr) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null batch_indices_ptr");
- }
-
- const auto& rois_dims = rois_ptr->Shape();
- const auto& batch_indices_dims = batch_indices_ptr->Shape();
-
- if (batch_indices_dims.NumDimensions() != 1) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "Number of dimensions for batch indices should be exactly 1");
- }
-
- // validate rois_dims
- if (rois_dims.NumDimensions() != EXPECTED_NUM_ROI_DIMS) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "Number of dimensions for rois should be exactly " + std::to_string(EXPECTED_NUM_ROI_DIMS));
- }
- if (rois_dims[1] != EXPECTED_SECOND_ROI_DIM) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "Second dimension for rois should be exactly " + std::to_string(EXPECTED_SECOND_ROI_DIM));
- }
-
- // first dimension of batch_indices and rois should match
- if (batch_indices_dims[0] != rois_dims[0]) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "First dimension (num_rois) of batch_indices and rois don't match");
- }
-
- if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) {
- // Validate batch_indices values are within [0, batch_size) when the tensor
- // data is accessible from the host (CPU).
- const int64_t batch_size = X_ptr->Shape()[0];
- const int64_t num_rois = batch_indices_dims[0];
-
- auto check_bounds = [batch_size, num_rois](const auto* batch_indices_data) -> Status {
- for (int64_t i = 0; i < num_rois; ++i) {
- if (batch_indices_data[i] < 0 || batch_indices_data[i] >= batch_size) {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "batch_indices value " + std::to_string(batch_indices_data[i]) +
- " at index " + std::to_string(i) +
- " is out of range [0, " + std::to_string(batch_size) + ")");
- }
- }
- return Status::OK();
- };
-
- if (batch_indices_ptr->IsDataType()) {
- auto status = check_bounds(batch_indices_ptr->Data());
- if (!status.IsOK()) return status;
- } else if (batch_indices_ptr->IsDataType()) {
- auto status = check_bounds(batch_indices_ptr->Data());
- if (!status.IsOK()) return status;
- } else {
- return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
- "batch_indices must be of type int64_t or int32_t");
- }
- }
-
- return Status::OK();
-}
-
template
Status RoiAlign::Compute(OpKernelContext* context) const {
const auto* X_ptr = context->Input(0);
diff --git a/onnxruntime/core/providers/cpu/object_detection/roialign.h b/onnxruntime/core/providers/cpu/object_detection/roialign.h
index 1bb8bd34c5cb2..bb97de158369b 100644
--- a/onnxruntime/core/providers/cpu/object_detection/roialign.h
+++ b/onnxruntime/core/providers/cpu/object_detection/roialign.h
@@ -3,12 +3,86 @@
#pragma once
-#include "core/framework/op_kernel.h"
+#include
#include
+#include
+
+#include "core/common/common.h"
+#ifndef SHARED_PROVIDER
+#include "core/framework/op_kernel.h"
+#endif
namespace onnxruntime {
+#ifdef SHARED_PROVIDER
Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, const Tensor* batch_indices_ptr);
+#else
+inline Status CheckROIAlignValidInput(const Tensor* X_ptr, const Tensor* rois_ptr, const Tensor* batch_indices_ptr) {
+ constexpr int64_t EXPECTED_NUM_ROI_DIMS = 2;
+ constexpr int64_t EXPECTED_SECOND_ROI_DIM = 4;
+ if (!X_ptr) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null input X ptr");
+ }
+ if (!rois_ptr) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null rois_ptr");
+ }
+ if (!batch_indices_ptr) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Null batch_indices_ptr");
+ }
+
+ const auto& rois_dims = rois_ptr->Shape();
+ const auto& batch_indices_dims = batch_indices_ptr->Shape();
+
+ if (batch_indices_dims.NumDimensions() != 1) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Number of dimensions for batch indices should be exactly 1");
+ }
+
+ if (rois_dims.NumDimensions() != EXPECTED_NUM_ROI_DIMS) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Number of dimensions for rois should be exactly " + std::to_string(EXPECTED_NUM_ROI_DIMS));
+ }
+ if (rois_dims[1] != EXPECTED_SECOND_ROI_DIM) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Second dimension for rois should be exactly " + std::to_string(EXPECTED_SECOND_ROI_DIM));
+ }
+
+ if (batch_indices_dims[0] != rois_dims[0]) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "First dimension (num_rois) of batch_indices and rois don't match");
+ }
+
+ if (batch_indices_ptr->Location().device.Type() == OrtDevice::CPU) {
+ const int64_t batch_size = X_ptr->Shape()[0];
+ const int64_t num_rois = batch_indices_dims[0];
+
+ auto check_bounds = [batch_size, num_rois](const auto* batch_indices_data) -> Status {
+ for (int64_t i = 0; i < num_rois; ++i) {
+ if (batch_indices_data[i] < 0 || batch_indices_data[i] >= batch_size) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "batch_indices value " + std::to_string(batch_indices_data[i]) +
+ " at index " + std::to_string(i) +
+ " is out of range [0, " + std::to_string(batch_size) + ")");
+ }
+ }
+ return Status::OK();
+ };
+
+ if (batch_indices_ptr->IsDataType()) {
+ auto status = check_bounds(batch_indices_ptr->Data());
+ if (!status.IsOK()) return status;
+ } else if (batch_indices_ptr->IsDataType()) {
+ auto status = check_bounds(batch_indices_ptr->Data());
+ if (!status.IsOK()) return status;
+ } else {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "batch_indices must be of type int64_t or int32_t");
+ }
+ }
+
+ return Status::OK();
+}
+#endif
enum struct RoiAlignMode {
avg = 0,
@@ -17,10 +91,10 @@ enum struct RoiAlignMode {
class RoiAlignBase {
public:
- explicit RoiAlignBase(const OpKernelInfo& info) {
- // mode
+ template
+ explicit RoiAlignBase(const TKernelInfo& info) {
std::string mode;
- if (info.GetAttr("mode", &mode).IsOK()) {
+ if (info.template GetAttr("mode", &mode).IsOK()) {
std::transform(mode.begin(), mode.end(), mode.begin(), [](char i) { return static_cast(::tolower(i)); });
if (mode == "avg") {
mode_ = RoiAlignMode::avg;
@@ -31,41 +105,33 @@ class RoiAlignBase {
}
}
- // output_height
int64_t output_height_tmp;
- if (info.GetAttr("output_height", &output_height_tmp).IsOK()) {
+ if (info.template GetAttr("output_height", &output_height_tmp).IsOK()) {
output_height_ = output_height_tmp;
}
- // output_width
int64_t output_width_tmp;
- if (info.GetAttr("output_width", &output_width_tmp).IsOK()) {
+ if (info.template GetAttr("output_width", &output_width_tmp).IsOK()) {
output_width_ = output_width_tmp;
}
- // sampling_ratio
int64_t sampling_ratio_tmp;
- if (info.GetAttr("sampling_ratio", &sampling_ratio_tmp).IsOK()) {
+ if (info.template GetAttr("sampling_ratio", &sampling_ratio_tmp).IsOK()) {
sampling_ratio_ = sampling_ratio_tmp;
ORT_ENFORCE(sampling_ratio_ >= 0, "Sampling ratio should be >=0, but it was ", sampling_ratio_);
}
- // spatial_scale
float spatial_scale_tmp;
- if (info.GetAttr("spatial_scale", &spatial_scale_tmp).IsOK()) {
+ if (info.template GetAttr("spatial_scale", &spatial_scale_tmp).IsOK()) {
spatial_scale_ = spatial_scale_tmp;
}
std::string coordinate_transformation_mode;
- if (info.GetAttr("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) {
- if (coordinate_transformation_mode == "half_pixel")
- half_pixel_ = true;
- else
- half_pixel_ = false;
+ if (info.template GetAttr("coordinate_transformation_mode", &coordinate_transformation_mode).IsOK()) {
+ half_pixel_ = coordinate_transformation_mode == "half_pixel";
}
if (mode_ == RoiAlignMode::max && sampling_ratio_ != 1) {
- // TODO(fdwr): Issue #6146. ORT 1.13 will correct the incorrect summation of max mode with PR #7354.
LOGS_DEFAULT(WARNING) << "The existing summation for max mode and sampling ratios besides 1 is incorrect "
<< "and will be fixed in the next ORT 1.13 release. Thus the results of RoiAlign "
<< "will be different.";
diff --git a/onnxruntime/core/providers/cpu/tensor/concat.cc b/onnxruntime/core/providers/cpu/tensor/concat.cc
index e3d5c0600420f..98d61ed1d3127 100644
--- a/onnxruntime/core/providers/cpu/tensor/concat.cc
+++ b/onnxruntime/core/providers/cpu/tensor/concat.cc
@@ -49,14 +49,6 @@ using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExec
Concat, Input, 0);
} // namespace
-// this method will be shared between 'Concat' (CPU and GPU) and
-// 'ConcatFromSequence' ('concat' and 'stack' modes) to validate inputs
-Status ConcatBase::PrepareForCompute(OpKernelContext* ctx,
- const InlinedTensorsVector& input_tensors,
- Prepare& p) const {
- return PrepareForComputeImpl(ctx, input_tensors, p);
-}
-
namespace {
TensorShapeVector StridesForStack(const TensorShapeVector& full_strides, uint64_t axis) {
// if we are stacking, skip the dimension that will be stacked along in the output strides
diff --git a/onnxruntime/core/providers/cpu/tensor/concatbase.h b/onnxruntime/core/providers/cpu/tensor/concatbase.h
index b9085b2a9318b..df2eb78c61180 100644
--- a/onnxruntime/core/providers/cpu/tensor/concatbase.h
+++ b/onnxruntime/core/providers/cpu/tensor/concatbase.h
@@ -209,8 +209,16 @@ class ConcatBase {
return Status::OK();
}
+#ifdef SHARED_PROVIDER
Status PrepareForCompute(OpKernelContext* ctx, const InlinedTensorsVector& input_tensors,
Prepare& p) const;
+#else
+ template
+ inline Status PrepareForCompute(KernelContextType* ctx, const InlinedTensorsVector& input_tensors,
+ Prepare& p) const {
+ return PrepareForComputeImpl(ctx, input_tensors, p);
+ }
+#endif
protected:
template
diff --git a/onnxruntime/core/providers/cpu/tensor/gather.cc b/onnxruntime/core/providers/cpu/tensor/gather.cc
index f171b33ee5f4f..3b3c67e7d818b 100644
--- a/onnxruntime/core/providers/cpu/tensor/gather.cc
+++ b/onnxruntime/core/providers/cpu/tensor/gather.cc
@@ -56,10 +56,6 @@ ONNX_CPU_OPERATOR_KERNEL(
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList()),
Gather);
-Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
- return PrepareForComputeImpl(context, p);
-}
-
template
Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uint8_t* dst_base, bool is_string_type,
const size_t element_bytes, const int64_t block_size, const int64_t M,
diff --git a/onnxruntime/core/providers/cpu/tensor/gatherbase.h b/onnxruntime/core/providers/cpu/tensor/gatherbase.h
index 1f5e85c554a78..fc29c04290883 100644
--- a/onnxruntime/core/providers/cpu/tensor/gatherbase.h
+++ b/onnxruntime/core/providers/cpu/tensor/gatherbase.h
@@ -46,7 +46,14 @@ class GatherBase {
return Status::OK();
}
+#ifdef SHARED_PROVIDER
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
+#else
+ template
+ inline Status PrepareForCompute(KernelContextType* context, Prepare& p) const {
+ return PrepareForComputeImpl(context, p);
+ }
+#endif
protected:
template
diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
index 3218c8952d6ec..14a22fa7be0af 100644
--- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
+++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
@@ -9,8 +9,9 @@ namespace onnxruntime {
class SpaceDepthBase {
protected:
- explicit SpaceDepthBase(const OpKernelInfo& info) {
- ORT_ENFORCE(info.GetAttr("blocksize", &blocksize_).IsOK(),
+ template
+ explicit SpaceDepthBase(const KernelInfoType& info) {
+ ORT_ENFORCE(info.template GetAttr("blocksize", &blocksize_).IsOK(),
"Attribute blocksize is not set.");
}
diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc b/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc
index 1b6ee02061d34..5fdb57b1a5e35 100644
--- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc
+++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc
@@ -77,57 +77,6 @@ ONNX_CPU_OPERATOR_KERNEL(
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
Unsqueeze);
-Status UnsqueezeBase::PrepareCompute(OpKernelContext* ctx, Prepare& p) const {
- const auto* X = ctx->Input(0);
- ORT_ENFORCE(X != nullptr);
- auto& input_tensor = *X;
-
- TensorShapeVector axes;
- size_t num_inputs = ctx->InputCount();
- if (num_inputs == 2) { // axes is an input
- const Tensor* axes_tensor = ctx->Input(1);
- ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
- ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 ||
- axes_tensor->Shape().NumDimensions() == 1,
- "An axes tensor must be a scalar or a 1-D tensor.");
- auto data_span = axes_tensor->template DataAsSpan();
- axes.assign(data_span.begin(), data_span.end());
- } else {
- axes.assign(axes_.begin(), axes_.end());
- }
-
- // New dimension count is the current dimensions + the number of entries in axes
- // Initialize output_dims to 0 in each axis initially
- TensorShapeVector output_dims(axes.size() + input_tensor.Shape().NumDimensions(), 0);
-
- // Set all axes indices to 1 in output_dims and check for duplicates
- for (int64_t axis : axes) {
- // Valid axis range is [0, output_rank - 1]
- axis = HandleNegativeAxis(axis, onnxruntime::narrow(output_dims.size()));
- if (axis < 0 || axis >= static_cast(output_dims.size()))
- return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has an out of range axis");
- if (output_dims[onnxruntime::narrow(axis)] != 0)
- return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'axes' has a duplicate axis");
- output_dims[onnxruntime::narrow(axis)] = 1;
- }
-
- // Now fill in the zero entries with the existing shape
- {
- auto begin = input_tensor.Shape().GetDims().begin();
- for (auto& axisSize : output_dims) {
- if (axisSize == 0)
- axisSize = *begin++;
- }
- assert(begin == input_tensor.Shape().GetDims().end());
- }
-
- TensorShape output_shape(output_dims);
- p.output_tensor = ctx->Output(0, output_shape);
- ORT_ENFORCE(nullptr != p.output_tensor);
- p.input_tensor = &input_tensor;
- return Status::OK();
-}
-
Status Unsqueeze::Compute(OpKernelContext* ctx) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareCompute(ctx, p));
diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h
index 5a8a318923da5..09a77c113e022 100644
--- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h
+++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h
@@ -19,7 +19,57 @@ class UnsqueezeBase {
Tensor* output_tensor = nullptr;
};
+#ifdef SHARED_PROVIDER
Status PrepareCompute(OpKernelContext* context, Prepare& p) const;
+#else
+ template
+ inline Status PrepareCompute(KernelContextType* ctx, Prepare& p) const {
+ const auto* X = ctx->template Input(0);
+ ORT_ENFORCE(X != nullptr);
+ auto& input_tensor = *X;
+
+ TensorShapeVector axes;
+ size_t num_inputs = ctx->InputCount();
+ if (num_inputs == 2) {
+ const Tensor* axes_tensor = ctx->template Input(1);
+ ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
+ ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 ||
+ axes_tensor->Shape().NumDimensions() == 1,
+ "An axes tensor must be a scalar or a 1-D tensor.");
+ auto data_span = axes_tensor->template DataAsSpan();
+ axes.assign(data_span.begin(), data_span.end());
+ } else {
+ axes.assign(axes_.begin(), axes_.end());
+ }
+
+ TensorShapeVector output_dims(axes.size() + input_tensor.Shape().NumDimensions(), 0);
+
+ for (int64_t axis : axes) {
+ axis = HandleNegativeAxis(axis, onnxruntime::narrow(output_dims.size()));
+ if (axis < 0 || axis >= static_cast(output_dims.size()))
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "'axes' has an out of range axis");
+ if (output_dims[onnxruntime::narrow(axis)] != 0)
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "'axes' has a duplicate axis");
+ output_dims[onnxruntime::narrow(axis)] = 1;
+ }
+
+ {
+ auto begin = input_tensor.Shape().GetDims().begin();
+ for (auto& axis_size : output_dims) {
+ if (axis_size == 0)
+ axis_size = *begin++;
+ }
+ assert(begin == input_tensor.Shape().GetDims().end());
+ }
+
+ TensorShape output_shape(output_dims);
+ p.output_tensor = ctx->Output(0, output_shape);
+ ORT_ENFORCE(nullptr != p.output_tensor);
+ p.input_tensor = &input_tensor;
+ return Status::OK();
+ }
+#endif
+
static TensorShapeVector ComputeOutputShape(
const TensorShape& input_shape,
const TensorShapeVector& axes) {
diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc
index b533f1b7dc80b..87ba56bc45dad 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsample.cc
+++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc
@@ -2,10 +2,6 @@
// Licensed under the MIT License.
#include "core/providers/cpu/tensor/upsample.h"
-
-#include
-
-#include "core/common/inlined_containers.h"
#include "core/common/safeint.h"
#include "core/platform/threadpool.h"
#include "core/providers/cpu/tensor/upsample_antialias.h"
@@ -35,46 +31,6 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9);
REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9);
-void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims,
- InlinedVector& scales) const {
- // AspectRatioPolicy::STRETCH is default policy when opset < 18
- if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) {
- return;
- }
-
- InlinedHashSet axes_set(axes_.begin(), axes_.end());
-
- float scale_in_policy = 0.0f;
- if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) {
- scale_in_policy = std::numeric_limits::max();
-
- for (size_t i = 0; i < scales.size(); i++) {
- if (axes_set.empty() || axes_set.count(i) > 0) {
- scale_in_policy = std::min(scale_in_policy, scales[i]);
- }
- }
- } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) {
- scale_in_policy = std::numeric_limits::min();
-
- for (size_t i = 0; i < scales.size(); i++) {
- if (axes_set.empty() || axes_set.count(i) > 0) {
- scale_in_policy = std::max(scale_in_policy, scales[i]);
- }
- }
- }
-
- for (size_t i = 0; i < scales.size(); i++) {
- // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes
- if (axes_set.empty() || axes_set.count(i) > 0) {
- scales[i] = scale_in_policy;
- output_dims[i] = static_cast(std::round(scales[i] * input_dims[i]));
- } else {
- scales[i] = 1.0f;
- output_dims[i] = input_dims[i];
- }
- }
-}
-
template
void UpsampleNearest2x(int64_t batch_size,
int64_t num_channels,
diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
index b0e309a70444f..7dcf88133e967 100644
--- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
+++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h
@@ -4,9 +4,12 @@
#pragma once
#include
+#include
+#include
#include
#include
#include
+#include
#include
#include
@@ -120,6 +123,49 @@ void PrintAntiAliasBuffers(std::ostream& os, gsl::span bounds, gsl::spa
os << std::endl;
}
+namespace upsamplebase_helper {
+
+inline void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims,
+ InlinedVector& scales, AspectRatioPolicy keep_aspect_ratio_policy,
+ const TensorShapeVector& axes) {
+ if (keep_aspect_ratio_policy == AspectRatioPolicy::STRETCH) {
+ return;
+ }
+
+ std::unordered_set axes_set(axes.begin(), axes.end());
+
+ float scale_in_policy = 0.0f;
+ if (keep_aspect_ratio_policy == AspectRatioPolicy::NOT_LARGER) {
+ scale_in_policy = std::numeric_limits::max();
+
+ for (size_t i = 0; i < scales.size(); ++i) {
+ if (axes_set.empty() || axes_set.count(static_cast(i)) > 0) {
+ scale_in_policy = std::min(scale_in_policy, scales[i]);
+ }
+ }
+ } else if (keep_aspect_ratio_policy == AspectRatioPolicy::NOT_SMALLER) {
+ scale_in_policy = std::numeric_limits::min();
+
+ for (size_t i = 0; i < scales.size(); ++i) {
+ if (axes_set.empty() || axes_set.count(static_cast(i)) > 0) {
+ scale_in_policy = std::max(scale_in_policy, scales[i]);
+ }
+ }
+ }
+
+ for (size_t i = 0; i < scales.size(); ++i) {
+ if (axes_set.empty() || axes_set.count(static_cast(i)) > 0) {
+ scales[i] = scale_in_policy;
+ output_dims[i] = static_cast(std::round(scales[i] * input_dims[i]));
+ } else {
+ scales[i] = 1.0f;
+ output_dims[i] = input_dims[i];
+ }
+ }
+}
+
+} // namespace upsamplebase_helper
+
class UpsampleBase {
public:
// Make this available in other EP via provider bridge
@@ -597,6 +643,13 @@ class UpsampleBase {
}
}; // UpsampleBase
+#ifndef SHARED_PROVIDER
+inline void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims,
+ InlinedVector& scales) const {
+ upsamplebase_helper::AdjustOutputSizeAsPolicy(output_dims, input_dims, scales, keep_aspect_ratio_policy_, axes_);
+}
+#endif
+
} // namespace onnxruntime
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 60ac16018f539..9cfc38c8f292f 100755
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -944,8 +944,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Resize);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, ReverseSequence);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, float, RoiAlign);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, double, RoiAlign);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int32_t, Slice);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, ThresholdedRelu);
@@ -1424,6 +1427,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 19, float, GridSample);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 21, MLFloat16, RoiAlign);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1592,6 +1596,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GridSample);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, RoiAlign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, RoiAlign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, RoiAlign);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, RoiAlign);
// Opset 23.
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention);
@@ -2027,8 +2035,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2676,6 +2687,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 23
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign.cc b/onnxruntime/core/providers/cuda/object_detection/roialign.cc
index 71fb066c2898f..b4f63b3aa04f2 100644
--- a/onnxruntime/core/providers/cuda/object_detection/roialign.cc
+++ b/onnxruntime/core/providers/cuda/object_detection/roialign.cc
@@ -7,11 +7,37 @@
namespace onnxruntime {
namespace cuda {
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
+#define ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
RoiAlign, \
kOnnxDomain, \
10, \
+ 15, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ RoiAlign);
+
+#define ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \
+ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
+ RoiAlign, \
+ kOnnxDomain, \
+ 16, \
+ 21, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
+ RoiAlign);
+
+#define ADD_TYPED_ROIALIGN_OP_22(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RoiAlign, \
+ kOnnxDomain, \
+ 22, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
@@ -67,13 +93,19 @@ Status RoiAlign::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}
-#define SPECIALIZED_COMPUTE(T) \
- REGISTER_KERNEL_TYPED(T) \
+#define SPECIALIZED_COMPUTE(T) \
+ ADD_VERSIONED_TYPED_ROIALIGN_OP_10(T) \
+ ADD_VERSIONED_TYPED_ROIALIGN_OP_16(T) \
+ ADD_TYPED_ROIALIGN_OP_22(T) \
template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const;
SPECIALIZED_COMPUTE(float)
SPECIALIZED_COMPUTE(double)
-// SPECIALIZED_COMPUTE(MLFloat16)
+SPECIALIZED_COMPUTE(MLFloat16)
+
+// BFloat16 is available for RoiAlign op from version 22:
+ADD_TYPED_ROIALIGN_OP_22(BFloat16)
+template Status RoiAlign::ComputeInternal(OpKernelContext* ctx) const;
} // namespace cuda
}; // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu
index 7acfd9d075461..76f6f26fd8a02 100644
--- a/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu
+++ b/onnxruntime/core/providers/cuda/object_detection/roialign_impl.cu
@@ -17,12 +17,13 @@
#include "roialign_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/shared_inc/accumulation_type.h"
namespace onnxruntime {
namespace cuda {
template
-__device__ T bilinear_interpolate(
+__device__ AccumulationType_t bilinear_interpolate(
const T* bottom_data,
const int height,
const int width,
@@ -30,51 +31,61 @@ __device__ T bilinear_interpolate(
T x,
const bool is_mode_avg,
const int index /* index for debug only*/) {
+ using TAcc = AccumulationType_t;
+
+ TAcc y_acc = static_cast(y);
+ TAcc x_acc = static_cast(x);
+
// deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ if (y_acc < static_cast(-1.0f) || y_acc > static_cast(height) ||
+ x_acc < static_cast(-1.0f) || x_acc > static_cast(width)) {
// empty
- return 0;
+ return static_cast(0.0f);
}
- if (y <= 0) {
- y = 0;
+ if (y_acc <= static_cast(0.0f)) {
+ y_acc = static_cast(0.0f);
}
- if (x <= 0) {
- x = 0;
+ if (x_acc <= static_cast(0.0f)) {
+ x_acc = static_cast(0.0f);
}
- int y_low = (int)y;
- int x_low = (int)x;
+ int y_low = static_cast(y_acc);
+ int x_low = static_cast(x_acc);
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
- y = (T)y_low;
+ y_acc = static_cast(y_low);
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
- x = (T)x_low;
+ x_acc = static_cast(x_low);
} else {
x_high = x_low + 1;
}
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
+ TAcc ly = y_acc - static_cast(y_low);
+ TAcc lx = x_acc - static_cast(x_low);
+ TAcc hy = static_cast(1.0f) - ly;
+ TAcc hx = static_cast(1.0f) - lx;
// do bilinear interpolation
- T v1 = bottom_data[y_low * width + x_low];
- T v2 = bottom_data[y_low * width + x_high];
- T v3 = bottom_data[y_high * width + x_low];
- T v4 = bottom_data[y_high * width + x_high];
- T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+ TAcc v1 = static_cast(bottom_data[y_low * width + x_low]);
+ TAcc v2 = static_cast(bottom_data[y_low * width + x_high]);
+ TAcc v3 = static_cast(bottom_data[y_high * width + x_low]);
+ TAcc v4 = static_cast(bottom_data[y_high * width + x_high]);
+ TAcc w1 = hy * hx;
+ TAcc w2 = hy * lx;
+ TAcc w3 = ly * hx;
+ TAcc w4 = ly * lx;
- T val = is_mode_avg
- ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg
- : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max
+ TAcc val = is_mode_avg
+ ? (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4) // mode Avg
+ : max(max(max(w1 * v1, w2 * v2), w3 * v3), w4 * v4); // mode Max
return val;
}
@@ -97,6 +108,8 @@ __global__ void RoIAlignForward(
const bool half_pixel,
const int64_t* batch_indices_ptr,
const int64_t batch_size) {
+ using TAcc = AccumulationType_t;
+
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
@@ -111,26 +124,27 @@ __global__ void RoIAlignForward(
// If the index is out of range, we set the output to 0 for this RoI element.
if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) {
CUDA_KERNEL_ASSERT(false && "batch_indices values are out of range");
- top_data[index] = 0;
+ top_data[index] = static_cast(0.0f);
continue;
}
// Do not using rounding; this implementation detail is critical
- T roi_offset = half_pixel ? T(0.5) : T(0);
- T roi_start_w = offset_bottom_rois[0] * spatial_scale - roi_offset;
- T roi_start_h = offset_bottom_rois[1] * spatial_scale - roi_offset;
- T roi_end_w = offset_bottom_rois[2] * spatial_scale - roi_offset;
- T roi_end_h = offset_bottom_rois[3] * spatial_scale - roi_offset;
-
- T roi_width = roi_end_w - roi_start_w;
- T roi_height = roi_end_h - roi_start_h;
+ const TAcc spatial_scale_acc = static_cast(spatial_scale);
+ const TAcc roi_offset = half_pixel ? static_cast(0.5f) : static_cast