@@ -151,6 +151,12 @@ class IVFIndex {
151151 , inter_query_threadpool_{threads::as_threadpool (std::move (threadpool_proto))}
152152 , intra_query_thread_count_{intra_query_thread_count}
153153 , logger_{std::move (logger)} {
154+ // Clamp thread pool: more threads than centroids causes OOB in
155+ // compute_centroid_distances and wastes resources.
156+ if (inter_query_threadpool_.size () > centroids_.size ()) {
157+ inter_query_threadpool_ =
158+ InterQueryThreadPool (threads::DefaultThreadPool (centroids_.size ()));
159+ }
154160 validate_thread_configuration ();
155161 initialize_thread_pools ();
156162 initialize_search_buffers ();
@@ -269,7 +275,8 @@ class IVFIndex {
269275 return scratchspace_type{
270276 create_centroid_buffer (sp.n_probes_ ),
271277 create_leaf_buffers (buffer_leaves_size),
272- extensions::per_thread_batch_search_setup (cluster0_, distance_)};
278+ extensions::per_thread_batch_search_setup (cluster0_, distance_)
279+ };
273280 }
274281
275282 // / @brief Return scratch space resources for external threading with default parameters
@@ -540,7 +547,8 @@ class IVFIndex {
540547 mutable std::vector<size_t > id_in_cluster_{};
541548 // Thread-safe initialization flag for ID mapping (wrapped in unique_ptr for movability)
542549 mutable std::unique_ptr<std::once_flag> id_mapping_init_flag_{
543- std::make_unique<std::once_flag>()};
550+ std::make_unique<std::once_flag>()
551+ };
544552
545553 // /// Threading Infrastructure /////
546554 InterQueryThreadPool inter_query_threadpool_; // Handles parallelism across queries
@@ -567,9 +575,11 @@ class IVFIndex {
567575 void initialize_thread_pools () {
568576 // Create thread pools for intra-query (cluster-level) parallelism
569577 for (size_t i = 0 ; i < inter_query_threadpool_.size (); i++) {
570- intra_query_threadpools_.push_back (threads::ThreadPoolHandle (
571- threads::DefaultThreadPool (intra_query_thread_count_)
572- ));
578+ intra_query_threadpools_.push_back (
579+ threads::ThreadPoolHandle (
580+ threads::DefaultThreadPool (intra_query_thread_count_)
581+ )
582+ );
573583 }
574584 }
575585
@@ -627,11 +637,13 @@ class IVFIndex {
627637
628638 void validate_query_batch_size (size_t query_size) const {
629639 if (query_size > MAX_QUERY_BATCH_SIZE) {
630- throw std::runtime_error (fmt::format (
631- " Query batch size {} exceeds maximum allowed {}" ,
632- query_size,
633- MAX_QUERY_BATCH_SIZE
634- ));
640+ throw std::runtime_error (
641+ fmt::format (
642+ " Query batch size {} exceeds maximum allowed {}" ,
643+ query_size,
644+ MAX_QUERY_BATCH_SIZE
645+ )
646+ );
635647 }
636648 }
637649
@@ -645,9 +657,11 @@ class IVFIndex {
645657 std::vector<SortedBuffer<Idx, distance::compare_t <Dist>>> buffers;
646658 buffers.reserve (intra_query_thread_count_);
647659 for (size_t j = 0 ; j < intra_query_thread_count_; j++) {
648- buffers.push_back (SortedBuffer<Idx, distance::compare_t <Dist>>(
649- buffer_size, distance::comparator (distance_)
650- ));
660+ buffers.push_back (
661+ SortedBuffer<Idx, distance::compare_t <Dist>>(
662+ buffer_size, distance::comparator (distance_)
663+ )
664+ );
651665 }
652666 return buffers;
653667 }
0 commit comments