Skip to content

Commit fc0d9fa

Browse files
committed
[C-API] Add Index threadpool size getter and setter
1 parent 3103a68 commit fc0d9fa

5 files changed

Lines changed: 123 additions & 29 deletions

File tree

bindings/c/include/svs/c_api/svs_c.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,24 @@ SVS_API bool svs_index_dynamic_compact(
501501
svs_index_h index, size_t batchsize /*=0*/, svs_error_h out_err /*=NULL*/
502502
);
503503

504+
/// @brief Get number of threads used for search in the index's thread pool
505+
/// @param index The index handle
506+
/// @param out_num_threads Pointer to store the retrieved number of threads
507+
/// @param out_err An optional error handle to capture errors
508+
/// @return true on success, false on failure
509+
SVS_API bool svs_index_get_num_threads(
510+
svs_index_h index, size_t* out_num_threads, svs_error_h out_err /*=NULL*/
511+
);
512+
513+
/// @brief Set number of threads for search in the index's thread pool
514+
/// @param index The index handle
515+
/// @param num_threads The number of threads to set
516+
/// @param out_err An optional error handle to capture errors
517+
/// @return true on success, false on failure
518+
SVS_API bool svs_index_set_num_threads(
519+
svs_index_h index, size_t num_threads, svs_error_h out_err /*=NULL*/
520+
);
521+
504522
#ifdef __cplusplus
505523
}
506524
#endif

bindings/c/src/index.hpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "svs/c_api/svs_c.h"
1919

2020
#include "algorithm.hpp"
21+
#include "threadpool.hpp"
2122

