2222import org .apache .datasketches .common .Family ;
2323import org .apache .datasketches .common .SketchesArgumentException ;
2424import org .apache .datasketches .common .SketchesException ;
25+ import org .apache .datasketches .common .Util ;
26+ import org .apache .datasketches .common .positional .PositionalSegment ;
2527import org .apache .datasketches .hash .MurmurHash3 ;
26- import org .apache .datasketches .tuple .Util ;
2728
28- import java .io .ByteArrayOutputStream ;
29- import java .nio .ByteBuffer ;
29+ import java .lang .foreign .MemorySegment ;
3030import java .nio .charset .StandardCharsets ;
3131import java .util .Random ;
3232
33+ import static java .lang .foreign .ValueLayout .JAVA_BYTE ;
34+ import static java .lang .foreign .ValueLayout .JAVA_INT_UNALIGNED ;
35+ import static java .lang .foreign .ValueLayout .JAVA_LONG_UNALIGNED ;
36+ import static java .lang .foreign .ValueLayout .JAVA_SHORT_UNALIGNED ;
3337
38+
39+ /**
40+ * Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan.
41+ * This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens.
42+ *
43+ * The CountMin sketch is a probabilistic data structure that provides frequency estimates for items
44+ * in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array,
45+ * providing approximate counts with configurable error bounds.
46+ *
47+ * Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf
48+ */
3449public class CountMinSketch {
3550 private final byte numHashes_ ;
3651 private final int numBuckets_ ;
@@ -39,6 +54,9 @@ public class CountMinSketch {
3954 private final long [] sketchArray_ ;
4055 private long totalWeight_ ;
4156
57+ // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
58+ private static final ThreadLocal <MemorySegment > LONG_SEGMENT =
59+ ThreadLocal .withInitial (() -> MemorySegment .ofArray (new byte [Long .BYTES ]));
4260
4361 private enum Flag {
4462 IS_EMPTY ;
@@ -57,35 +75,64 @@ int mask() {
5775 * @param seed The base hash seed
5876 */
5977 CountMinSketch (final byte numHashes , final int numBuckets , final long seed ) {
60- numHashes_ = numHashes ;
61- numBuckets_ = numBuckets ;
62- seed_ = seed ;
63- hashSeeds_ = new long [numHashes ];
64- sketchArray_ = new long [numHashes * numBuckets ];
65- totalWeight_ = 0 ;
78+ // Validate numHashes
79+ if (numHashes <= 0 ) {
80+ throw new SketchesArgumentException ("Number of hash functions must be positive, got: " + numHashes );
81+ }
6682
83+ // Validate numBuckets with clear mathematical justification
84+ if (numBuckets <= 0 ) {
85+ throw new SketchesArgumentException ("Number of buckets must be positive, got: " + numBuckets );
86+ }
6787 if (numBuckets < 3 ) {
68- throw new SketchesArgumentException ("Using fewer than 3 buckets incurs relative error greater than 1." );
88+ throw new SketchesArgumentException ("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " +
89+ "With " + numBuckets + " buckets, relative error would be " + String .format ("%.3f" , Math .exp (1.0 ) / numBuckets ));
90+ }
91+
92+ // Check for potential overflow in array size calculation
93+ // Use long arithmetic to detect overflow before casting
94+ final long totalSize = (long ) numHashes * (long ) numBuckets ;
95+ if (totalSize > Integer .MAX_VALUE ) {
96+ throw new SketchesArgumentException ("Sketch array size would overflow: " + numHashes + " * " + numBuckets
97+ + " = " + totalSize + " > " + Integer .MAX_VALUE );
6998 }
7099
71100 // This check is to ensure later compatibility with a Java implementation whose maximum size can only
72101 // be 2^31-1. We check only against 2^30 for simplicity.
73- if (numBuckets * numHashes >= 1 << 30 ) {
74- throw new SketchesArgumentException ("These parameters generate a sketch that exceeds 2^30 elements. \n " +
75- "Try reducing either the number of buckets or the number of hash functions." );
102+ if (totalSize >= (1L << 30 )) {
103+ throw new SketchesArgumentException ("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
104+ + " = " + totalSize + " elements (~" + String .format ("%d" , totalSize * Long .BYTES / (1024 * 1024 * 1024 )) + " GB). "
105+ + "Consider reducing numHashes or numBuckets." );
76106 }
77107
78- Random rand = new Random (seed );
108+ numHashes_ = numHashes ;
109+ numBuckets_ = numBuckets ;
110+ seed_ = seed ;
111+ hashSeeds_ = new long [numHashes ];
112+ sketchArray_ = new long [(int ) totalSize ];
113+ totalWeight_ = 0 ;
114+
115+ final Random rand = new Random (seed );
79116 for (int i = 0 ; i < numHashes ; i ++) {
80117 hashSeeds_ [i ] = rand .nextLong ();
81118 }
82119 }
83120
84- private long [] getHashes (byte [] item ) {
85- long [] updateLocations = new long [numHashes_ ];
121+ /**
122+ * Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
123+ */
124+ private static byte [] longToBytes (final long value ) {
125+ final MemorySegment segment = LONG_SEGMENT .get ();
126+ segment .set (JAVA_LONG_UNALIGNED , 0 , value );
127+ return segment .toArray (JAVA_BYTE );
128+ }
129+
130+
131+ private long [] getHashes (final byte [] item ) {
132+ final long [] updateLocations = new long [numHashes_ ];
86133
87134 for (int i = 0 ; i < numHashes_ ; i ++) {
88- long [] index = MurmurHash3 .hash (item , hashSeeds_ [i ]);
135+ final long [] index = MurmurHash3 .hash (item , hashSeeds_ [i ]);
89136 updateLocations [i ] = i * (long )numBuckets_ + Math .floorMod (index [0 ], numBuckets_ );
90137 }
91138
@@ -145,11 +192,11 @@ public double getRelativeError() {
145192 * @param confidence The desired confidence level between 0 and 1.
146193 * @return Suggested number of hash functions.
147194 */
148- public static byte suggestNumHashes (double confidence ) {
195+ public static byte suggestNumHashes (final double confidence ) {
149196 if (confidence < 0 || confidence > 1 ) {
150197 throw new SketchesException ("Confidence must be between 0 and 1.0 (inclusive)." );
151198 }
152- int value = (int ) Math .ceil (Math .log (1.0 / (1.0 - confidence )));
199+ final int value = (int ) Math .ceil (Math .log (1.0 / (1.0 - confidence )));
153200 return (byte ) Math .min (value , 127 );
154201 }
155202
@@ -158,7 +205,7 @@ public static byte suggestNumHashes(double confidence) {
158205 * @param relativeError The desired relative error.
159206 * @return Suggested number of buckets.
160207 */
161- public static int suggestNumBuckets (double relativeError ) {
208+ public static int suggestNumBuckets (final double relativeError ) {
162209 if (relativeError < 0. ) {
163210 throw new SketchesException ("Relative error must be at least 0." );
164211 }
@@ -171,8 +218,7 @@ public static int suggestNumBuckets(double relativeError) {
171218 * @param weight The weight of the item.
172219 */
173220 public void update (final long item , final long weight ) {
174- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
175- update (longByte , weight );
221+ update (longToBytes (item ), weight );
176222 }
177223
178224 /**
@@ -199,8 +245,8 @@ public void update(final byte[] item, final long weight) {
199245 }
200246
201247 totalWeight_ += weight > 0 ? weight : -weight ;
202- long [] hashLocations = getHashes (item );
203- for (long h : hashLocations ) {
248+ final long [] hashLocations = getHashes (item );
249+ for (final long h : hashLocations ) {
204250 sketchArray_ [(int ) h ] += weight ;
205251 }
206252 }
@@ -211,8 +257,7 @@ public void update(final byte[] item, final long weight) {
211257 * @return Estimated frequency.
212258 */
213259 public long getEstimate (final long item ) {
214- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
215- return getEstimate (longByte );
260+ return getEstimate (longToBytes (item ));
216261 }
217262
218263 /**
@@ -239,10 +284,11 @@ public long getEstimate(final byte[] item) {
239284 return 0 ;
240285 }
241286
242- long [] hashLocations = getHashes (item );
287+ final long [] hashLocations = getHashes (item );
243288 long res = sketchArray_ [(int ) hashLocations [0 ]];
244- for (long h : hashLocations ) {
245- res = Math .min (res , sketchArray_ [(int ) h ]);
289+ // Start from index 1 to avoid processing first element twice
290+ for (int i = 1 ; i < hashLocations .length ; i ++) {
291+ res = Math .min (res , sketchArray_ [(int ) hashLocations [i ]]);
246292 }
247293
248294 return res ;
@@ -254,8 +300,7 @@ public long getEstimate(final byte[] item) {
254300 * @return Upper bound of estimated frequency.
255301 */
256302 public long getUpperBound (final long item ) {
257- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
258- return getUpperBound (longByte );
303+ return getUpperBound (longToBytes (item ));
259304 }
260305
261306 /**
@@ -268,8 +313,8 @@ public long getUpperBound(final String item) {
268313 return 0 ;
269314 }
270315
271- byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
272- return getUpperBound (strByte );
316+ final byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
317+ return getUpperBound (strByte );
273318 }
274319
275320 /**
@@ -291,8 +336,7 @@ public long getUpperBound(final byte[] item) {
291336 * @return Lower bound of estimated frequency.
292337 */
293338 public long getLowerBound (final long item ) {
294- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
295- return getLowerBound (longByte );
339+ return getLowerBound (longToBytes (item ));
296340 }
297341
298342 /**
@@ -305,7 +349,7 @@ public long getLowerBound(final String item) {
305349 return 0 ;
306350 }
307351
308- byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
352+ final byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
309353 return getLowerBound (strByte );
310354 }
311355
@@ -327,8 +371,8 @@ public void merge(final CountMinSketch other) {
327371 throw new SketchesException ("Cannot merge a sketch with itself" );
328372 }
329373
330- boolean acceptableConfig = getNumBuckets_ () == other .getNumBuckets_ () &&
331- getNumHashes_ () == other .getNumHashes_ () && getSeed_ () == other .getSeed_ ();
374+ final boolean acceptableConfig = getNumBuckets_ () == other .getNumBuckets_ ()
375+ && getNumHashes_ () == other .getNumHashes_ () && getSeed_ () == other .getSeed_ ();
332376
333377 if (!acceptableConfig ) {
334378 throw new SketchesException ("Incompatible sketch configuration." );
@@ -342,39 +386,56 @@ public void merge(final CountMinSketch other) {
342386 }
343387
344388 /**
345- * Serializes the sketch into the provided ByteBuffer.
346- * @param buf The ByteBuffer to write into.
389+ * Returns the serialized size in bytes.
390+ */
391+ private int getSerializedSizeBytes () {
392+ final int preambleBytes = Family .COUNTMIN .getMinPreLongs () * Long .BYTES ;
393+ if (isEmpty ()) {
394+ return preambleBytes ;
395+ }
396+ return preambleBytes + Long .BYTES + (sketchArray_ .length * Long .BYTES );
397+ }
398+
399+
400+ /**
401+ * Returns the sketch as a byte array.
347402 */
348- public void serialize (ByteArrayOutputStream buf ) {
403+ public byte [] toByteArray () {
404+ final int serializedSizeBytes = getSerializedSizeBytes ();
405+ final byte [] bytes = new byte [serializedSizeBytes ];
406+ final PositionalSegment posSeg = PositionalSegment .wrap (MemorySegment .ofArray (bytes ));
407+
349408 // Long 0
350409 final int preambleLongs = Family .COUNTMIN .getMinPreLongs ();
351- buf . write ((byte ) preambleLongs );
410+ posSeg . setByte ((byte ) preambleLongs );
352411 final int serialVersion = 1 ;
353- buf . write ((byte ) serialVersion );
412+ posSeg . setByte ((byte ) serialVersion );
354413 final int familyId = Family .COUNTMIN .getID ();
355- buf . write ((byte ) familyId );
414+ posSeg . setByte ((byte ) familyId );
356415 final int flagsByte = isEmpty () ? Flag .IS_EMPTY .mask () : 0 ;
357- buf . write ((byte )flagsByte );
416+ posSeg . setByte ((byte ) flagsByte );
358417 final int NULL_32 = 0 ;
359- buf . writeBytes ( ByteBuffer . allocate ( 4 ). putInt ( NULL_32 ). array () );
418+ posSeg . setInt ( NULL_32 );
360419
361420 // Long 1
362- buf . writeBytes ( ByteBuffer . allocate ( 4 ). putInt ( numBuckets_ ). array () );
363- buf . write (numHashes_ );
364- short hashSeed = Util .computeSeedHash (seed_ );
365- buf . writeBytes ( ByteBuffer . allocate ( 2 ). putShort ( hashSeed ). array () );
421+ posSeg . setInt ( numBuckets_ );
422+ posSeg . setByte (numHashes_ );
423+ final short hashSeed = Util .computeSeedHash (seed_ );
424+ posSeg . setShort ( hashSeed );
366425 final byte NULL_8 = 0 ;
367- buf .write (NULL_8 );
426+ posSeg .setByte (NULL_8 );
427+
368428 if (isEmpty ()) {
369- return ;
429+ return bytes ;
370430 }
371431
372- final byte [] totWeightByte = ByteBuffer .allocate (8 ).putLong (totalWeight_ ).array ();
373- buf .writeBytes (totWeightByte );
432+ posSeg .setLong (totalWeight_ );
374433
375- for (long w : sketchArray_ ) {
376- buf . writeBytes ( ByteBuffer . allocate ( 8 ). putLong ( w ). array () );
434+ for (final long w : sketchArray_ ) {
435+ posSeg . setLong ( w );
377436 }
437+
438+ return bytes ;
378439 }
379440
380441 /**
@@ -384,36 +445,51 @@ public void serialize(ByteArrayOutputStream buf) {
384445 * @return The deserialized CountMinSketch.
385446 */
386447 public static CountMinSketch deserialize (final byte [] b , final long seed ) {
387- ByteBuffer buf = ByteBuffer .allocate (b .length );
388- buf .put (b );
389- buf .flip ();
390-
391- final byte preambleLongs = buf .get ();
392- final byte serialVersion = buf .get ();
393- final byte familyId = buf .get ();
394- final byte flagsByte = buf .get ();
395- final int NULL_32 = buf .getInt ();
448+ final PositionalSegment posSeg = PositionalSegment .wrap (MemorySegment .ofArray (b ));
449+
450+ final byte preambleLongs = posSeg .getByte ();
451+ final byte serialVersion = posSeg .getByte ();
452+ final byte familyId = posSeg .getByte ();
453+ final byte flagsByte = posSeg .getByte ();
454+ posSeg .getInt (); // skip NULL_32
455+
456+ // Validate serialization format
457+ final int expectedPreambleLongs = Family .COUNTMIN .getMinPreLongs ();
458+ if (preambleLongs != expectedPreambleLongs ) {
459+ throw new SketchesArgumentException ("Preamble longs mismatch: expected " + expectedPreambleLongs
460+ + ", actual " + preambleLongs );
461+ }
462+ final int expectedSerialVersion = 1 ;
463+ if (serialVersion != expectedSerialVersion ) {
464+ throw new SketchesArgumentException ("Serial version mismatch: expected " + expectedSerialVersion
465+ + ", actual " + serialVersion );
466+ }
467+ final int expectedFamilyId = Family .COUNTMIN .getID ();
468+ if (familyId != expectedFamilyId ) {
469+ throw new SketchesArgumentException ("Family ID mismatch: expected " + expectedFamilyId
470+ + ", actual " + familyId );
471+ }
396472
397- final int numBuckets = buf .getInt ();
398- final byte numHashes = buf . get ();
399- final short seedHash = buf .getShort ();
400- final byte NULL_8 = buf . get ();
473+ final int numBuckets = posSeg .getInt ();
474+ final byte numHashes = posSeg . getByte ();
475+ final short seedHash = posSeg .getShort ();
476+ posSeg . getByte (); // skip NULL_8
401477
402478 if (seedHash != Util .computeSeedHash (seed )) {
403- throw new SketchesArgumentException ("Incompatible seed hashes: " + String . valueOf ( seedHash ) + ", "
404- + String . valueOf ( Util .computeSeedHash (seed ) ));
479+ throw new SketchesArgumentException ("Incompatible seed hashes: " + seedHash + ", "
480+ + Util .computeSeedHash (seed ));
405481 }
406482
407- CountMinSketch cms = new CountMinSketch (numHashes , numBuckets , seed );
483+ final CountMinSketch cms = new CountMinSketch (numHashes , numBuckets , seed );
408484 final boolean empty = (flagsByte & Flag .IS_EMPTY .mask ()) > 0 ;
409485 if (empty ) {
410486 return cms ;
411487 }
412- long w = buf .getLong ();
488+ final long w = posSeg .getLong ();
413489 cms .totalWeight_ = w ;
414490
415491 for (int i = 0 ; i < cms .sketchArray_ .length ; i ++) {
416- cms .sketchArray_ [i ] = buf .getLong ();
492+ cms .sketchArray_ [i ] = posSeg .getLong ();
417493 }
418494
419495 return cms ;
0 commit comments