Skip to content

Commit 43c078a

Browse files
authored
fix: Multi-vector dynamic vamana index Save/Load functionality (#162)
* Extended _"Multi-vector dynamic vamana index"_ test section _"Save/Load"_ to enforce multi-vector initialization and compare search results before and after savein-loading. * Fixed labels saving/loading in `MultiMutableVamanaIndex`
1 parent 117ba52 commit 43c078a

2 files changed

Lines changed: 19 additions & 11 deletions

File tree

include/svs/index/vamana/multi.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -583,24 +583,22 @@ class MultiMutableVamanaIndex {
583583
return a.first < b.first;
584584
});
585585

586-
std::vector<label_type> labels(ext_lab_vec.size());
586+
size_t num_labels = ext_lab_vec.size();
587+
std::vector<label_type> labels(num_labels);
587588
std::transform(
588589
ext_lab_vec.begin(),
589590
ext_lab_vec.end(),
590591
labels.begin(),
591592
[](const auto& ext_lab) { return ext_lab.second; }
592593
);
593-
size_t num_labels = ext_lab_vec.size();
594594

595595
// Save auxiliary data structures.
596596
lib::save_to_disk(
597597
lib::SaveOverride([&](const lib::SaveContext& ctx) {
598598
// Save labels to a file.
599599
auto filename = ctx.generate_name("labels", "binary");
600600
auto stream = lib::open_write(filename);
601-
for (const auto& l : ext_lab_vec) {
602-
lib::write_binary(stream, l.first);
603-
}
601+
lib::write_binary(stream, labels);
604602

605603
// Save the construction parameters.
606604
auto parameters = VamanaIndexParameters{
@@ -678,13 +676,10 @@ struct MultiVamanaStateLoader {
678676
switch (load_from) {
679677
case MultiMutableVamanaLoad::FROM_MULTI: {
680678
auto num_labels = lib::load_at<size_t>(table, "num_labels");
681-
std::vector<label_type> labels;
682-
labels.reserve(num_labels);
679+
std::vector<label_type> labels(num_labels);
683680
auto resolved = table.resolve_at("filename");
684681
auto stream = lib::open_read(resolved);
685-
for (size_t i = 0; i < num_labels; ++i) {
686-
labels.push_back(lib::read_binary<label_type>(stream));
687-
}
682+
lib::read_binary(stream, labels);
688683
return MultiVamanaStateLoader{
689684
SVS_LOAD_MEMBER_AT_(table, parameters),
690685
IDTranslator{},

tests/svs/index/vamana/multi.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,12 @@ CATCH_TEMPLATE_TEST_CASE(
248248
auto graph_dir = dir / "graph";
249249
auto data_dir = dir / "data";
250250
std::vector<size_t> test_indices(num_points);
251+
// Fill the test indices with labels in the range of num_labels
252+
// to ensure that there are labels mapped to more than 1 vector.
253+
const size_t per_label = 2;
254+
const auto num_labels = num_points / per_label;
251255
for (auto& i : test_indices) {
252-
i = std::rand();
256+
i = std::rand() % num_labels;
253257
}
254258
auto test_index = svs::index::vamana::MultiMutableVamanaIndex(
255259
build_parameters, data, test_indices, Distance(), num_threads
@@ -271,6 +275,15 @@ CATCH_TEMPLATE_TEST_CASE(
271275
test_index_2.search(test_results_2.view(), queries.view(), search_parameters);
272276
auto test_recall_2 = svs::k_recall_at_n(groundtruth, test_results_2);
273277

278+
// Check that the results are the same
279+
CATCH_REQUIRE(test_results.n_neighbors() == test_results_2.n_neighbors());
280+
for (size_t i = 0; i < test_results.n_queries(); ++i) {
281+
for (size_t j = 0; j < test_results.n_neighbors(); ++j) {
282+
CATCH_REQUIRE(test_results.indices().at(i, j) ==
283+
test_results_2.indices().at(i, j));
284+
}
285+
}
286+
274287
CATCH_REQUIRE(test_index.size() == test_index_2.size());
275288
CATCH_REQUIRE(test_index.dimensions() == test_index_2.dimensions());
276289
// Index Properties

0 commit comments

Comments
 (0)