Skip to content

Commit 629a79c

Browse files
authored
Enhance heuristic pruning to handle duplicate clusters (#282)
This PR fixes #80 . It adds a post-pruning step to both: - `IterativePruneStrategy` - `ProgressivePruneStrategy` Approach: If a duplicate cluster is detected, the last (worst) slot in the `result` is replaced with the closest candidate from the pool that does not have the same distance. --------- Signed-off-by: Dilkhush Purohit <dilkhushpurohit01@gmail.com>
1 parent f601b56 commit 629a79c

2 files changed

Lines changed: 159 additions & 0 deletions

File tree

include/svs/index/vamana/prune.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ void heuristic_prune_neighbors(
130130

131131
auto pruned = std::vector<PruneState>(poolsize, PruneState::Available);
132132
float current_alpha = 1.0f;
133+
float anchor_dist = 0.0f;
134+
bool anchor_set = false;
135+
bool all_duplicates = true;
133136
while (result.size() < max_result_size && !cmp(alpha, current_alpha)) {
134137
size_t start = 0;
135138
while (result.size() < max_result_size && start < poolsize) {
@@ -145,6 +148,16 @@ void heuristic_prune_neighbors(
145148
const auto& query = accessor(dataset, id);
146149
distance::maybe_fix_argument(distance_function, query);
147150
result.push_back(detail::construct_as(lib::Type<I>(), pool[start]));
151+
152+
if (all_duplicates) {
153+
if (!anchor_set) {
154+
anchor_dist = pool[start].distance();
155+
anchor_set = true;
156+
} else if (pool[start].distance() != anchor_dist) {
157+
all_duplicates = false;
158+
}
159+
}
160+
148161
for (size_t t = start + 1; t < poolsize; ++t) {
149162
if (excluded(pruned[t])) {
150163
continue;
@@ -171,6 +184,40 @@ void heuristic_prune_neighbors(
171184
}
172185
current_alpha *= alpha;
173186
}
187+
188+
// Add a diversity edge if a duplicate cluster is detected
189+
if (all_duplicates && anchor_set && !result.empty()) {
190+
auto result_id = [](const I& r) -> size_t {
191+
if constexpr (std::integral<I>) {
192+
return static_cast<size_t>(r);
193+
} else {
194+
return static_cast<size_t>(r.id());
195+
}
196+
};
197+
for (size_t t = 0; t < poolsize; ++t) {
198+
const auto& candidate = pool[t];
199+
auto cid = candidate.id();
200+
if (cid == current_node_id || candidate.distance() == anchor_dist) {
201+
continue;
202+
}
203+
bool in_result = false;
204+
for (const auto& r : result) {
205+
if (result_id(r) == static_cast<size_t>(cid)) {
206+
in_result = true;
207+
break;
208+
}
209+
}
210+
assert(
211+
!in_result &&
212+
"Candidate with non-anchor distance should not already be in result"
213+
);
214+
if (in_result) {
215+
continue;
216+
}
217+
result.back() = detail::construct_as(lib::Type<I>(), candidate);
218+
break;
219+
}
220+
}
174221
}
175222

176223
template <
@@ -203,6 +250,9 @@ void heuristic_prune_neighbors(
203250
std::vector<float> pruned(poolsize, type_traits::tombstone_v<float, decltype(cmp)>);
204251

205252
float current_alpha = 1.0f;
253+
float anchor_dist = 0.0f;
254+
bool anchor_set = false;
255+
bool all_duplicates = true;
206256
while (result.size() < max_result_size && !cmp(alpha, current_alpha)) {
207257
size_t start = 0;
208258
while (result.size() < max_result_size && start < poolsize) {
@@ -218,6 +268,16 @@ void heuristic_prune_neighbors(
218268
const auto& query = accessor(dataset, id);
219269
distance::maybe_fix_argument(distance_function, query);
220270
result.push_back(detail::construct_as(lib::Type<I>(), pool[start]));
271+
272+
if (all_duplicates) {
273+
if (!anchor_set) {
274+
anchor_dist = pool[start].distance();
275+
anchor_set = true;
276+
} else if (pool[start].distance() != anchor_dist) {
277+
all_duplicates = false;
278+
}
279+
}
280+
221281
for (size_t t = start + 1; t < poolsize; ++t) {
222282
if (cmp(current_alpha, pruned[t])) {
223283
continue;
@@ -236,6 +296,40 @@ void heuristic_prune_neighbors(
236296
}
237297
current_alpha *= alpha;
238298
}
299+
300+
// Add a diversity edge if a duplicate cluster is detected
301+
if (all_duplicates && anchor_set && !result.empty()) {
302+
auto result_id = [](const I& r) -> size_t {
303+
if constexpr (std::integral<I>) {
304+
return static_cast<size_t>(r);
305+
} else {
306+
return static_cast<size_t>(r.id());
307+
}
308+
};
309+
for (size_t t = 0; t < poolsize; ++t) {
310+
const auto& candidate = pool[t];
311+
auto cid = candidate.id();
312+
if (cid == current_node_id || candidate.distance() == anchor_dist) {
313+
continue;
314+
}
315+
bool in_result = false;
316+
for (const auto& r : result) {
317+
if (result_id(r) == static_cast<size_t>(cid)) {
318+
in_result = true;
319+
break;
320+
}
321+
}
322+
assert(
323+
!in_result &&
324+
"Candidate with non-anchor distance should not already be in result"
325+
);
326+
if (in_result) {
327+
continue;
328+
}
329+
result.back() = detail::construct_as(lib::Type<I>(), candidate);
330+
break;
331+
}
332+
}
239333
}
240334

241335
///

tests/svs/index/vamana/prune.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
// header under test
1818
#include "svs/index/vamana/prune.h"
1919

20+
// core
21+
#include "svs/core/data/simple.h"
22+
#include "svs/core/distance/euclidean.h"
23+
2024
// catch2
2125
#include "catch2/catch_test_macros.hpp"
2226

@@ -46,4 +50,65 @@ CATCH_TEST_CASE("Pruning", "[index][vamana]") {
4650
CATCH_REQUIRE(v::excluded(v::PruneState::Pruned) == true);
4751
}
4852
}
53+
54+
CATCH_SECTION("Duplicate Cluster Trap") {
55+
auto data = svs::data::SimpleData<float>(6, 4);
56+
auto d0 = std::vector<float>{1.0f, 1.0f, 1.0f, 1.0f};
57+
auto d4 = std::vector<float>{2.0f, 1.0f, 1.0f, 1.0f};
58+
auto d5 = std::vector<float>{1.5f, 1.0f, 1.0f, 1.0f};
59+
60+
for (size_t i = 0; i < 4; ++i) {
61+
data.set_datum(i, d0);
62+
}
63+
data.set_datum(4, d4);
64+
data.set_datum(5, d5);
65+
66+
auto dist = svs::distance::DistanceL2();
67+
auto accessor = svs::data::GetDatumAccessor{};
68+
69+
std::vector<svs::Neighbor<size_t>> pool = {
70+
{size_t{0}, 0.0f},
71+
{size_t{1}, 0.0f},
72+
{size_t{2}, 0.0f},
73+
{size_t{3}, 0.0f},
74+
{size_t{4}, 1.0f}};
75+
76+
CATCH_SECTION("Iterative Strategy Fix") {
77+
std::vector<svs::Neighbor<size_t>> result;
78+
v::heuristic_prune_neighbors(
79+
v::IterativePruneStrategy{},
80+
2,
81+
1.3f,
82+
data,
83+
accessor,
84+
dist,
85+
size_t{5},
86+
std::span<const svs::Neighbor<size_t>>(pool),
87+
result
88+
);
89+
90+
CATCH_REQUIRE(result.size() == 2);
91+
CATCH_REQUIRE(result[0].id() == 0);
92+
CATCH_REQUIRE(result[1].id() == 4);
93+
}
94+
95+
CATCH_SECTION("Progressive Strategy Fix") {
96+
std::vector<svs::Neighbor<size_t>> result;
97+
v::heuristic_prune_neighbors(
98+
v::ProgressivePruneStrategy{},
99+
2,
100+
1.3f,
101+
data,
102+
accessor,
103+
dist,
104+
size_t{5},
105+
std::span<const svs::Neighbor<size_t>>(pool),
106+
result
107+
);
108+
109+
CATCH_REQUIRE(result.size() == 2);
110+
CATCH_REQUIRE(result[0].id() == 0);
111+
CATCH_REQUIRE(result[1].id() == 4);
112+
}
113+
}
49114
}

0 commit comments

Comments
 (0)