2223
#include <svs/concepts/data.h>
2324
#include <svs/core/distance.h>
@@ -32,8 +33,10 @@
3233
namespace svs::c_runtime {
3334
struct Index {
3435
svs_algorithm_type algorithm;
35-
Index(svs_algorithm_type algorithm)
36-
: algorithm(algorithm) {}
36+
ThreadPoolBuilder pool_builder;
37+
Index(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
38+
: algorithm(algorithm)
39+
, pool_builder(pool_builder) {}
3740
virtual ~Index() = default;
3841
virtual svs::QueryResult<size_t> search(
3942
svs::data::ConstSimpleDataView<float> queries,
@@ -45,11 +48,13 @@ struct Index {
4548
virtual float get_distance(size_t id, std::span<const float> query) const = 0;
4649
virtual void
4750
reconstruct_at(svs::data::SimpleDataView<float> dst, std::span<const size_t> ids) = 0;
51+
virtual size_t get_num_threads() { return pool_builder.get_threads_num(); };
52+
virtual void set_num_threads(size_t num_threads) = 0;
4853
};
4954

5055
struct DynamicIndex : public Index {
51-
DynamicIndex(svs_algorithm_type algorithm)
52-
: Index(algorithm) {}
56+
DynamicIndex(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
57+
: Index(algorithm, pool_builder) {}
5358
~DynamicIndex() = default;
5459

5560
virtual size_t add_points(
@@ -63,8 +68,8 @@ struct DynamicIndex : public Index {
6368

6469
struct IndexVamana : public Index {
6570
svs::Vamana index;
66-
IndexVamana(svs::Vamana&& index)
67-
: Index{SVS_ALGORITHM_TYPE_VAMANA}
71+
IndexVamana(svs::Vamana&& index, ThreadPoolBuilder pool_builder)
72+
: Index{SVS_ALGORITHM_TYPE_VAMANA, pool_builder}
6873
, index(std::move(index)) {}
6974
~IndexVamana() = default;
7075
svs::QueryResult<size_t> search(
@@ -99,12 +104,17 @@ struct IndexVamana : public Index {
99104
override {
100105
index.reconstruct_at(dst, ids);
101106
}
107+
108+
void set_num_threads(size_t num_threads) override {
109+
pool_builder.resize(num_threads);
110+
index.set_threadpool(pool_builder.build());
111+
}
102112
};
103113

104114
struct DynamicIndexVamana : public DynamicIndex {
105115
svs::DynamicVamana index;
106-
DynamicIndexVamana(svs::DynamicVamana&& index)
107-
: DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA)
116+
DynamicIndexVamana(svs::DynamicVamana&& index, ThreadPoolBuilder pool_builder)
117+
: DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA, pool_builder)
108118
, index(std::move(index)) {}
109119
~DynamicIndexVamana() = default;
110120

@@ -170,5 +180,10 @@ struct DynamicIndexVamana : public DynamicIndex {
170180
index.compact(batchsize);
171181
}
172182
}
183+
184+
void set_num_threads(size_t num_threads) override {
185+
pool_builder.resize(num_threads);
186+
index.set_threadpool(pool_builder.build());
187+
}
173188
};
174189
} // namespace svs::c_runtime

bindings/c/src/index_builder.hpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,16 @@ struct IndexBuilder {
6969
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
7070
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
7171

72-
auto index = std::make_shared<IndexVamana>(dispatch_vamana_index_build(
73-
vamana_algorithm->build_parameters(),
74-
data,
75-
storage.get(),
76-
to_distance_type(distance_metric),
77-
pool_builder.build()
78-
));
72+
auto index = std::make_shared<IndexVamana>(
73+
dispatch_vamana_index_build(
74+
vamana_algorithm->build_parameters(),
75+
data,
76+
storage.get(),
77+
to_distance_type(distance_metric),
78+
pool_builder.build()
79+
),
80+
pool_builder
81+
);
7982

8083
return index;
8184
}
@@ -86,13 +89,16 @@ struct IndexBuilder {
8689
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
8790
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
8891

89-
auto index = std::make_shared<IndexVamana>(dispatch_vamana_index_load(
90-
vamana_algorithm->build_parameters(),
91-
directory,
92-
storage.get(),
93-
to_distance_type(distance_metric),
94-
pool_builder.build()
95-
));
92+
auto index = std::make_shared<IndexVamana>(
93+
dispatch_vamana_index_load(
94+
vamana_algorithm->build_parameters(),
95+
directory,
96+
storage.get(),
97+
to_distance_type(distance_metric),
98+
pool_builder.build()
99+
),
100+
pool_builder
101+
);
96102

97103
return index;
98104
}
@@ -107,16 +113,18 @@ struct IndexBuilder {
107113
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
108114
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
109115

110-
auto index =
111-
std::make_shared<DynamicIndexVamana>(dispatch_dynamic_vamana_index_build(
116+
auto index = std::make_shared<DynamicIndexVamana>(
117+
dispatch_dynamic_vamana_index_build(
112118
vamana_algorithm->build_parameters(),
113119
data,
114120
ids,
115121
storage.get(),
116122
to_distance_type(distance_metric),
117123
pool_builder.build(),
118124
blocksize_bytes
119-
));
125+
),
126+
pool_builder
127+
);
120128

121129
return index;
122130
}
@@ -128,15 +136,17 @@ struct IndexBuilder {
128136
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
129137
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
130138

131-
auto index =
132-
std::make_shared<DynamicIndexVamana>(dispatch_dynamic_vamana_index_load(
139+
auto index = std::make_shared<DynamicIndexVamana>(
140+
dispatch_dynamic_vamana_index_load(
133141
vamana_algorithm->build_parameters(),
134142
directory,
135143
storage.get(),
136144
to_distance_type(distance_metric),
137145
pool_builder.build(),
138146
blocksize_bytes
139-
));
147+
),
148+
pool_builder
149+
);
140150

141151
return index;
142152
}

bindings/c/src/svs_c.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,33 @@ svs_index_dynamic_compact(svs_index_h index, size_t batchsize, svs_error_h out_e
787787
false
788788
);
789789
}
790+
791+
extern "C" bool
792+
svs_index_get_num_threads(svs_index_h index, size_t* out_num_threads, svs_error_h out_err) {
793+
using namespace svs::c_runtime;
794+
return wrap_exceptions(
795+
[&]() {
796+
EXPECT_ARG_NOT_NULL(index);
797+
EXPECT_ARG_NOT_NULL(out_num_threads);
798+
*out_num_threads = index->impl->get_num_threads();
799+
return true;
800+
},
801+
out_err,
802+
false
803+
);
804+
}
805+
806+
extern "C" bool
807+
svs_index_set_num_threads(svs_index_h index, size_t num_threads, svs_error_h out_err) {
808+
using namespace svs::c_runtime;
809+
return wrap_exceptions(
810+
[&]() {
811+
EXPECT_ARG_NOT_NULL(index);
812+
EXPECT_ARG_GT_THAN(num_threads, 0);
813+
index->impl->set_num_threads(num_threads);
814+
return true;
815+
},
816+
out_err,
817+
false
818+
);
819+
}

bindings/c/src/threadpool.hpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ class ThreadPoolBuilder {
7474

7575
ThreadPoolBuilder(svs_threadpool_kind kind, size_t num_threads)
7676
: kind(kind)
77-
, num_threads(num_threads) {
77+
, num_threads(kind == SVS_THREADPOOL_KIND_SINGLE_THREAD ? 1 : num_threads)
78+
, user_threadpool(nullptr) {
7879
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
7980
throw std::invalid_argument(
8081
"SVS_THREADPOOL_KIND_CUSTOM cannot be built automatically."
@@ -91,6 +92,26 @@ class ThreadPoolBuilder {
9192
return std::max(size_t{1}, size_t{std::thread::hardware_concurrency()});
9293
}
9394

95+
svs_threadpool_kind get_kind() const { return kind; }
96+
svs_threadpool_i get_user_threadpool() const { return user_threadpool; }
97+
98+
size_t get_threads_num() const {
99+
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
100+
return user_threadpool->ops.size(user_threadpool->self);
101+
}
102+
return num_threads;
103+
}
104+
105+
void resize(size_t new_num_threads) {
106+
if (kind == SVS_THREADPOOL_KIND_SINGLE_THREAD) {
107+
throw std::logic_error("Cannot resize a single-threaded threadpool.");
108+
}
109+
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
110+
throw std::logic_error("Cannot resize a custom threadpool.");
111+
}
112+
num_threads = new_num_threads;
113+
}
114+
94115
svs::threads::ThreadPoolHandle build() const {
95116
using namespace svs::threads;
96117
switch (kind) {

0 commit comments

Comments
 (0)