Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AnnService/inc/Core/BKT/ParameterDefinitionList.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
DefineBKTParameter(m_pTrees.m_fBalanceFactor, float, 100.0F, "BKTLambdaFactor")
DefineBKTParameter(m_pTrees.m_parallelBuild, bool, false, "ParallelBKTBuild")

DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
Expand Down
170 changes: 170 additions & 0 deletions AnnService/inc/Core/Common/BKTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <vector>
#include <mutex>
#include <shared_mutex>
#include <atomic>
#include <omp.h>
#include "inc/Core/VectorIndex.h"

#include "CommonUtils.h"
Expand Down Expand Up @@ -655,6 +657,173 @@ break;
}
}

// Parallel BKTree Build - processes sibling nodes in parallel
template <typename T>
void BuildTreesParallel(const Dataset<T>& data, DistCalcMethod distMethod, int numOfThreads,
std::vector<SizeType>* indices = nullptr, std::vector<SizeType>* reverseIndices = nullptr,
bool dynamicK = false, IAbortOperation* abort = nullptr)
{
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Using PARALLEL BKTree build with %d threads.\n", numOfThreads);

// Helper struct for collecting parallel results
struct ParallelNodeResult {
SizeType parentIndex;
SizeType first, last;
std::vector<SizeType> childCenters;
std::vector<SizeType> childCounts;
bool isLeaf;
bool singleCluster;
SizeType singleClusterCenter;
};

struct BKTStackItem {
SizeType index, first, last;
bool debug;
BKTStackItem(SizeType index_ = -1, SizeType first_ = 0, SizeType last_ = 0, bool debug_ = false)
: index(index_), first(first_), last(last_), debug(debug_) {}
};

std::vector<SizeType> localindices;
if (indices == nullptr) {
localindices.resize(data.R());
for (SizeType i = 0; i < (SizeType)localindices.size(); i++) localindices[i] = i;
}
else {
localindices.assign(indices->begin(), indices->end());
}

// Create a shared KmeansArgs for DynamicFactorSelect (uses all threads)
KmeansArgs<T> sharedArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), numOfThreads, distMethod, m_pQuantizer);

if (m_fBalanceFactor < 0) {
m_fBalanceFactor = DynamicFactorSelect(data, localindices, 0, (SizeType)localindices.size(), sharedArgs, m_iSamples);
}

std::mt19937 rg;
m_pSampleCenterMap.clear();

