|
26 | 26 | #include "catch2/catch_test_macros.hpp" |
27 | 27 |
|
28 | 28 | // stl |
| 29 | +#include <sstream> |
29 | 30 | #include <unordered_map> |
30 | 31 | #include <unordered_set> |
31 | 32 | #include <vector> |
@@ -304,3 +305,147 @@ CATCH_TEMPLATE_TEST_CASE( |
304 | 305 | CATCH_REQUIRE(test_recall_2 > test_recall - epsilon); |
305 | 306 | } |
306 | 307 | } |
| 308 | + |
| 309 | +CATCH_TEST_CASE( |
| 310 | + "MultiMutableVamana Index Save and Load", "[index][vamana][multi][saveload]" |
| 311 | +) { |
| 312 | + using Eltype = float; |
| 313 | + using Distance = svs::DistanceL2; |
| 314 | + const size_t N = 128; |
| 315 | + const size_t num_threads = 4; |
| 316 | + const size_t num_neighbors = 10; |
| 317 | + const size_t max_degree = 64; |
| 318 | + |
| 319 | + const auto data = svs::data::SimpleData<Eltype, N>::load(test_dataset::data_svs_file()); |
| 320 | + const auto num_points = data.size(); |
| 321 | + const auto queries = test_dataset::queries(); |
| 322 | + const auto groundtruth = test_dataset::load_groundtruth(svs::distance_type_v<Distance>); |
| 323 | + |
| 324 | + const svs::index::vamana::VamanaBuildParameters build_parameters{ |
| 325 | + 1.2, max_degree, 10, 20, 10, true}; |
| 326 | + |
| 327 | + const auto search_parameters = svs::index::vamana::VamanaSearchParameters(); |
| 328 | + |
| 329 | + const float epsilon = 0.05f; |
| 330 | + |
| 331 | + std::vector<size_t> test_indices(num_points); |
| 332 | + const size_t per_label = 2; |
| 333 | + const auto num_labels = num_points / per_label; |
| 334 | + for (auto& i : test_indices) { |
| 335 | + i = std::rand() % num_labels; |
| 336 | + } |
| 337 | + |
| 338 | + auto index = svs::index::vamana::MultiMutableVamanaIndex( |
| 339 | + build_parameters, data, test_indices, Distance(), num_threads |
| 340 | + ); |
| 341 | + auto results = svs::QueryResult<size_t>(queries.size(), num_neighbors); |
| 342 | + index.search(results.view(), queries.view(), search_parameters); |
| 343 | + |
| 344 | + CATCH_SECTION("Load MultiMutableVamana Index being serialized natively to stream") { |
| 345 | + std::stringstream stream; |
| 346 | + index.save(stream); |
| 347 | + { |
| 348 | + auto deserializer = svs::lib::detail::Deserializer::build(stream); |
| 349 | + CATCH_REQUIRE(deserializer.is_native()); |
| 350 | + |
| 351 | + using Data_t = svs::data::SimpleData<Eltype, N>; |
| 352 | + using GraphType = svs::graphs::SimpleBlockedGraph<uint32_t>; |
| 353 | + |
| 354 | + auto loaded = svs::index::vamana::auto_multi_dynamic_assemble( |
| 355 | + stream, |
| 356 | + [&]() -> GraphType { return GraphType::load(stream); }, |
| 357 | + [&]() -> Data_t { return svs::lib::load_from_stream<Data_t>(stream); }, |
| 358 | + Distance(), |
| 359 | + num_threads |
| 360 | + ); |
| 361 | + |
| 362 | + CATCH_REQUIRE(loaded.size() == index.size()); |
| 363 | + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); |
| 364 | + CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha()); |
| 365 | + CATCH_REQUIRE( |
| 366 | + loaded.get_construction_window_size() == |
| 367 | + index.get_construction_window_size() |
| 368 | + ); |
| 369 | + CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates()); |
| 370 | + CATCH_REQUIRE(loaded.max_degree() == index.max_degree()); |
| 371 | + CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to()); |
| 372 | + CATCH_REQUIRE( |
| 373 | + loaded.get_full_search_history() == index.get_full_search_history() |
| 374 | + ); |
| 375 | + CATCH_REQUIRE(loaded.view_data() == index.view_data()); |
| 376 | + |
| 377 | + auto loaded_results = svs::QueryResult<size_t>(queries.size(), num_neighbors); |
| 378 | + loaded.search(loaded_results.view(), queries.view(), search_parameters); |
| 379 | + for (size_t i = 0; i < results.n_queries(); ++i) { |
| 380 | + for (size_t j = 0; j < results.n_neighbors(); ++j) { |
| 381 | + CATCH_REQUIRE( |
| 382 | + results.indices().at(i, j) == loaded_results.indices().at(i, j) |
| 383 | + ); |
| 384 | + } |
| 385 | + } |
| 386 | + |
| 387 | + auto loaded_recall = svs::k_recall_at_n(groundtruth, loaded_results); |
| 388 | + auto test_recall = svs::k_recall_at_n(groundtruth, results); |
| 389 | + CATCH_REQUIRE(loaded_recall > test_recall - epsilon); |
| 390 | + } |
| 391 | + } |
| 392 | + |
| 393 | + CATCH_SECTION("Load MultiMutableVamana Index being serialized with intermediate files" |
| 394 | + ) { |
| 395 | + std::stringstream stream; |
| 396 | + svs::lib::UniqueTempDirectory tempdir{"svs_multivamana_save"}; |
| 397 | + const auto config_dir = tempdir.get() / "config"; |
| 398 | + const auto graph_dir = tempdir.get() / "graph"; |
| 399 | + const auto data_dir = tempdir.get() / "data"; |
| 400 | + std::filesystem::create_directories(config_dir); |
| 401 | + std::filesystem::create_directories(graph_dir); |
| 402 | + std::filesystem::create_directories(data_dir); |
| 403 | + index.save(config_dir, graph_dir, data_dir); |
| 404 | + svs::lib::DirectoryArchiver::pack(tempdir, stream); |
| 405 | + { |
| 406 | + using Data_t = svs::data::SimpleData<Eltype, N>; |
| 407 | + using GraphType = svs::graphs::SimpleBlockedGraph<uint32_t>; |
| 408 | + |
| 409 | + auto deserializer = svs::lib::detail::Deserializer::build(stream); |
| 410 | + CATCH_REQUIRE(!deserializer.is_native()); |
| 411 | + svs::lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic()); |
| 412 | + |
| 413 | + auto loaded = svs::index::vamana::auto_multi_dynamic_assemble( |
| 414 | + config_dir, |
| 415 | + GraphType::load(graph_dir), |
| 416 | + Data_t::load(data_dir), |
| 417 | + Distance(), |
| 418 | + num_threads |
| 419 | + ); |
| 420 | + |
| 421 | + CATCH_REQUIRE(loaded.size() == index.size()); |
| 422 | + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); |
| 423 | + CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha()); |
| 424 | + CATCH_REQUIRE( |
| 425 | + loaded.get_construction_window_size() == |
| 426 | + index.get_construction_window_size() |
| 427 | + ); |
| 428 | + CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates()); |
| 429 | + CATCH_REQUIRE(loaded.max_degree() == index.max_degree()); |
| 430 | + CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to()); |
| 431 | + CATCH_REQUIRE( |
| 432 | + loaded.get_full_search_history() == index.get_full_search_history() |
| 433 | + ); |
| 434 | + CATCH_REQUIRE(loaded.view_data() == index.view_data()); |
| 435 | + |
| 436 | + auto loaded_results = svs::QueryResult<size_t>(queries.size(), num_neighbors); |
| 437 | + loaded.search(loaded_results.view(), queries.view(), search_parameters); |
| 438 | + for (size_t i = 0; i < results.n_queries(); ++i) { |
| 439 | + for (size_t j = 0; j < results.n_neighbors(); ++j) { |
| 440 | + CATCH_REQUIRE( |
| 441 | + results.indices().at(i, j) == loaded_results.indices().at(i, j) |
| 442 | + ); |
| 443 | + } |
| 444 | + } |
| 445 | + |
| 446 | + auto loaded_recall = svs::k_recall_at_n(groundtruth, loaded_results); |
| 447 | + auto test_recall = svs::k_recall_at_n(groundtruth, results); |
| 448 | + CATCH_REQUIRE(loaded_recall > test_recall - epsilon); |
| 449 | + } |
| 450 | + } |
| 451 | +} |
0 commit comments