Skip to content

Commit 99a0d6f

Browse files
committed
Support for signed integer values
1 parent 3a44930 commit 99a0d6f

2 files changed

Lines changed: 109 additions & 1 deletion

File tree

cppbktree/cppbktree.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ hammingDistance64( const uint64_t a,
7070
}
7171

7272

73+
inline size_t
74+
hammingDistance64s( const int64_t a,
75+
const int64_t b )
76+
{
77+
return countBits( a ^ b );
78+
}
79+
80+
7381
template<typename T_ValueType,
7482
typename T_DistanceType>
7583
class CppBKTree
@@ -311,6 +319,8 @@ class CppBKTree
311319
if ( !m_metricFunction ) {
312320
if constexpr ( std::is_same_v<ValueType, uint64_t> ) {
313321
m_metricFunction = MetricFunction( &hammingDistance64 );
322+
} else if constexpr ( std::is_same_v<ValueType, int64_t> ) {
323+
m_metricFunction = MetricFunction( &hammingDistance64s );
314324
} else if constexpr ( std::is_same_v<ValueType, std::vector<uint8_t> > ) {
315325
m_metricFunction = MetricFunction( &hammingDistance );
316326
} else {

cppbktree/cppbktree.pyx

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from libc.stdlib cimport malloc, free
44
from libc.stdio cimport SEEK_SET
5-
from libc.stdint cimport uint8_t, uint64_t
5+
from libc.stdint cimport uint8_t, uint64_t, int64_t
66
from libcpp.map cimport map
77
from libcpp.vector cimport vector
88
from libcpp.string cimport string
@@ -36,6 +36,23 @@ cdef class _LinearLookup64:
3636
def size(self):
3737
return self.data.size()
3838

39+
40+
cdef class _LinearLookup64s:
41+
cdef CppLinearLookup[int64_t]* data
42+
43+
def __cinit__(self, list_of_hashes):
44+
self.data = new CppLinearLookup[int64_t](list_of_hashes)
45+
46+
def __dealloc__(self):
47+
del self.data
48+
49+
def find(self, query, distance=0):
50+
return <list>(self.data.find(query, distance))
51+
52+
def size(self):
53+
return self.data.size()
54+
55+
3956
# Extra class because cdefs are not visible from outside
4057
class LinearLookup64:
4158
def __init__(self, list_of_hashes):
@@ -48,9 +65,21 @@ class LinearLookup64:
4865
return self.tree.size()
4966

5067

68+
class LinearLookup64s:
69+
def __init__(self, list_of_hashes):
70+
self.tree = _LinearLookup64s(list_of_hashes)
71+
72+
def find(self, query, distance=0):
73+
return self.tree.find(query, distance)
74+
75+
def size(self):
76+
return self.tree.size()
77+
78+
5179
cdef extern from "cppbktree.hpp":
5280
size_t hammingDistance(const vector[uint8_t]&, const vector[uint8_t]& ) except +
5381
size_t hammingDistance64( const uint64_t, const uint64_t ) except +;
82+
size_t hammingDistance64s( const int64_t, const int64_t ) except +;
5483

5584
cppclass CppBKTree[T_ValueType, T_DistanceType]:
5685
struct TreeStatistics:
@@ -178,6 +207,55 @@ cdef class _BKTree64:
178207
return self.tree.rebalance(self.max_element_count if max_element_count is None else <size_t>max_element_count)
179208

180209

210+
cdef class _BKTree64s:
211+
cdef CppBKTree[int64_t, size_t]* tree
212+
cdef int max_element_count
213+
cdef bool _needs_rebalance
214+
215+
def __cinit__(self, list_of_hashes_or_file_name, max_element_count = 32 * 1024):
216+
self.tree = new CppBKTree[int64_t, size_t](list_of_hashes_or_file_name)
217+
self.tree.rebalance(max_element_count)
218+
self.max_element_count = max_element_count
219+
self._needs_rebalance = False
220+
221+
def __dealloc__(self):
222+
del self.tree
223+
224+
def add(self, list_of_hashes_or_file_name):
225+
self.tree.add(<vector[int64_t]>list_of_hashes_or_file_name)
226+
self._needs_rebalance = True
227+
228+
def find(self, query, distance=0):
229+
if self._needs_rebalance:
230+
self.rebalance()
231+
self._needs_rebalance = False
232+
return <list>(self.tree.find(query, distance))
233+
234+
def size(self):
235+
return self.tree.size()
236+
237+
def statistics(self):
238+
# Automatic POD to dict conversion did not work for me. Maybe because the contained types?
239+
cdef CppBKTree[int64_t, size_t].TreeStatistics result = self.tree.statistics();
240+
stats = {
241+
'nodeCount' : result.nodeCount ,
242+
'leafCount' : result.leafCount ,
243+
'valueCount' : result.valueCount ,
244+
'averageChildCountPerNode' : result.averageChildCountPerNode,
245+
'maxDepth' : result.maxDepth ,
246+
'minChildrenPerNode' : result.minChildrenPerNode ,
247+
'maxChildrenPerNode' : result.maxChildrenPerNode ,
248+
'minPayloadsPerNode' : result.minPayloadsPerNode ,
249+
'maxPayloadsPerNode' : result.maxPayloadsPerNode ,
250+
'duplicateCount' : result.duplicateCount ,
251+
'valueBitCount' : result.valueBitCount ,
252+
}
253+
return stats
254+
255+
def rebalance(self, max_element_count = None):
256+
return self.tree.rebalance(self.max_element_count if max_element_count is None else <size_t>max_element_count)
257+
258+
181259
# Extra class because cdefs are not visible from outside
182260
class BKTree:
183261
def __init__(self, list_of_hashes, max_element_count = 32 * 1024):
@@ -219,4 +297,24 @@ class BKTree64:
219297
return self.tree.rebalance(max_element_count)
220298

221299

300+
class BKTree64s:
301+
def __init__(self, list_of_hashes, max_element_count = 32 * 1024):
302+
self.tree = _BKTree64s(list_of_hashes, max_element_count)
303+
304+
def add(self, list_of_hashes_or_file_name):
305+
self.tree.add(list_of_hashes_or_file_name)
306+
307+
def find(self, query, distance=0):
308+
return self.tree.find(query, distance)
309+
310+
def size(self):
311+
return self.tree.size()
312+
313+
def statistics(self):
314+
return self.tree.statistics()
315+
316+
def rebalance(self, max_element_count = None):
317+
return self.tree.rebalance(max_element_count)
318+
319+
222320
__version__ = '0.2.0'

0 commit comments

Comments
 (0)