for (char treeIdx = 0; treeIdx < m_iTreeNumber; treeIdx++)
{
std::shuffle(localindices.begin(), localindices.end(), rg);

m_pTreeStart.push_back((SizeType)m_pTreeRoots.size());
m_pTreeRoots.emplace_back((SizeType)localindices.size());
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start to build BKTree %d (parallel)\n", treeIdx + 1);

// Level-order processing
std::vector<BKTStackItem> currentLevel, nextLevel;
currentLevel.push_back(BKTStackItem(m_pTreeStart[treeIdx], 0, (SizeType)localindices.size(), true));

int level = 0;
while (!currentLevel.empty()) {
if (abort && abort->ShouldAbort()) {
SPTAGLIB_LOG(Helper::LogLevel::LL_Warning, "Abort!!!\n");
return;
}

size_t levelSize = currentLevel.size();
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Processing level %d with %zu nodes...\n", level, levelSize);

std::vector<ParallelNodeResult> results(levelSize);

// Parallel phase: Run k-means for all nodes in this level
#pragma omp parallel for schedule(dynamic, 1) num_threads(numOfThreads)
for (int idx = 0; idx < (int)levelSize; idx++) {
BKTStackItem& item = currentLevel[idx];
ParallelNodeResult& result = results[idx];
result.parentIndex = item.index;
result.first = item.first;
result.last = item.last;
result.isLeaf = false;
result.singleCluster = false;

if (item.last - item.first <= m_iBKTLeafSize) {
// Leaf node
result.isLeaf = true;
for (SizeType j = item.first; j < item.last; j++) {
SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]);
result.childCenters.push_back(cid);
}
} else {
// K-means clustering - use thread-local args with 1 thread
// (parallelism is at the node level, not within k-means)
// IMPORTANT: Must use full dataset size because KmeansAssign uses absolute indices
// (args.label[i] where i ranges from first to last, not 0 to rangeSize)
KmeansArgs<T> localArgs(m_iBKTKmeansK, data.C(), (SizeType)localindices.size(), 1, distMethod, m_pQuantizer);

int dk = m_iBKTKmeansK;
if (dynamicK) {
dk = std::min<int>((item.last - item.first) / m_iBKTLeafSize + 1, m_iBKTKmeansK);
dk = std::max<int>(dk, 2);
localArgs._DK = dk;
}

int numClusters = KmeansClustering(data, localindices, item.first, item.last, localArgs,
m_iSamples, m_fBalanceFactor, false, abort);

if (numClusters <= 1) {
result.singleCluster = true;
SizeType end = min(item.last + 1, (SizeType)localindices.size());
std::sort(localindices.begin() + item.first, localindices.begin() + end);
result.singleClusterCenter = (reverseIndices == nullptr) ? localindices[item.first] : reverseIndices->at(localindices[item.first]);
for (SizeType j = item.first + 1; j < end; j++) {
SizeType cid = (reverseIndices == nullptr) ? localindices[j] : reverseIndices->at(localindices[j]);
result.childCenters.push_back(cid);
}
} else {
SizeType pos = item.first;
for (int k = 0; k < m_iBKTKmeansK; k++) {
if (localArgs.counts[k] == 0) continue;
SizeType cid = (reverseIndices == nullptr) ? localindices[pos + localArgs.counts[k] - 1] : reverseIndices->at(localindices[pos + localArgs.counts[k] - 1]);
result.childCenters.push_back(cid);
result.childCounts.push_back(localArgs.counts[k]);
pos += localArgs.counts[k];
}
}
}
}

// Sequential phase: Build tree structure and prepare next level
nextLevel.clear();
for (size_t idx = 0; idx < levelSize; idx++) {
ParallelNodeResult& result = results[idx];
m_pTreeRoots[result.parentIndex].childStart = (SizeType)m_pTreeRoots.size();

if (result.isLeaf) {
for (SizeType cid : result.childCenters) {
m_pTreeRoots.emplace_back(cid);
}
} else if (result.singleCluster) {
m_pTreeRoots[result.parentIndex].centerid = result.singleClusterCenter;
m_pTreeRoots[result.parentIndex].childStart = -m_pTreeRoots[result.parentIndex].childStart;
for (SizeType cid : result.childCenters) {
m_pTreeRoots.emplace_back(cid);
m_pSampleCenterMap[cid] = result.singleClusterCenter;
}
m_pSampleCenterMap[-1 - result.singleClusterCenter] = result.parentIndex;
} else {
SizeType pos = result.first;
for (size_t c = 0; c < result.childCenters.size(); c++) {
SizeType nodeIdx = (SizeType)m_pTreeRoots.size();
m_pTreeRoots.emplace_back(result.childCenters[c]);
if (result.childCounts[c] > 1) {
nextLevel.push_back(BKTStackItem(nodeIdx, pos, pos + result.childCounts[c] - 1, false));
}
pos += result.childCounts[c];
}
}
m_pTreeRoots[result.parentIndex].childEnd = (SizeType)m_pTreeRoots.size();
}

currentLevel.swap(nextLevel);
level++;
}

m_pTreeRoots.emplace_back(-1);
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "%d BKTree built (parallel), %zu %zu\n", treeIdx + 1, m_pTreeRoots.size() - m_pTreeStart[treeIdx], localindices.size());
}
}

