Skip to content

Commit e9f6dd8

Browse files
authored
Refactor Vamana Iterator: simplified API and performance enhancement (#105)
This PR simplifies the Vamana Iterator API and enhances its performance by improving buffer management and search behavior. *Key Changes:* **Eliminated `soft_clear()` with Extra Buffer Capacity:** The buffer size is now larger than the search_window_size, removing the need for `soft_clear()` when calling `next()`. This ensures smoother search continuation and better performance. **Search Triggered Exclusively by `next(batch_size)`:** The iterator is now empty upon creation or after being updated. Searches are only performed when `next(batch_size)` is explicitly called, providing better control over search execution. On calling `next(batch_size)`, both the `search_window_size` and `buffer_capacity` are increased by `batch_size` and then search is conducted. **Dynamic Buffer Adjustment with batch_size:** Introduced the `batch_size` argument in `next()`, allowing the search buffer size to be dynamically adjusted based on the batch size. **Simplified API by Removing Schedules:** Removed the usage of schedules in the iterator, simplifying its design and making it easier to use.
1 parent d95e876 commit e9f6dd8

14 files changed

Lines changed: 331 additions & 571 deletions

File tree

.github/workflows/build-linux.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,8 @@ jobs:
8181
working-directory: ${{ runner.temp }}/build/tests
8282
run: ctest -C ${{ matrix.build_type }}
8383

84+
- name: Run Cpp Examples
85+
env:
86+
CTEST_OUTPUT_ON_FAILURE: 1
87+
working-directory: ${{ runner.temp }}/build/examples/cpp
88+
run: ctest -C RelWithDebugInfo

benchmark/include/svs-benchmark/vamana/iterator.h

Lines changed: 23 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -30,70 +30,11 @@
3030

3131
namespace svsbenchmark::vamana {
3232

33-
/// Pre-configuration for the linear schedule.
34-
struct LinearSchedulePrototype {
35-
size_t scale_search_window_;
36-
size_t scale_buffer_capacity_;
37-
int64_t enable_filter_after_;
38-
size_t batch_size_start_;
39-
size_t scale_batch_size_;
40-
// Whether search should be restarted on every iteration.
41-
bool restart_searches_;
42-
43-
///// Saving and Loading.
44-
static constexpr std::string_view serialization_schema =
45-
"svsbench_vamana_iter_schedule";
46-
static constexpr svs::lib::Version save_version{0, 0, 0};
47-
48-
svs::lib::SaveTable save() const {
49-
return svs::lib::SaveTable{
50-
serialization_schema,
51-
save_version,
52-
{SVS_LIST_SAVE_(scale_search_window),
53-
SVS_LIST_SAVE_(scale_buffer_capacity),
54-
SVS_LIST_SAVE_(enable_filter_after),
55-
SVS_LIST_SAVE_(batch_size_start),
56-
SVS_LIST_SAVE_(scale_batch_size),
57-
SVS_LIST_SAVE_(restart_searches)}};
58-
}
59-
60-
static LinearSchedulePrototype load(const svs::lib::ContextFreeLoadTable& table) {
61-
return LinearSchedulePrototype{
62-
SVS_LOAD_MEMBER_AT_(table, scale_search_window),
63-
SVS_LOAD_MEMBER_AT_(table, scale_buffer_capacity),
64-
SVS_LOAD_MEMBER_AT_(table, enable_filter_after),
65-
SVS_LOAD_MEMBER_AT_(table, batch_size_start),
66-
SVS_LOAD_MEMBER_AT_(table, scale_batch_size),
67-
SVS_LOAD_MEMBER_AT_(table, restart_searches)};
68-
}
69-
70-
// Return several representative examples for the schedule.
71-
static std::vector<LinearSchedulePrototype> examples() {
72-
return {{10, 20, -1, 10, 0, false}, {10, 10, 3, 10, 5, false}};
73-
}
74-
75-
// Should search be restarted from scratch every iteration.
76-
bool restart_every_iteration() const { return restart_searches_; }
77-
78-
// Materialize an actual schedule given a set of base parameters.
79-
// NOTE: This does not propagate the `restart_searches_` flag.
80-
svs::index::vamana::LinearSchedule
81-
materialize(const svs::index::vamana::VamanaSearchParameters& sp) const {
82-
return svs::index::vamana::LinearSchedule{
83-
sp,
84-
svs::lib::narrow<uint16_t>(scale_search_window_),
85-
svs::lib::narrow<uint16_t>(scale_buffer_capacity_),
86-
svs::lib::narrow<int16_t>(enable_filter_after_),
87-
svs::lib::narrow<uint16_t>(batch_size_start_),
88-
svs::lib::narrow<uint16_t>(scale_batch_size_)};
89-
}
90-
};
91-
9233
struct IteratorSearchParameters {
9334
public:
9435
///// Members
95-
// The schedules to try.
96-
std::vector<LinearSchedulePrototype> schedules_;
36+
// Batch sizes to use for the iterator.
37+
std::vector<size_t> batch_sizes_{{10, 20}};
9738
// target recalls relative to base number of neighbors.
9839
std::vector<svs::lib::Percent> target_recalls_;
9940
// The number of batches to yield.
@@ -108,7 +49,7 @@ struct IteratorSearchParameters {
10849

10950
static IteratorSearchParameters example() {
11051
return IteratorSearchParameters{
111-
.schedules_ = LinearSchedulePrototype::examples(),
52+
.batch_sizes_ = {10},
11253
.target_recalls_ = {svs::lib::Percent(0.9)},
11354
.num_batches_ = 5,
11455
.query_subsample_ = 10,
@@ -119,15 +60,15 @@ struct IteratorSearchParameters {
11960
return svs::lib::SaveTable{
12061
serialization_schema,
12162
save_version,
122-
{SVS_LIST_SAVE_(schedules),
63+
{SVS_LIST_SAVE_(batch_sizes),
12364
SVS_LIST_SAVE_(target_recalls),
12465
SVS_LIST_SAVE_(num_batches),
12566
SVS_LIST_SAVE_(query_subsample)}};
12667
}
12768

12869
static IteratorSearchParameters load(const svs::lib::ContextFreeLoadTable& table) {
12970
return IteratorSearchParameters{
130-
SVS_LOAD_MEMBER_AT_(table, schedules),
71+
SVS_LOAD_MEMBER_AT_(table, batch_sizes),
13172
SVS_LOAD_MEMBER_AT_(table, target_recalls),
13273
SVS_LOAD_MEMBER_AT_(table, num_batches),
13374
SVS_LOAD_MEMBER_AT_(table, query_subsample)};
@@ -279,30 +220,25 @@ struct YieldedResult {
279220

280221
// TODO: Make the dependence on `Report` looser.
281222
template <typename Index> struct QueryIteratorResult {
282-
LinearSchedulePrototype schedule_;
223+
size_t batch_size_;
283224
size_t num_batches_;
284225
double target_recall_;
285226
search::RunReport<Index> report_;
286-
// The search parameters used for each iteration.
287-
// Must be the same for all queries in the batch.
288-
std::vector<svs::index::vamana::VamanaSearchParameters> iteration_parameters_;
289227
// Outer vector: Results for each query.
290228
// Inner vector: Results within a query.
291229
std::vector<std::vector<YieldedResult>> results_;
292230

293231
///// Constructor
294232
QueryIteratorResult(
295-
const LinearSchedulePrototype& schedule,
233+
size_t batch_size,
296234
double target_recall,
297235
search::RunReport<Index> report,
298-
std::vector<svs::index::vamana::VamanaSearchParameters> iteration_parameters,
299236
std::vector<std::vector<YieldedResult>> results
300237
)
301-
: schedule_{schedule}
302-
, num_batches_{iteration_parameters.size()}
238+
: batch_size_{batch_size}
239+
, num_batches_{results.at(0).size()}
303240
, target_recall_{target_recall}
304241
, report_{std::move(report)}
305-
, iteration_parameters_{std::move(iteration_parameters)}
306242
, results_{std::move(results)} {
307243
// Ensure all the yielded results have the correct size.
308244
for (size_t i = 0, imax = results_.size(); i < imax; ++i) {
@@ -326,11 +262,10 @@ template <typename Index> struct QueryIteratorResult {
326262
return svs::lib::SaveTable{
327263
serialization_schema,
328264
save_version,
329-
{SVS_LIST_SAVE_(schedule),
265+
{SVS_LIST_SAVE_(batch_size),
330266
SVS_LIST_SAVE_(num_batches),
331267
SVS_LIST_SAVE_(target_recall),
332268
SVS_LIST_SAVE_(report),
333-
SVS_LIST_SAVE_(iteration_parameters),
334269
SVS_LIST_SAVE_(results)}};
335270
}
336271
};
@@ -360,15 +295,14 @@ std::vector<QueryIteratorResult<Index>> tune_and_search_iterator(
360295

361296
// Loop over each batchsize.
362297
auto query_iterator_results = std::vector<QueryIteratorResult<Index>>{};
363-
for (const auto& schedule : parameters.schedules_) {
364-
auto initial_batch_size = schedule.batch_size_start_;
298+
for (const auto& batch_size : parameters.batch_sizes_) {
365299
for (auto target_recall : parameters.target_recalls_) {
366300
// Calibrate the index for the given recall.
367301
auto config = traits::calibrate(
368302
index,
369303
query_set.training_set_,
370304
query_set.training_set_groundtruth_,
371-
initial_batch_size,
305+
batch_size,
372306
target_recall.value(),
373307
context,
374308
extra
@@ -379,7 +313,7 @@ std::vector<QueryIteratorResult<Index>> tune_and_search_iterator(
379313
index,
380314
query_set.test_set_,
381315
query_set.test_set_groundtruth_,
382-
initial_batch_size,
316+
batch_size,
383317
target_recall.value(),
384318
svsbenchmark::CalibrateContext::TestSetTune,
385319
config,
@@ -389,7 +323,7 @@ std::vector<QueryIteratorResult<Index>> tune_and_search_iterator(
389323
// Now we have a calibrated configuration - obtain a baseline report for
390324
// searching with this batchsize.
391325
auto report = svsbenchmark::search::search_with_config(
392-
index, config, query_test, groundtruth_test, initial_batch_size
326+
index, config, query_test, groundtruth_test, batch_size
393327
);
394328

395329
// `resuilt_buffer`: All results that have been returned by the iterator.
@@ -452,38 +386,34 @@ std::vector<QueryIteratorResult<Index>> tune_and_search_iterator(
452386

453387
// The first call to `iterator` kick-starts graph search.
454388
auto tic = svs::lib::now();
455-
auto iterator = make_iterator(index, query, config, schedule);
389+
auto iterator = make_iterator(index, query);
390+
iterator.next(config.buffer_config_.get_search_window_size());
456391
auto elapsed = svs::lib::time_difference(tic);
457392
if (i == 0) {
458-
iteration_parameters.push_back(iterator.parameters_for_current_batch());
393+
iteration_parameters.push_back(
394+
iterator.parameters_for_current_iteration()
395+
);
459396
}
460397

461398
timings_for_this_query.push_back(tally(iterator, i, 0, elapsed));
462399
for (size_t j = 0; j < parameters.num_batches_; ++j) {
463-
// If requested by the parent schedule, reset search for this
464-
// iteration.
465-
if (schedule.restart_every_iteration()) {
466-
iterator.restart_next_search();
467-
}
468-
469400
tic = svs::lib::now();
470-
iterator.next();
401+
iterator.next(batch_size);
471402
elapsed = svs::lib::time_difference(tic);
472403
timings_for_this_query.push_back(tally(iterator, i, j + 1, elapsed));
473404
if (i == 0) {
474405
iteration_parameters.push_back(
475-
iterator.parameters_for_current_batch()
406+
iterator.parameters_for_current_iteration()
476407
);
477408
}
478409
}
479410
}
480411

481412
// Finish up summarizing these results.
482413
query_iterator_results.emplace_back(
483-
schedule,
414+
batch_size,
484415
target_recall.value(),
485416
std::move(report),
486-
std::move(iteration_parameters),
487417
std::move(yielded_results)
488418
);
489419
do_checkpoint(query_iterator_results);
@@ -522,9 +452,7 @@ toml::table tune_and_search_iterator(
522452
job.parameters_,
523453
query_set,
524454
svsbenchmark::CalibrateContext::InitialTrainingSet,
525-
[](const auto& index, const auto& query, const auto& config, const auto& schedule) {
526-
return index.batch_iterator(query, schedule.materialize(config));
527-
},
455+
[](const auto& index, const auto& query) { return index.batch_iterator(query); },
528456
do_checkpoint,
529457
svsbenchmark::IndexTraits<Index>::regression_optimization()
530458
);

examples/cpp/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ endfunction()
3737
create_simple_example(saveload test_saveload saveload.cpp)
3838
create_simple_example(types test_types types.cpp)
3939
create_simple_example(vamana_iterator test_vamana_iterator vamana_iterator.cpp)
40-
create_simple_example(custom_thread_pool test_custom_thread_pool custom_thread_pool.cpp)
4140

4241
## More complicated examples involving more extensive setup.
4342

@@ -48,6 +47,7 @@ create_simple_example(custom_thread_pool test_custom_thread_pool custom_thread_p
4847
configure_file(../../data/test_dataset/data_f32.fvecs . COPYONLY)
4948
configure_file(../../data/test_dataset/queries_f32.fvecs . COPYONLY)
5049
configure_file(../../data/test_dataset/groundtruth_euclidean.ivecs . COPYONLY)
50+
5151
# The vamana test executable.
5252
add_executable(vamana vamana.cpp)
5353
target_include_directories(vamana PRIVATE ${CMAKE_CURRENT_LIST_DIR})
@@ -61,6 +61,20 @@ add_test(
6161
groundtruth_euclidean.ivecs
6262
)
6363

64+
# The custom thread pool executable.
65+
add_executable(custom_thread_pool custom_thread_pool.cpp)
66+
target_include_directories(custom_thread_pool PRIVATE ${CMAKE_CURRENT_LIST_DIR})
67+
target_link_libraries(custom_thread_pool ${SVS_LIB} svs_compile_options svs_native_options)
68+
add_test(
69+
NAME test_custom_thread_pool
70+
COMMAND
71+
custom_thread_pool
72+
data_f32.fvecs
73+
queries_f32.fvecs
74+
groundtruth_euclidean.ivecs
75+
)
76+
77+
6478
#####
6579
##### Dispatcher
6680
#####

examples/cpp/vamana_iterator.cpp

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void check(bool value, std::string_view expr, svs::lib::detail::LineInfo linfo)
4040
#define CHECK(expr) check(expr, #expr, SVS_LINEINFO);
4141

4242
//! [Example Index Construction]
43-
[[nodiscard]] svs::Vamana make_example_index() {
43+
[[nodiscard]] auto make_example_index() {
4444
// Build the index.
4545
auto build_parameters = svs::index::vamana::VamanaBuildParameters{
4646
1.2, // alpha
@@ -68,42 +68,34 @@ void check(bool value, std::string_view expr, svs::lib::detail::LineInfo linfo)
6868
}
6969

7070
// Build the index.
71-
return svs::Vamana::build<float>(build_parameters, std::move(data), svs::DistanceL2{});
71+
return svs::index::vamana::auto_build(
72+
build_parameters, std::move(data), svs::DistanceL2{}, 1
73+
);
7274
}
7375
//! [Example Index Construction]
7476

75-
void demonstrate_default_schedule() {
77+
void demonstrate_iterator() {
7678
//! [Setup]
7779
auto index = make_example_index();
7880

79-
// Base search parameters for the iterator schedule.
80-
// This uses a search window size/capacity of 4.
81-
auto base_parameters = svs::index::vamana::VamanaSearchParameters{}.buffer_config({4});
82-
83-
// The default schedule take base parameters and a batch size.
8481
// Each iteration will yield 3 elements that have not been yielded previously.
8582
size_t batchsize = 3;
86-
auto schedule = svs::index::vamana::DefaultSchedule{base_parameters, batchsize};
8783

8884
// Create a batch iterator over the index for the query.
89-
// After the constructor returns, the contents of the first batch will be available.
9085
auto itr = [&]() {
9186
// Construct a query a query in a scoped block to demonstrate that the iterator
9287
// maintains an internal copy.
9388
auto query = std::vector<float>{3.25, 3.25, 3.25, 3.25};
9489

95-
// Make a batch iterator for the query using the provided schedule.
96-
return index.batch_iterator(svs::lib::as_const_span(query), schedule);
90+
// Make a batch iterator for the query.
91+
return svs::VamanaIterator(index, svs::lib::as_const_span(query));
9792
}();
9893
//! [Setup]
9994

100-
//! [Initial Checks]
101-
// The iterator was configured to yield three neighbors on each invocation.
102-
// This information is available through the `size()` method.
95+
//! [First Iteration]
96+
itr.next(batchsize);
10397
CHECK(itr.size() == 3);
104-
105-
// The contents of the iterator are for batch 0.
106-
CHECK(itr.batch() == 0);
98+
CHECK(itr.batch_number() == 1);
10799

108100
// Obtain a view of the current list candidates.
109101
std::span<const svs::Neighbor<size_t>> results = itr.results();
@@ -117,18 +109,15 @@ void demonstrate_default_schedule() {
117109
CHECK(results[0].id() == 3);
118110
CHECK(results[1].id() == 4);
119111
CHECK(results[2].id() == 2);
120-
//! [Initial Checks]
112+
//! [First Iteration]
121113

122114
//! [Next Iteration]
123-
// Once we've finished with the current batch of neighbors, we can step the iterator
124-
// to the next batch.
125-
//
126-
// Using the `DefaultSchedule`, we will retrieve at most 3 new candidates.
127-
itr.next();
115+
// This will yield the next batch of neighbors.
116+
itr.next(batchsize);
128117
CHECK(itr.size() == 3);
129118

130-
// The contents of the iterator are for batch 1.
131-
CHECK(itr.batch() == 1);
119+
// The contents of the iterator are for batch 2.
120+
CHECK(itr.batch_number() == 2);
132121

133122
// Update and inspect the results.
134123
results = itr.results();
@@ -144,7 +133,7 @@ void demonstrate_default_schedule() {
144133
//! [Final Iteration]
145134
// So far, the iterator has yielded 6 of the 7 vectors in the dataset.
146135
// This call to `next()` should only yield a single neighbor - the last on in the index.
147-
itr.next();
136+
itr.next(batchsize);
148137
CHECK(itr.size() == 1);
149138
CHECK(itr.done());
150139
results = itr.results();
@@ -154,15 +143,15 @@ void demonstrate_default_schedule() {
154143

155144
//! [Beyond Final Iteration]
156145
// Calling `next()` again should yield no more candidates.
157-
itr.next();
146+
itr.next(batchsize);
158147
CHECK(itr.size() == 0);
159148
CHECK(itr.done());
160149
//! [Beyond Final Iteration]
161150
}
162151

163152
// Alternative main definition
164153
int svs_main(std::vector<std::string> SVS_UNUSED(args)) {
165-
demonstrate_default_schedule();
154+
demonstrate_iterator();
166155
return 0;
167156
}
168157

0 commit comments

Comments
 (0)