1818#include " svs/c_api/svs_c.h"
1919
2020#include " algorithm.hpp"
21+ #include " threadpool.hpp"
2122
2223#include < svs/concepts/data.h>
2324#include < svs/core/distance.h>
3233namespace svs ::c_runtime {
3334struct Index {
3435 svs_algorithm_type algorithm;
35- Index (svs_algorithm_type algorithm)
36- : algorithm(algorithm) {}
36+ ThreadPoolBuilder pool_builder;
37+ Index (svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder)
38+ : algorithm(algorithm)
39+ , pool_builder(pool_builder) {}
3740 virtual ~Index () = default ;
3841 virtual svs::QueryResult<size_t > search (
3942 svs::data::ConstSimpleDataView<float > queries,
@@ -45,11 +48,13 @@ struct Index {
4548 virtual float get_distance (size_t id, std::span<const float > query) const = 0;
4649 virtual void
4750 reconstruct_at (svs::data::SimpleDataView<float > dst, std::span<const size_t > ids) = 0 ;
51+ virtual size_t get_num_threads () { return pool_builder.get_threads_num (); };
52+ virtual void set_num_threads (size_t num_threads) = 0;
4853};
4954
5055struct DynamicIndex : public Index {
51- DynamicIndex (svs_algorithm_type algorithm)
52- : Index(algorithm) {}
56+ DynamicIndex (svs_algorithm_type algorithm, ThreadPoolBuilder pool_builder )
57+ : Index(algorithm, pool_builder ) {}
5358 ~DynamicIndex () = default ;
5459
5560 virtual size_t add_points (
@@ -63,8 +68,8 @@ struct DynamicIndex : public Index {
6368
6469struct IndexVamana : public Index {
6570 svs::Vamana index;
66- IndexVamana (svs::Vamana&& index)
67- : Index{SVS_ALGORITHM_TYPE_VAMANA}
71+ IndexVamana (svs::Vamana&& index, ThreadPoolBuilder pool_builder )
72+ : Index{SVS_ALGORITHM_TYPE_VAMANA, pool_builder }
6873 , index(std::move(index)) {}
6974 ~IndexVamana () = default ;
7075 svs::QueryResult<size_t > search (
@@ -99,12 +104,17 @@ struct IndexVamana : public Index {
99104 override {
100105 index.reconstruct_at (dst, ids);
101106 }
107+
108+ void set_num_threads (size_t num_threads) override {
109+ pool_builder.resize (num_threads);
110+ index.set_threadpool (pool_builder.build ());
111+ }
102112};
103113
104114struct DynamicIndexVamana : public DynamicIndex {
105115 svs::DynamicVamana index;
106- DynamicIndexVamana (svs::DynamicVamana&& index)
107- : DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA)
116+ DynamicIndexVamana (svs::DynamicVamana&& index, ThreadPoolBuilder pool_builder )
117+ : DynamicIndex(SVS_ALGORITHM_TYPE_VAMANA, pool_builder )
108118 , index(std::move(index)) {}
109119 ~DynamicIndexVamana () = default ;
110120
@@ -170,5 +180,10 @@ struct DynamicIndexVamana : public DynamicIndex {
170180 index.compact (batchsize);
171181 }
172182 }
183+
184+ void set_num_threads (size_t num_threads) override {
185+ pool_builder.resize (num_threads);
186+ index.set_threadpool (pool_builder.build ());
187+ }
173188};
174189} // namespace svs::c_runtime
0 commit comments