Skip to content

Commit 1478cde

Browse files
authored
[C-API] Add getter and setter for index threadpool size (#305)
Adds `svs_index_get_num_threads` / `svs_index_set_num_threads` to the C API, enabling dynamic inspection and resizing of the search threadpool after index construction. ### ThreadPoolBuilder - Added `get_threads_num()` — delegates to the custom pool's `size()` op when `kind == CUSTOM`, otherwise returns the stored count - Added `resize(n)` — updates stored thread count; throws `std::invalid_argument` for `n == 0`, `SINGLE_THREAD`, or `CUSTOM` kinds (surfaced as `SVS_ERROR_INVALID_ARGUMENT` through `wrap_exceptions`) ### Index wrappers (`index.hpp`) - `Index` stores a `ThreadPoolBuilder`; `get_num_threads()` is pure-virtual — implemented in `IndexVamana` and `DynamicIndexVamana` by delegating to the wrapped `svs::Vamana` / `svs::DynamicVamana` instance, so the value reflects actual runtime state - `set_num_threads(n)` calls `pool_builder.resize(n)` then rebuilds and installs the threadpool via `set_threadpool()` ### C API (`svs_c.cpp` / `svs_c.h`) - Both entry points validate `index->impl` non-null before dereferencing (consistent with existing handle-check pattern) - Public header documents supported kinds and expected error codes for unsupported configurations
1 parent ec4260c commit 1478cde

6 files changed

Lines changed: 154 additions & 29 deletions

File tree

bindings/c/include/svs/c_api/svs_c.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum svs_error_code {
3232
SVS_ERROR_NOT_IMPLEMENTED = 5,
3333
SVS_ERROR_UNSUPPORTED_HW = 6,
3434
SVS_ERROR_RUNTIME = 7,
35+
SVS_ERROR_INVALID_OPERATION = 8,
3536
SVS_ERROR_UNKNOWN = 1000
3637
};
3738

@@ -501,6 +502,32 @@ SVS_API bool svs_index_dynamic_compact(
501502
svs_index_h index, size_t batchsize /*=0*/, svs_error_h out_err /*=NULL*/
502503
);
503504

505+
/// @brief Get number of threads used for search in the index's thread pool
506+
/// @param index The index handle
507+
/// @param out_num_threads Pointer to store the retrieved number of threads
508+
/// @param out_err An optional error handle to capture errors
509+
/// @return true on success, false on failure
510+
SVS_API bool svs_index_get_num_threads(
511+
svs_index_h index, size_t* out_num_threads, svs_error_h out_err /*=NULL*/
512+
);
513+
514+
/// @brief Set number of threads for search in the index's thread pool
515+
/// @param index The index handle
516+
/// @param num_threads The number of threads to set
517+
/// @param out_err An optional error handle to capture errors
518+
/// @return true on success, false on failure
519+
/// @remarks This function is only supported for indices built with threadpool kinds
520+
/// SVS_THREADPOOL_KIND_NATIVE or SVS_THREADPOOL_KIND_OMP. Attempting to call this
521+
/// function on indices built with SVS_THREADPOOL_KIND_CUSTOM or
522+
/// SVS_THREADPOOL_KIND_SINGLE_THREAD will fail and return false.
523+
/// @error On failure, if out_err is provided, it will contain:
524+
/// - SVS_ERROR_INVALID_OPERATION if the index's threadpool kind is unresizable
525+
/// - SVS_ERROR_INVALID_ARGUMENT if num_threads is invalid or zero
526+
/// - SVS_ERROR_RUNTIME for other runtime failures
527+
SVS_API bool svs_index_set_num_threads(
528+
svs_index_h index, size_t num_threads, svs_error_h out_err /*=NULL*/
529+
);
530+
504531
#ifdef __cplusplus
505532
}
506533
#endif

bindings/c/src/error.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ class not_implemented : public std::logic_error {
8787
using std::logic_error::logic_error;
8888
};
8989

