From ba0b6e472ce4c857e106aec013dafeab84b951a1 Mon Sep 17 00:00:00 2001 From: xufyan Date: Thu, 5 Mar 2026 06:12:02 +0000 Subject: [PATCH] Add parallel BKT tree build support via level-order BFS with OpenMP Add BuildTreesParallel() method to BKTree that parallelizes tree construction using a level-order (BFS) approach instead of the existing depth-first recursive method. At each level of the tree, all sibling nodes are processed in parallel using OpenMP, with each thread running independent k-means clustering. The tree structure is then assembled sequentially to maintain correctness. This is controlled by a new ParallelBKTBuild parameter (default: false) that can be enabled in both BKT index and SPANN select-head configurations. Benchmark on SIFT 50M (128-dim, L2) with 32 threads on Azure L32s_v2: - Select Head (BKT build): 16.6 hours -> 1.2 hours (13.6x speedup) - Build Head graph (RefineGraph): unchanged (~10 hours, memory-bound) - Total end-to-end build: ~30 hours -> ~15 hours - Recall@1: 91% -> 94% (slight improvement) - Query latency: comparable (P50 ~40ms) --- .../inc/Core/BKT/ParameterDefinitionList.h | 1 + AnnService/inc/Core/Common/BKTree.h | 170 ++++++++++++++++++ AnnService/inc/Core/SPANN/Options.h | 1 + .../inc/Core/SPANN/ParameterDefinitionList.h | 1 + AnnService/src/Core/BKT/BKTIndex.cpp | 6 +- AnnService/src/Core/SPANN/SPANNIndex.cpp | 16 +- 6 files changed, 189 insertions(+), 6 deletions(-) diff --git a/AnnService/inc/Core/BKT/ParameterDefinitionList.h b/AnnService/inc/Core/BKT/ParameterDefinitionList.h index 8bf242f90..e71201bb5 100644 --- a/AnnService/inc/Core/BKT/ParameterDefinitionList.h +++ b/AnnService/inc/Core/BKT/ParameterDefinitionList.h @@ -15,6 +15,7 @@ DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK") DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize") DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples") DefineBKTParameter(m_pTrees.m_fBalanceFactor, float, 100.0F, "BKTLambdaFactor") +DefineBKTParameter(m_pTrees.m_parallelBuild, bool, false, "ParallelBKTBuild") DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber") DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize") diff --git a/AnnService/inc/Core/Common/BKTree.h b/AnnService/inc/Core/Common/BKTree.h index e59962bd3..10f1ddd76 100644 --- a/AnnService/inc/Core/Common/BKTree.h +++ b/AnnService/inc/Core/Common/BKTree.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "inc/Core/VectorIndex.h" #include "CommonUtils.h" @@ -655,6 +657,173 @@ break; } } + // Parallel BKTree Build - processes sibling nodes in parallel + template + void BuildTreesParallel(const Dataset& data, DistCalcMethod distMethod, int numOfThreads, + std::vector* indices = nullptr, std::vector* reverseIndices = nullptr, + bool dynamicK = false, IAbortOperation* abort = nullptr) + { + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using PARALLEL BKTree build with %d threads.\n", numOfThreads); + + // Helper struct for collecting parallel results + struct ParallelNodeResult { + SizeType parentIndex; + SizeType first, last; + std::vector childCenters; + std::vector childCounts; + bool isLeaf; + bool singleCluster; + SizeType singleClusterCenter; + }; + + struct BKTStackItem { + SizeType index, first, last; + bool debug; + BKTStackItem(SizeType index_ = -1, SizeType first_ = 0, SizeType last_ = 0, bool debug_ = false) + : index(index_), first(first_), last(last_), debug(debug_) {} + }; + + std::vector localindices; + if (indices == nullptr) { + localindices.resize(data.R()); + for (SizeType i = 0; i < (SizeType)localindices.size(); i++) localindices[i] = i; + } + else { + localindices.assign(indices->begin(), indices->end()); + } + + // Create a shared KmeansArgs for DynamicFactorSelect (uses all threads) + KmeansArgs sharedArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), numOfThreads, distMethod, m_pQuantizer); + + if (m_fBalanceFactor < 0) { + m_fBalanceFactor = DynamicFactorSelect(data, localindices, 0, (SizeType)localindices.size(), sharedArgs, m_iSamples); + } + + std::mt19937 rg; + m_pSampleCenterMap.clear(); + + for (char treeIdx = 0; treeIdx < m_iTreeNumber; treeIdx++) + { + std::shuffle(localindices.begin(), localindices.end(), rg); + + m_pTreeStart.push_back((SizeType)m_pTreeRoots.size()); + m_pTreeRoots.emplace_back((SizeType)localindices.size()); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d (parallel)\n", treeIdx + 1); + + // Level-order processing + std::vector currentLevel, nextLevel; + currentLevel.push_back(BKTStackItem(m_pTreeStart[treeIdx], 0, (SizeType)localindices.size(), true)); + + int level = 0; + while (!currentLevel.empty()) { + if (abort && abort->ShouldAbort()) { + SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Abort!!!\n"); + return; + } + + size_t levelSize = currentLevel.size(); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Processing level %d with %zu nodes...\n", level, levelSize); + + std::vector results(levelSize); + + // Parallel phase: Run k-means for all nodes in this level + #pragma omp parallel for schedule(dynamic, 1) num_threads(numOfThreads) + for (int idx = 0; idx < (int)levelSize; idx++) { + BKTStackItem& item = currentLevel[idx]; + ParallelNodeResult& result = results[idx]; + result.parentIndex = item.index; + result.first = item.first; + result.last = item.last; + result.isLeaf = false; + result.singleCluster = false; + + if (item.last - item.first <= m_iBKTLeafSize) { + // Leaf node + result.isLeaf = true; + for (SizeType j = item.first; j < item.last; j++) { + SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]); + result.childCenters.push_back(cid); + } + } else { + // K-means clustering - use thread-local args with 1 thread + // (parallelism is at the node level, not within k-means) + // IMPORTANT: Must use full dataset size because KmeansAssign uses absolute indices + // (args.label[i] where i ranges from first to last, not 0 to rangeSize) + KmeansArgs localArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), 1, distMethod, m_pQuantizer); + + int dk = m_iBKTKmeansK; + if (dynamicK) { + dk = std::min((item.last - item.first) / m_iBKTLeafSize + 1, m_iBKTKmeansK); + dk = std::max(dk, 2); + localArgs._DK = dk; + } + + int numClusters = KmeansClustering(data, localindices, item.first, item.last, localArgs, + m_iSamples, m_fBalanceFactor, false, abort); + + if (numClusters <= 1) { + result.singleCluster = true; + SizeType end = min(item.last + 1, (SizeType)localindices.size()); + std::sort(localindices.begin() + item.first, localindices.begin() + end); + result.singleClusterCenter = (reverseIndices == nullptr) ? localindices[item.first] : reverseIndices->at(localindices[item.first]); + for (SizeType j = item.first + 1; j < end; j++) { + SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]); + result.childCenters.push_back(cid); + } + } else { + SizeType pos = item.first; + for (int k = 0; k < m_iBKTKmeansK; k++) { + if (localArgs.counts[k] == 0) continue; + SizeType cid = (reverseIndices == nullptr) ? localindices[pos + localArgs.counts[k] - 1] : reverseIndices->at(localindices[pos + localArgs.counts[k] - 1]); + result.childCenters.push_back(cid); + result.childCounts.push_back(localArgs.counts[k]); + pos += localArgs.counts[k]; + } + } + } + } + + // Sequential phase: Build tree structure and prepare next level + nextLevel.clear(); + for (size_t idx = 0; idx < levelSize; idx++) { + ParallelNodeResult& result = results[idx]; + m_pTreeRoots[result.parentIndex].childStart = (SizeType)m_pTreeRoots.size(); + + if (result.isLeaf) { + for (SizeType cid : result.childCenters) { + m_pTreeRoots.emplace_back(cid); + } + } else if (result.singleCluster) { + m_pTreeRoots[result.parentIndex].centerid = result.singleClusterCenter; + m_pTreeRoots[result.parentIndex].childStart = -m_pTreeRoots[result.parentIndex].childStart; + for (SizeType cid : result.childCenters) { + m_pTreeRoots.emplace_back(cid); + m_pSampleCenterMap[cid] = result.singleClusterCenter; + } + m_pSampleCenterMap[-1 - result.singleClusterCenter] = result.parentIndex; + } else { + SizeType pos = result.first; + for (size_t c = 0; c < result.childCenters.size(); c++) { + SizeType nodeIdx = (SizeType)m_pTreeRoots.size(); + m_pTreeRoots.emplace_back(result.childCenters[c]); + if (result.childCounts[c] > 1) { + nextLevel.push_back(BKTStackItem(nodeIdx, pos, pos + result.childCounts[c] - 1, false)); + } + pos += result.childCounts[c]; + } + } + m_pTreeRoots[result.parentIndex].childEnd = (SizeType)m_pTreeRoots.size(); + } + + currentLevel.swap(nextLevel); + level++; + } + + m_pTreeRoots.emplace_back(-1); + SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d BKTree built (parallel), %zu %zu\n", treeIdx + 1, m_pTreeRoots.size() - m_pTreeStart[treeIdx], localindices.size()); + } + } + inline std::uint64_t BufferSize() const { return sizeof(int) + sizeof(SizeType) * m_iTreeNumber + @@ -863,6 +1032,7 @@ break; int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples, m_bfs; float m_fBalanceFactor; std::shared_ptr m_pQuantizer; + bool m_parallelBuild = false; }; } } diff --git a/AnnService/inc/Core/SPANN/Options.h b/AnnService/inc/Core/SPANN/Options.h index f49621230..077842684 100644 --- a/AnnService/inc/Core/SPANN/Options.h +++ b/AnnService/inc/Core/SPANN/Options.h @@ -70,6 +70,7 @@ namespace SPTAG { bool m_recursiveCheckSmallCluster; bool m_printSizeCount; std::string m_selectType; + bool m_parallelBKTBuild; // Section 3: for build head bool m_buildHead; diff --git a/AnnService/inc/Core/SPANN/ParameterDefinitionList.h b/AnnService/inc/Core/SPANN/ParameterDefinitionList.h index 0a88e3f1d..e758c46c4 100644 --- a/AnnService/inc/Core/SPANN/ParameterDefinitionList.h +++ b/AnnService/inc/Core/SPANN/ParameterDefinitionList.h @@ -62,6 +62,7 @@ DefineSelectHeadParameter(m_headVectorCount, int, 0, "Count") DefineSelectHeadParameter(m_recursiveCheckSmallCluster, bool, true, "RecursiveCheckSmallCluster") DefineSelectHeadParameter(m_printSizeCount, bool, true, "PrintSizeCount") DefineSelectHeadParameter(m_selectType, std::string, "BKT", "SelectHeadType") +DefineSelectHeadParameter(m_parallelBKTBuild, bool, false, "ParallelBKTBuild") #endif #ifdef DefineBuildHeadParameter diff --git a/AnnService/src/Core/BKT/BKTIndex.cpp b/AnnService/src/Core/BKT/BKTIndex.cpp index 77ecce4fc..7b1ebf91a 100644 --- a/AnnService/src/Core/BKT/BKTIndex.cpp +++ b/AnnService/src/Core/BKT/BKTIndex.cpp @@ -840,7 +840,11 @@ ErrorCode Index::BuildIndex(const void *p_data, SizeType p_vectorNum, Dimensi m_threadPool.init(); auto t1 = std::chrono::high_resolution_clock::now(); - m_pTrees.BuildTrees(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads); + if (m_pTrees.m_parallelBuild) { + m_pTrees.BuildTreesParallel(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads); + } else { + m_pTrees.BuildTrees(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads); + } auto t2 = std::chrono::high_resolution_clock::now(); SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n", std::chrono::duration_cast(t2 - t1).count()); diff --git a/AnnService/src/Core/SPANN/SPANNIndex.cpp b/AnnService/src/Core/SPANN/SPANNIndex.cpp index 0c34a17b8..b3e613723 100644 --- a/AnnService/src/Core/SPANN/SPANNIndex.cpp +++ b/AnnService/src/Core/SPANN/SPANNIndex.cpp @@ -938,16 +938,22 @@ bool Index::SelectHeadInternal(std::shared_ptr &p_re bkt->m_iSamples = m_options.m_iSamples; bkt->m_iTreeNumber = m_options.m_iTreeNumber; bkt->m_fBalanceFactor = m_options.m_fBalanceFactor; + bkt->m_parallelBuild = m_options.m_parallelBKTBuild; bkt->m_pQuantizer = m_pQuantizer; SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n"); SPTAGLIB_LOG( Helper::LogLevel::LL_Info, - "BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n", + "BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d, ParallelBuild: %s.\n", bkt->m_iBKTKmeansK, bkt->m_iBKTLeafSize, bkt->m_iSamples, bkt->m_fBalanceFactor, bkt->m_iTreeNumber, - m_options.m_iSelectHeadNumberOfThreads); - - bkt->BuildTrees(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads, - nullptr, nullptr, true); + m_options.m_iSelectHeadNumberOfThreads, m_options.m_parallelBKTBuild ? "true" : "false"); + + if (bkt->m_parallelBuild) { + bkt->BuildTreesParallel(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads, + nullptr, nullptr, true); + } else { + bkt->BuildTrees(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads, + nullptr, nullptr, true); + } auto t2 = std::chrono::high_resolution_clock::now(); double elapsedSeconds = std::chrono::duration_cast(t2 - t1).count(); SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n");