Skip to content

Commit 1a33e4e

Browse files
authored
feat(ivf): Add Float16 numpy array specializations for IVF/DynamicIVF assembly (#293)
Register Float16 pybind specializations for numpy array assembly paths. Previously only float32 was registered for the numpy array overloads of assemble_from_clustering and assemble_from_file in both IVF and DynamicIVF. Added: - add_assemble_from_clustering_array_specialization<svs::Float16> - add_assemble_from_file_array_specialization<svs::Float16>
1 parent 26fa402 commit 1a33e4e

4 files changed

Lines changed: 41 additions & 0 deletions

File tree

bindings/python/src/dynamic_ivf.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,9 @@ Method {}:
668668
}
669669

670670
// Assemble from numpy array.
671+
add_assemble_from_clustering_array_specialization<svs::Float16>(dynamic_ivf);
671672
add_assemble_from_clustering_array_specialization<float>(dynamic_ivf);
673+
add_assemble_from_file_array_specialization<svs::Float16>(dynamic_ivf);
672674
add_assemble_from_file_array_specialization<float>(dynamic_ivf);
673675

674676
// Index modification.

bindings/python/src/ivf.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,9 @@ void wrap(py::module& m) {
784784
detail::wrap_assemble(ivf);
785785

786786
// Assemble from numpy array.
787+
detail::add_assemble_from_clustering_array_specialization<svs::Float16>(ivf);
787788
detail::add_assemble_from_clustering_array_specialization<float>(ivf);
789+
detail::add_assemble_from_file_array_specialization<svs::Float16>(ivf);
788790
detail::add_assemble_from_file_array_specialization<float>(ivf);
789791

790792
// Make the IVF type searchable.

bindings/python/tests/test_dynamic_ivf.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,25 @@ def test_assemble_from_numpy(self):
257257
print(f" assemble_from_file numpy recall: {recall2}")
258258
self.assertTrue(0.5 < recall2 <= 1.0)
259259

260+
# Test with float16 numpy array
261+
data_f16 = data.astype('float16')
262+
print("Testing DynamicIVF.assemble_from_clustering with numpy array (float16)")
263+
index_f16 = svs.DynamicIVF.assemble_from_clustering(
264+
clustering = clustering,
265+
py_data = data_f16,
266+
ids = ids,
267+
distance = svs.DistanceType.L2,
268+
num_threads = num_threads,
269+
)
270+
self.assertEqual(index_f16.size, test_number_of_vectors)
271+
self.assertEqual(index_f16.dimensions, test_data_dims)
272+
273+
index_f16.search_parameters = search_params
274+
I_f16, D_f16 = index_f16.search(queries, k)
275+
recall_f16 = svs.k_recall_at(groundtruth, I_f16, k, k)
276+
print(f" assemble_from_clustering numpy float16 recall: {recall_f16}")
277+
self.assertTrue(0.4 < recall_f16 <= 1.0)
278+
260279
def test_build_from_loader(self):
261280
"""Test building DynamicIVF using a VectorDataLoader and explicit IDs."""
262281
num_threads = 2

bindings/python/tests/test_ivf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,24 @@ def test_assemble_from_numpy(self):
404404
print(f" assemble_from_file numpy recall: {recall2}")
405405
self.assertTrue(0.5 < recall2 <= 1.0)
406406

407+
# Test with float16 numpy array
408+
data_f16 = data.astype('float16')
409+
print("Testing IVF.assemble_from_clustering with numpy array (float16)")
410+
ivf_f16 = svs.IVF.assemble_from_clustering(
411+
clustering = clustering,
412+
py_data = data_f16,
413+
distance = svs.DistanceType.L2,
414+
num_threads = num_threads,
415+
)
416+
self.assertEqual(ivf_f16.size, test_number_of_vectors)
417+
self.assertEqual(ivf_f16.dimensions, test_dimensions)
418+
419+
ivf_f16.search_parameters = search_params
420+
I_f16, D_f16 = ivf_f16.search(queries, k)
421+
recall_f16 = svs.k_recall_at(groundtruth, I_f16, k, k)
422+
print(f" assemble_from_clustering numpy float16 recall: {recall_f16}")
423+
self.assertTrue(0.4 < recall_f16 <= 1.0)
424+
407425
def test_build(self):
408426
# Build directly from data
409427
data = svs.read_vecs(test_data_vecs)

0 commit comments

Comments
 (0)