Skip to content

Commit 5e52e82

Browse files
razdoburdinDmitry Razdoburdin
andauthored
Native serialization for MultiMutableVamana index (#289)
This PR introduce native serialization for `MultiMutableVamana` index. Should be merged after: #286 Main changes are: 1. New overload of `svs::index::vamana::auto_multi_dynamic_assemble` required for direct deserialization accepts lazy loaders and call them in a flexible order to cover legacy serialized models. 2. Added related tests. 3. `supports_saving` flag is keeped false, as far as it isn't used. 4. `MultiMutableVamana` doesn't have an orchestrator. So no changes on this side. --------- Co-authored-by: Dmitry Razdoburdin <drazdobu@intel.com>
1 parent 81c5d82 commit 5e52e82

2 files changed

Lines changed: 255 additions & 19 deletions

File tree

include/svs/index/vamana/multi.h

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -563,16 +563,8 @@ class MultiMutableVamanaIndex {
563563
constexpr std::string_view name() const { return "multi dynamic vamana index"; }
564564

565565
static constexpr lib::Version save_version = lib::Version(0, 0, 0);
566-
void save(
567-
const std::filesystem::path& config_directory,
568-
const std::filesystem::path& graph_directory,
569-
const std::filesystem::path& data_directory
570-
) {
571-
// Post-consolidation, all entries should be "valid".
572-
// Therefore, we don't need to save the slot metadata.
573-
consolidate();
574-
compact();
575566

567+
auto get_labels() const {
576568
// Since data is in order of external ids,
577569
// convert a map of external ids to label types into a sorted vector of labels based
578570
// on external ids.
@@ -592,6 +584,34 @@ class MultiMutableVamanaIndex {
592584
[](const auto& ext_lab) { return ext_lab.second; }
593585
);
594586

587+
return labels;
588+
}
589+
590+
VamanaIndexParameters get_parameters() const {
591+
return {
592+
index_->entry_point_.front(),
593+
{get_alpha(),
594+
max_degree(),
595+
get_construction_window_size(),
596+
get_max_candidates(),
597+
get_prune_to(),
598+
get_full_search_history()},
599+
get_search_parameters()};
600+
}
601+
602+
void save(
603+
const std::filesystem::path& config_directory,
604+
const std::filesystem::path& graph_directory,
605+
const std::filesystem::path& data_directory
606+
) {
607+
// Post-consolidation, all entries should be "valid".
608+
// Therefore, we don't need to save the slot metadata.
609+
consolidate();
610+
compact();
611+
612+
auto labels = get_labels();
613+
size_t num_labels = labels.size();
614+
595615
// Save auxiliary data structures.
596616
lib::save_to_disk(
597617
lib::SaveOverride([&](const lib::SaveContext& ctx) {
@@ -601,16 +621,7 @@ class MultiMutableVamanaIndex {
601621
lib::write_binary(stream, labels);
602622

603623
// Save the construction parameters.
604-
auto parameters = VamanaIndexParameters{
605-
index_->entry_point_.front(),
606-
{get_alpha(),
607-
max_degree(),
608-
get_construction_window_size(),
609-
get_max_candidates(),
610-
get_prune_to(),
611-
get_full_search_history()},
612-
get_search_parameters()};
613-
624+
auto parameters = get_parameters();
614625
return lib::SaveTable(
615626
"multi_vamana_dynamic_auxiliary_parameters",
616627
save_version,
@@ -628,6 +639,32 @@ class MultiMutableVamanaIndex {
628639
// Graph
629640
lib::save_to_disk(index_->graph_, graph_directory);
630641
}
642+
643+
void save(std::ostream& os) {
644+
consolidate();
645+
compact();
646+
647+
auto labels = get_labels();
648+
size_t num_labels = labels.size();
649+
650+
lib::begin_serialization(os);
651+
652+
auto parameters = get_parameters();
653+
auto save_table = lib::SaveTable(
654+
"multi_vamana_dynamic_auxiliary_parameters",
655+
save_version,
656+
{{"name", lib::save(name())},
657+
{"parameters", lib::save(parameters)},
658+
{"num_labels", lib::save(num_labels)}}
659+
);
660+
lib::save_to_stream(save_table, os);
661+
lib::write_binary(os, labels);
662+
663+
// Save the dataset.
664+
lib::save_to_stream(index_->data_, os);
665+
// Save the graph.
666+
lib::save_to_stream(index_->graph_, os);
667+
}
631668
};
632669

633670
///// Deduction Guides.
@@ -789,4 +826,58 @@ auto auto_multi_dynamic_assemble(
789826
}
790827
}
791828

829+
template <
830+
typename LazyGraphLoader,
831+
typename LazyDataLoader,
832+
typename Distance,
833+
typename ThreadPoolProto>
834+
auto auto_multi_dynamic_assemble(
835+
std::istream& is,
836+
LazyGraphLoader graph_loader,
837+
LazyDataLoader data_loader,
838+
Distance distance,
839+
ThreadPoolProto threadpool_proto,
840+
svs::logging::logger_ptr logger = svs::logging::get()
841+
) {
842+
using label_type = size_t;
843+
844+
auto table = lib::detail::read_metadata(is);
845+
846+
auto parameters = lib::load<VamanaIndexParameters>(
847+
table.template cast<toml::table>().at("parameters").template cast<toml::table>()
848+
);
849+
850+
auto num_labels =
851+
lib::load<size_t>(table.template cast<toml::table>().at("num_labels"));
852+
853+
// Read labels binary data directly from the stream.
854+
std::vector<label_type> labels(num_labels);
855+
lib::read_binary(is, labels);
856+
857+
auto data = data_loader();
858+
auto graph = graph_loader();
859+
860+
auto datasize = data.size();
861+
auto graphsize = graph.n_nodes();
862+
if (datasize != graphsize) {
863+
throw ANNEXCEPTION(
864+
"Reloaded data has {} nodes while the graph has {} nodes!", datasize, graphsize
865+
);
866+
}
867+
868+
if (labels.size() != datasize) {
869+
throw ANNEXCEPTION("Labels has {} IDs but should have {}", labels.size(), datasize);
870+
}
871+
872+
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
873+
return MultiMutableVamanaIndex{
874+
parameters,
875+
std::move(data),
876+
std::move(graph),
877+
std::move(distance),
878+
labels,
879+
std::move(threadpool),
880+
std::move(logger)};
881+
}
882+
792883
} // namespace svs::index::vamana

tests/svs/index/vamana/multi.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "catch2/catch_test_macros.hpp"
2727

2828
// stl
29+
#include <sstream>
2930
#include <unordered_map>
3031
#include <unordered_set>
3132
#include <vector>
@@ -304,3 +305,147 @@ CATCH_TEMPLATE_TEST_CASE(
304305
CATCH_REQUIRE(test_recall_2 > test_recall - epsilon);
305306
}
306307
}
308+
309+
CATCH_TEST_CASE(
310+
"MultiMutableVamana Index Save and Load", "[index][vamana][multi][saveload]"
311+
) {
312+
using Eltype = float;
313+
using Distance = svs::DistanceL2;
314+
const size_t N = 128;
315+
const size_t num_threads = 4;
316+
const size_t num_neighbors = 10;
317+
const size_t max_degree = 64;
318+
319+
const auto data = svs::data::SimpleData<Eltype, N>::load(test_dataset::data_svs_file());
320+
const auto num_points = data.size();
321+
const auto queries = test_dataset::queries();
322+
const auto groundtruth = test_dataset::load_groundtruth(svs::distance_type_v<Distance>);
323+
324+
const svs::index::vamana::VamanaBuildParameters build_parameters{
325+
1.2, max_degree, 10, 20, 10, true};
326+
327+
const auto search_parameters = svs::index::vamana::VamanaSearchParameters();
328+
329+
const float epsilon = 0.05f;
330+
331+
std::vector<size_t> test_indices(num_points);
332+
const size_t per_label = 2;
333+
const auto num_labels = num_points / per_label;
334+
for (auto& i : test_indices) {
335+
i = std::rand() % num_labels;
336+
}
337+
338+
auto index = svs::index::vamana::MultiMutableVamanaIndex(
339+
build_parameters, data, test_indices, Distance(), num_threads
340+
);
341+
auto results = svs::QueryResult<size_t>(queries.size(), num_neighbors);
342+
index.search(results.view(), queries.view(), search_parameters);
343+
344+
CATCH_SECTION("Load MultiMutableVamana Index being serialized natively to stream") {
345+
std::stringstream stream;
346+
index.save(stream);
347+
{
348+
auto deserializer = svs::lib::detail::Deserializer::build(stream);
349+
CATCH_REQUIRE(deserializer.is_native());
350+
351+
using Data_t = svs::data::SimpleData<Eltype, N>;
352+
using GraphType = svs::graphs::SimpleBlockedGraph<uint32_t>;
353+
354+
auto loaded = svs::index::vamana::auto_multi_dynamic_assemble(
355+
stream,
356+
[&]() -> GraphType { return GraphType::load(stream); },
357+
[&]() -> Data_t { return svs::lib::load_from_stream<Data_t>(stream); },
358+
Distance(),
359+
num_threads
360+
);
361+
362+
CATCH_REQUIRE(loaded.size() == index.size());
363+
CATCH_REQUIRE(loaded.dimensions() == index.dimensions());
364+
CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha());
365+
CATCH_REQUIRE(
366+
loaded.get_construction_window_size() ==
367+
index.get_construction_window_size()
368+
);
369+
CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates());
370+
CATCH_REQUIRE(loaded.max_degree() == index.max_degree());
371+
CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to());
372+
CATCH_REQUIRE(
373+
loaded.get_full_search_history() == index.get_full_search_history()
374+
);
375+
CATCH_REQUIRE(loaded.view_data() == index.view_data());
376+
377+
auto loaded_results = svs::QueryResult<size_t>(queries.size(), num_neighbors);
378+
loaded.search(loaded_results.view(), queries.view(), search_parameters);
379+
for (size_t i = 0; i < results.n_queries(); ++i) {
380+
for (size_t j = 0; j < results.n_neighbors(); ++j) {
381+
CATCH_REQUIRE(
382+
results.indices().at(i, j) == loaded_results.indices().at(i, j)
383+
);
384+
}
385+
}
386+
387+
auto loaded_recall = svs::k_recall_at_n(groundtruth, loaded_results);
388+
auto test_recall = svs::k_recall_at_n(groundtruth, results);
389+
CATCH_REQUIRE(loaded_recall > test_recall - epsilon);
390+
}
391+
}
392+
393+
CATCH_SECTION("Load MultiMutableVamana Index being serialized with intermediate files"
394+
) {
395+
std::stringstream stream;
396+
svs::lib::UniqueTempDirectory tempdir{"svs_multivamana_save"};
397+
const auto config_dir = tempdir.get() / "config";
398+
const auto graph_dir = tempdir.get() / "graph";
399+
const auto data_dir = tempdir.get() / "data";
400+
std::filesystem::create_directories(config_dir);
401+
std::filesystem::create_directories(graph_dir);
402+
std::filesystem::create_directories(data_dir);
403+
index.save(config_dir, graph_dir, data_dir);
404+
svs::lib::DirectoryArchiver::pack(tempdir, stream);
405+
{
406+
using Data_t = svs::data::SimpleData<Eltype, N>;
407+
using GraphType = svs::graphs::SimpleBlockedGraph<uint32_t>;
408+
409+
auto deserializer = svs::lib::detail::Deserializer::build(stream);
410+
CATCH_REQUIRE(!deserializer.is_native());
411+
svs::lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic());
412+
413+
auto loaded = svs::index::vamana::auto_multi_dynamic_assemble(
414+
config_dir,
415+
GraphType::load(graph_dir),
416+
Data_t::load(data_dir),
417+
Distance(),
418+
num_threads
419+
);
420+
421+
CATCH_REQUIRE(loaded.size() == index.size());
422+
CATCH_REQUIRE(loaded.dimensions() == index.dimensions());
423+
CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha());
424+
CATCH_REQUIRE(
425+
loaded.get_construction_window_size() ==
426+
index.get_construction_window_size()
427+
);
428+
CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates());
429+
CATCH_REQUIRE(loaded.max_degree() == index.max_degree());
430+
CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to());
431+
CATCH_REQUIRE(
432+
loaded.get_full_search_history() == index.get_full_search_history()
433+
);
434+
CATCH_REQUIRE(loaded.view_data() == index.view_data());
435+
436+
auto loaded_results = svs::QueryResult<size_t>(queries.size(), num_neighbors);
437+
loaded.search(loaded_results.view(), queries.view(), search_parameters);
438+
for (size_t i = 0; i < results.n_queries(); ++i) {
439+
for (size_t j = 0; j < results.n_neighbors(); ++j) {
440+
CATCH_REQUIRE(
441+
results.indices().at(i, j) == loaded_results.indices().at(i, j)
442+
);
443+
}
444+
}
445+
446+
auto loaded_recall = svs::k_recall_at_n(groundtruth, loaded_results);
447+
auto test_recall = svs::k_recall_at_n(groundtruth, results);
448+
CATCH_REQUIRE(loaded_recall > test_recall - epsilon);
449+
}
450+
}
451+
}

0 commit comments

Comments
 (0)