90+
class invalid_operation : public std::logic_error {
91+
public:
92+
using std::logic_error::logic_error;
93+
};
94+
9095
class unsupported_hw : public std::runtime_error {
9196
public:
9297
using std::runtime_error::runtime_error;
@@ -104,6 +109,9 @@ Result wrap_exceptions(Callable&& func, svs_error_h err, Result err_res = {}) no
104109
} catch (const svs::c_runtime::not_implemented& ex) {
105110
SET_ERROR(err, SVS_ERROR_NOT_IMPLEMENTED, ex.what());
106111
return err_res;
112+
} catch (const svs::c_runtime::invalid_operation& ex) {
113+
SET_ERROR(err, SVS_ERROR_INVALID_OPERATION, ex.what());
114+
return err_res;
107115
} catch (const svs::c_runtime::unsupported_hw& ex) {
108116
SET_ERROR(err, SVS_ERROR_UNSUPPORTED_HW, ex.what());
109117
return err_res;

bindings/c/src/index.hpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "svs/c_api/svs_c.h"
1919

2020
#include "algorithm.hpp"
21+
#include "threadpool.hpp"
2122

2223
#include <svs/concepts/data.h>
2324
#include <svs/core/distance.h>
@@ -32,8 +33,10 @@
3233
namespace svs::c_runtime {
3334
struct Index {
3435
svs_algorithm_type algorithm;
35-
Index(svs_algorithm_type algorithm)
36-
: algorithm(algorithm) {}
36+
ThreadPoolBuilder pool_builder;
37+
Index(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
38+
: algorithm(algorithm)
39+
, pool_builder(pool_builder) {}
3740
virtual ~Index() = default;
3841
virtual svs::QueryResult<size_t> search(
3942
svs::data::ConstSimpleDataView<float> queries,
@@ -45,11 +48,13 @@ struct Index {
4548
virtual float get_distance(size_t id, std::span<const float> query) const = 0;
4649
virtual void
4750
reconstruct_at(svs::data::SimpleDataView<float> dst, std::span<const size_t> ids) = 0;
51+
virtual size_t get_num_threads() const = 0;
52+
virtual void set_num_threads(size_t num_threads) = 0;
4853
};
4954

5055
struct DynamicIndex : public Index {
51-
DynamicIndex(svs_algorithm_type algorithm)
52-
: Index(algorithm) {}
56+
DynamicIndex(svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
57+
: Index(algorithm, pool_builder) {}
5358
~DynamicIndex() = default;
5459

5560
virtual size_t add_points(
@@ -63,8 +68,8 @@ struct DynamicIndex : public Index {
6368

6469
struct IndexVamana : public Index {
6570
svs::Vamana index;
66-
IndexVamana(svs::Vamana&& index)
67-
: Index{SVS_ALGORITHM_TYPE_VAMANA}
71+
IndexVamana(svs::Vamana&& index, ThreadPoolBuilder pool_builder)
72+
: Index{SVS_ALGORITHM_TYPE_VAMANA, pool_builder}
6873
, index(std::move(index)) {}
6974
~IndexVamana() = default;
7075
svs::QueryResult<size_t> search(
@@ -99,12 +104,19 @@ struct IndexVamana : public Index {
99104
override {
100105
index.reconstruct_at(dst, ids);
101106
}
107+
108+
size_t get_num_threads() const override { return index.get_num_threads(); }
109+
110+
void set_num_threads(size_t num_threads) override {
111+
pool_builder.resize(num_threads);
112+
index.set_threadpool(pool_builder.build());
113+
}
102114
};
103115

104116
struct DynamicIndexVamana : public DynamicIndex {
105117
svs::DynamicVamana index;
106-
DynamicIndexVamana(svs::DynamicVamana&& index)
107-
: DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA)
118+
DynamicIndexVamana(svs::DynamicVamana&& index, ThreadPoolBuilder pool_builder)
119+
: DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA, pool_builder)
108120
, index(std::move(index)) {}
109121
~DynamicIndexVamana() = default;
110122

@@ -170,5 +182,12 @@ struct DynamicIndexVamana : public DynamicIndex {
170182
index.compact(batchsize);
171183
}
172184
}
185+
186+
size_t get_num_threads() const override { return index.get_num_threads(); }
187+
188+
void set_num_threads(size_t num_threads) override {
189+
pool_builder.resize(num_threads);
190+
index.set_threadpool(pool_builder.build());
191+
}
173192
};
174193
} // namespace svs::c_runtime

bindings/c/src/index_builder.hpp

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,16 @@ struct IndexBuilder {
6969
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
7070
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
7171

72-
auto index = std::make_shared<IndexVamana>(dispatch_vamana_index_build(
73-
vamana_algorithm->build_parameters(),
74-
data,
75-
storage.get(),
76-
to_distance_type(distance_metric),
77-
pool_builder.build()
78-
));
72+
auto index = std::make_shared<IndexVamana>(
73+
dispatch_vamana_index_build(
74+
vamana_algorithm->build_parameters(),
75+
data,
76+
storage.get(),
77+
to_distance_type(distance_metric),
78+
pool_builder.build()
79+
),
80+
pool_builder
81+
);
7982

8083
return index;
8184
}
@@ -86,13 +89,16 @@ struct IndexBuilder {
8689
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
8790
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
8891

89-
auto index = std::make_shared<IndexVamana>(dispatch_vamana_index_load(
90-
vamana_algorithm->build_parameters(),
91-
directory,
92-
storage.get(),
93-
to_distance_type(distance_metric),
94-
pool_builder.build()
95-
));
92+
auto index = std::make_shared<IndexVamana>(
93+
dispatch_vamana_index_load(
94+
vamana_algorithm->build_parameters(),
95+
directory,
96+
storage.get(),
97+
to_distance_type(distance_metric),
98+
pool_builder.build()
99+
),
100+
pool_builder
101+
);
96102

97103
return index;
98104
}
@@ -107,16 +113,18 @@ struct IndexBuilder {
107113
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
108114
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
109115

110-
auto index =
111-
std::make_shared<DynamicIndexVamana>(dispatch_dynamic_vamana_index_build(
116+
auto index = std::make_shared<DynamicIndexVamana>(
117+
dispatch_dynamic_vamana_index_build(
112118
vamana_algorithm->build_parameters(),
113119
data,
114120
ids,
115121
storage.get(),
116122
to_distance_type(distance_metric),
117123
pool_builder.build(),
118124
blocksize_bytes
119-
));
125+
),
126+
pool_builder
127+
);
120128

121129
return index;
122130
}
@@ -128,15 +136,17 @@ struct IndexBuilder {
128136
if (algorithm->type == SVS_ALGORITHM_TYPE_VAMANA) {
129137
auto vamana_algorithm = std::static_pointer_cast<AlgorithmVamana>(algorithm);
130138

131-
auto index =
132-
std::make_shared<DynamicIndexVamana>(dispatch_dynamic_vamana_index_load(
139+
auto index = std::make_shared<DynamicIndexVamana>(
140+
dispatch_dynamic_vamana_index_load(
133141
vamana_algorithm->build_parameters(),
134142
directory,
135143
storage.get(),
136144
to_distance_type(distance_metric),
137145
pool_builder.build(),
138146
blocksize_bytes
139-
));
147+
),
148+
pool_builder
149+
);
140150

141151
return index;
142152
}

bindings/c/src/svs_c.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,37 @@ svs_index_dynamic_compact(svs_index_h index, size_t batchsize, svs_error_h out_e
787787
false
788788
);
789789
}
790+
791+
extern "C" bool
792+
svs_index_get_num_threads(svs_index_h index, size_t* out_num_threads, svs_error_h out_err) {
793+
using namespace svs::c_runtime;
794+
return wrap_exceptions(
795+
[&]() {
796+
EXPECT_ARG_NOT_NULL(index);
797+
EXPECT_ARG_NOT_NULL(out_num_threads);
798+
auto& index_ptr = index->impl;
799+
INVALID_ARGUMENT_IF(index_ptr == nullptr, "Invalid index handle");
800+
*out_num_threads = index_ptr->get_num_threads();
801+
return true;
802+
},
803+
out_err,
804+
false
805+
);
806+
}
807+
808+
extern "C" bool
809+
svs_index_set_num_threads(svs_index_h index, size_t num_threads, svs_error_h out_err) {
810+
using namespace svs::c_runtime;
811+
return wrap_exceptions(
812+
[&]() {
813+
EXPECT_ARG_NOT_NULL(index);
814+
EXPECT_ARG_GT_THAN(num_threads, 0);
815+
auto& index_ptr = index->impl;
816+
INVALID_ARGUMENT_IF(index_ptr == nullptr, "Invalid index handle");
817+
index_ptr->set_num_threads(num_threads);
818+
return true;
819+
},
820+
out_err,
821+
false
822+
);
823+
}

bindings/c/src/threadpool.hpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "svs/c_api/svs_c.h"
1919

20+
#include "error.hpp"
2021
#include "types_support.hpp"
2122

2223
#include <svs/lib/threads.h>
@@ -74,7 +75,8 @@ class ThreadPoolBuilder {
7475

7576
ThreadPoolBuilder(svs_threadpool_kind kind, size_t num_threads)
7677
: kind(kind)
77-
, num_threads(num_threads) {
78+
, num_threads(kind == SVS_THREADPOOL_KIND_SINGLE_THREAD ? 1 : num_threads)
79+
, user_threadpool(nullptr) {
7880
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
7981
throw std::invalid_argument(
8082
"SVS_THREADPOOL_KIND_CUSTOM cannot be built automatically."
@@ -91,6 +93,31 @@ class ThreadPoolBuilder {
9193
return std::max(size_t{1}, size_t{std::thread::hardware_concurrency()});
9294
}
9395

96+
svs_threadpool_kind get_kind() const { return kind; }
97+
svs_threadpool_i get_user_threadpool() const { return user_threadpool; }
98+
99+
size_t get_threads_num() const {
100+
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
101+
return user_threadpool->ops.size(user_threadpool->self);
102+
}
103+
return num_threads;
104+
}
105+
106+
void resize(size_t new_num_threads) {
107+
if (new_num_threads == 0) {
108+
throw std::invalid_argument("Number of threads must be greater than zero.");
109+
}
110+
if (kind == SVS_THREADPOOL_KIND_SINGLE_THREAD) {
111+
throw svs::c_runtime::invalid_operation(
112+
"Cannot resize a single-threaded threadpool."
113+
);
114+
}
115+
if (kind == SVS_THREADPOOL_KIND_CUSTOM) {
116+
throw svs::c_runtime::invalid_operation("Cannot resize a custom threadpool.");
117+
}
118+
num_threads = new_num_threads;
119+
}
120+
94121
svs::threads::ThreadPoolHandle build() const {
95122
using namespace svs::threads;
96123
switch (kind) {

0 commit comments

Comments
 (0)