inline std::uint64_t BufferSize() const
{
return sizeof(int) + sizeof(SizeType) * m_iTreeNumber +
Expand Down Expand Up @@ -863,6 +1032,7 @@ break;
int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples, m_bfs;
float m_fBalanceFactor;
std::shared_ptr<SPTAG::COMMON::IQuantizer> m_pQuantizer;
bool m_parallelBuild = false;
};
}
}
Expand Down
1 change: 1 addition & 0 deletions AnnService/inc/Core/SPANN/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace SPTAG {
bool m_recursiveCheckSmallCluster;
bool m_printSizeCount;
std::string m_selectType;
bool m_parallelBKTBuild;

// Section 3: for build head
bool m_buildHead;
Expand Down
1 change: 1 addition & 0 deletions AnnService/inc/Core/SPANN/ParameterDefinitionList.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DefineSelectHeadParameter(m_headVectorCount, int, 0, "Count")
DefineSelectHeadParameter(m_recursiveCheckSmallCluster, bool, true, "RecursiveCheckSmallCluster")
DefineSelectHeadParameter(m_printSizeCount, bool, true, "PrintSizeCount")
DefineSelectHeadParameter(m_selectType, std::string, "BKT", "SelectHeadType")
DefineSelectHeadParameter(m_parallelBKTBuild, bool, false, "ParallelBKTBuild")
#endif

#ifdef DefineBuildHeadParameter
Expand Down
6 changes: 5 additions & 1 deletion AnnService/src/Core/BKT/BKTIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,11 @@ ErrorCode Index<T>::BuildIndex(const void *p_data, SizeType p_vectorNum, Dimensi
m_threadPool.init();

auto t1 = std::chrono::high_resolution_clock::now();
m_pTrees.BuildTrees<T>(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads);
if (m_pTrees.m_parallelBuild) {
m_pTrees.BuildTreesParallel<T>(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads);
} else {
m_pTrees.BuildTrees<T>(m_pSamples, m_iDistCalcMethod, m_iNumberOfThreads);
}
auto t2 = std::chrono::high_resolution_clock::now();
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Build Tree time (s): %lld\n",
std::chrono::duration_cast<std::chrono::seconds>(t2 - t1).count());
Expand Down
16 changes: 11 additions & 5 deletions AnnService/src/Core/SPANN/SPANNIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,16 +938,22 @@ bool Index<T>::SelectHeadInternal(std::shared_ptr<Helper::VectorSetReader> &p_re
bkt->m_iSamples = m_options.m_iSamples;
bkt->m_iTreeNumber = m_options.m_iTreeNumber;
bkt->m_fBalanceFactor = m_options.m_fBalanceFactor;
bkt->m_parallelBuild = m_options.m_parallelBKTBuild;
bkt->m_pQuantizer = m_pQuantizer;
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Start invoking BuildTrees.\n");
SPTAGLIB_LOG(
Helper::LogLevel::LL_Info,
"BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d.\n",
"BKTKmeansK: %d, BKTLeafSize: %d, Samples: %d, BKTLambdaFactor:%f TreeNumber: %d, ThreadNum: %d, ParallelBuild: %s.\n",
bkt->m_iBKTKmeansK, bkt->m_iBKTLeafSize, bkt->m_iSamples, bkt->m_fBalanceFactor, bkt->m_iTreeNumber,
m_options.m_iSelectHeadNumberOfThreads);

bkt->BuildTrees<InternalDataType>(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads,
nullptr, nullptr, true);
m_options.m_iSelectHeadNumberOfThreads, m_options.m_parallelBKTBuild ? "true" : "false");

if (bkt->m_parallelBuild) {
bkt->BuildTreesParallel<InternalDataType>(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads,
nullptr, nullptr, true);
} else {
bkt->BuildTrees<InternalDataType>(data, m_options.m_distCalcMethod, m_options.m_iSelectHeadNumberOfThreads,
nullptr, nullptr, true);
}
auto t2 = std::chrono::high_resolution_clock::now();
double elapsedSeconds = std::chrono::duration_cast<std::chrono::seconds>(t2 - t1).count();
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "End invoking BuildTrees.\n");
Expand Down