Skip to content

Commit 302e58d

Browse files
committed
address PR review: refactor and optimize adaptive batch size
- Rename compute_filtered_batch_size to predict_further_processing and move to svs_runtime_utils.h for reuse - Use float arithmetic instead of double for hit rate calculation - Compute batch size at loop start to avoid unnecessary computation - Use iterator.size() instead of per-element increment for total_checked - Initial batch size = max(k, search_window_size) - Apply adaptive batch size to vamana_index_impl.h filtered search
1 parent 30d26c4 commit 302e58d

3 files changed

Lines changed: 24 additions & 19 deletions

File tree

bindings/cpp/src/dynamic_vamana_index_impl.h

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,6 @@
3838
namespace svs {
3939
namespace runtime {
4040

41-
// Compute the next batch size based on observed filter hit rate.
42-
// On the first round (found == 0), returns initial_batch_size unchanged.
43-
// On subsequent rounds, estimates how many candidates are needed to find the
44-
// remaining results given the observed hit rate.
45-
inline size_t compute_filtered_batch_size(
46-
size_t found, size_t needed, size_t total_checked, size_t initial_batch_size
47-
) {
48-
if (found == 0 || found >= needed) {
49-
return initial_batch_size;
50-
}
51-
double hit_rate = static_cast<double>(found) / total_checked;
52-
return static_cast<size_t>((needed - found) / hit_rate);
53-
}
54-
5541
// Dynamic Vamana index implementation
5642
class DynamicVamanaIndexImpl {
5743
using allocator_type = svs::data::Blocked<svs::lib::Allocator<float>>;
@@ -140,11 +126,13 @@ class DynamicVamanaIndexImpl {
140126
auto iterator = impl_->batch_iterator(query);
141127
size_t found = 0;
142128
size_t total_checked = 0;
143-
auto batch_size = sp.buffer_config_.get_search_window_size();
129+
auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size());
144130
do {
131+
batch_size =
132+
predict_further_processing(total_checked, found, k, batch_size);
145133
iterator.next(batch_size);
134+
total_checked += iterator.size();
146135
for (auto& neighbor : iterator.results()) {
147-
total_checked++;
148136
if (filter->is_member(neighbor.id())) {
149137
result.set(neighbor, i, found);
150138
found++;
@@ -153,8 +141,6 @@ class DynamicVamanaIndexImpl {
153141
}
154142
}
155143
}
156-
batch_size =
157-
compute_filtered_batch_size(found, k, total_checked, batch_size);
158144
} while (found < k && !iterator.done());
159145

160146
// Pad results if not enough neighbors found

bindings/cpp/src/svs_runtime_utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,20 @@ auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) {
431431
}
432432
} // namespace storage
433433

434+
// Predict how many more items need to be processed to reach the goal,
435+
// based on the observed hit rate so far.
436+
// If no hits yet, returns `hint` unchanged.
437+
// The caller should cap the result to a max batch size if needed.
438+
inline size_t predict_further_processing(
439+
size_t processed, size_t hits, size_t goal, size_t hint
440+
) {
441+
if (hits == 0 || hits >= goal) {
442+
return hint;
443+
}
444+
float batch_size = static_cast<float>(goal - hits) * processed / hits;
445+
return std::max(static_cast<size_t>(batch_size), size_t{1});
446+
}
447+
434448
inline svs::threads::ThreadPoolHandle default_threadpool() {
435449
return svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads())
436450
);

bindings/cpp/src/vamana_index_impl.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,13 @@ class VamanaIndexImpl {
131131
auto query = queries.get_datum(i);
132132
auto iterator = get_impl()->batch_iterator(query);
133133
size_t found = 0;
134+
size_t total_checked = 0;
135+
auto batch_size = std::max(k, sp.buffer_config_.get_search_window_size());
134136
do {
135-
iterator.next(k);
137+
batch_size =
138+
predict_further_processing(total_checked, found, k, batch_size);
139+
iterator.next(batch_size);
140+
total_checked += iterator.size();
136141
for (auto& neighbor : iterator.results()) {
137142
if (filter->is_member(neighbor.id())) {
138143
result.set(neighbor, i, found);

0 commit comments

Comments
 (0)