Skip to content

Commit 9d1f24f

Browse files
authored
Merge pull request #676 from apache/fix-build-cpc-cms
CMS and CPCxLangTest move to FFM, Build fix
2 parents b163bc8 + 66f1333 commit 9d1f24f

3 files changed

Lines changed: 155 additions & 83 deletions

File tree

src/main/java/org/apache/datasketches/count/CountMinSketch.java

Lines changed: 150 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,30 @@
2222
import org.apache.datasketches.common.Family;
2323
import org.apache.datasketches.common.SketchesArgumentException;
2424
import org.apache.datasketches.common.SketchesException;
25+
import org.apache.datasketches.common.Util;
26+
import org.apache.datasketches.common.positional.PositionalSegment;
2527
import org.apache.datasketches.hash.MurmurHash3;
26-
import org.apache.datasketches.tuple.Util;
2728

28-
import java.io.ByteArrayOutputStream;
29-
import java.nio.ByteBuffer;
29+
import java.lang.foreign.MemorySegment;
3030
import java.nio.charset.StandardCharsets;
3131
import java.util.Random;
3232

33+
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
34+
import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED;
35+
import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED;
36+
import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED;
3337

38+
39+
/**
40+
* Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan.
41+
* This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens.
42+
*
43+
* The CountMin sketch is a probabilistic data structure that provides frequency estimates for items
44+
* in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array,
45+
* providing approximate counts with configurable error bounds.
46+
*
47+
* Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf
48+
*/
3449
public class CountMinSketch {
3550
private final byte numHashes_;
3651
private final int numBuckets_;
@@ -39,6 +54,9 @@ public class CountMinSketch {
3954
private final long[] sketchArray_;
4055
private long totalWeight_;
4156

57+
// Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
58+
private static final ThreadLocal<MemorySegment> LONG_SEGMENT =
59+
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[Long.BYTES]));
4260

4361
private enum Flag {
4462
IS_EMPTY;
@@ -57,35 +75,64 @@ int mask() {
5775
* @param seed The base hash seed
5876
*/
5977
CountMinSketch(final byte numHashes, final int numBuckets, final long seed) {
60-
numHashes_ = numHashes;
61-
numBuckets_ = numBuckets;
62-
seed_ = seed;
63-
hashSeeds_ = new long[numHashes];
64-
sketchArray_ = new long[numHashes * numBuckets];
65-
totalWeight_ = 0;
78+
// Validate numHashes
79+
if (numHashes <= 0) {
80+
throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes);
81+
}
6682

83+
// Validate numBuckets with clear mathematical justification
84+
if (numBuckets <= 0) {
85+
throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets);
86+
}
6787
if (numBuckets < 3) {
68-
throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1.");
88+
throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " +
89+
"With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets));
90+
}
91+
92+
// Check for potential overflow in array size calculation
93+
// Use long arithmetic to detect overflow before casting
94+
final long totalSize = (long) numHashes * (long) numBuckets;
95+
if (totalSize > Integer.MAX_VALUE) {
96+
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets
97+
+ " = " + totalSize + " > " + Integer.MAX_VALUE);
6998
}
7099

71100
// This check is to ensure later compatibility with a Java implementation whose maximum size can only
72101
// be 2^31-1. We check only against 2^30 for simplicity.
73-
if (numBuckets * numHashes >= 1 << 30) {
74-
throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
75-
"Try reducing either the number of buckets or the number of hash functions.");
102+
if (totalSize >= (1L << 30)) {
103+
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
104+
+ " = " + totalSize + " elements (~" + String.format("%d", totalSize * Long.BYTES / (1024 * 1024 * 1024)) + " GB). "
105+
+ "Consider reducing numHashes or numBuckets.");
76106
}
77107

78-
Random rand = new Random(seed);
108+
numHashes_ = numHashes;
109+
numBuckets_ = numBuckets;
110+
seed_ = seed;
111+
hashSeeds_ = new long[numHashes];
112+
sketchArray_ = new long[(int) totalSize];
113+
totalWeight_ = 0;
114+
115+
final Random rand = new Random(seed);
79116
for (int i = 0; i < numHashes; i++) {
80117
hashSeeds_[i] = rand.nextLong();
81118
}
82119
}
83120

84-
private long[] getHashes(byte[] item) {
85-
long[] updateLocations = new long[numHashes_];
121+
/**
122+
* Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
123+
*/
124+
private static byte[] longToBytes(final long value) {
125+
final MemorySegment segment = LONG_SEGMENT.get();
126+
segment.set(JAVA_LONG_UNALIGNED, 0, value);
127+
return segment.toArray(JAVA_BYTE);
128+
}
129+
130+
131+
private long[] getHashes(final byte[] item) {
132+
final long[] updateLocations = new long[numHashes_];
86133

87134
for (int i = 0; i < numHashes_; i++) {
88-
long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
135+
final long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
89136
updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_);
90137
}
91138

@@ -145,11 +192,11 @@ public double getRelativeError() {
145192
* @param confidence The desired confidence level between 0 and 1.
146193
* @return Suggested number of hash functions.
147194
*/
148-
public static byte suggestNumHashes(double confidence) {
195+
public static byte suggestNumHashes(final double confidence) {
149196
if (confidence < 0 || confidence > 1) {
150197
throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive).");
151198
}
152-
int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
199+
final int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
153200
return (byte) Math.min(value, 127);
154201
}
155202

@@ -158,7 +205,7 @@ public static byte suggestNumHashes(double confidence) {
158205
* @param relativeError The desired relative error.
159206
* @return Suggested number of buckets.
160207
*/
161-
public static int suggestNumBuckets(double relativeError) {
208+
public static int suggestNumBuckets(final double relativeError) {
162209
if (relativeError < 0.) {
163210
throw new SketchesException("Relative error must be at least 0.");
164211
}
@@ -171,8 +218,7 @@ public static int suggestNumBuckets(double relativeError) {
171218
* @param weight The weight of the item.
172219
*/
173220
public void update(final long item, final long weight) {
174-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
175-
update(longByte, weight);
221+
update(longToBytes(item), weight);
176222
}
177223

178224
/**
@@ -199,8 +245,8 @@ public void update(final byte[] item, final long weight) {
199245
}
200246

201247
totalWeight_ += weight > 0 ? weight : -weight;
202-
long[] hashLocations = getHashes(item);
203-
for (long h : hashLocations) {
248+
final long[] hashLocations = getHashes(item);
249+
for (final long h : hashLocations) {
204250
sketchArray_[(int) h] += weight;
205251
}
206252
}
@@ -211,8 +257,7 @@ public void update(final byte[] item, final long weight) {
211257
* @return Estimated frequency.
212258
*/
213259
public long getEstimate(final long item) {
214-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
215-
return getEstimate(longByte);
260+
return getEstimate(longToBytes(item));
216261
}
217262

218263
/**
@@ -239,10 +284,11 @@ public long getEstimate(final byte[] item) {
239284
return 0;
240285
}
241286

242-
long[] hashLocations = getHashes(item);
287+
final long[] hashLocations = getHashes(item);
243288
long res = sketchArray_[(int) hashLocations[0]];
244-
for (long h : hashLocations) {
245-
res = Math.min(res, sketchArray_[(int) h]);
289+
// Start from index 1 to avoid processing first element twice
290+
for (int i = 1; i < hashLocations.length; i++) {
291+
res = Math.min(res, sketchArray_[(int) hashLocations[i]]);
246292
}
247293

248294
return res;
@@ -254,8 +300,7 @@ public long getEstimate(final byte[] item) {
254300
* @return Upper bound of estimated frequency.
255301
*/
256302
public long getUpperBound(final long item) {
257-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
258-
return getUpperBound(longByte);
303+
return getUpperBound(longToBytes(item));
259304
}
260305

261306
/**
@@ -268,8 +313,8 @@ public long getUpperBound(final String item) {
268313
return 0;
269314
}
270315

271-
byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
272-
return getUpperBound(strByte);
316+
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
317+
return getUpperBound(strByte);
273318
}
274319

275320
/**
@@ -291,8 +336,7 @@ public long getUpperBound(final byte[] item) {
291336
* @return Lower bound of estimated frequency.
292337
*/
293338
public long getLowerBound(final long item) {
294-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
295-
return getLowerBound(longByte);
339+
return getLowerBound(longToBytes(item));
296340
}
297341

298342
/**
@@ -305,7 +349,7 @@ public long getLowerBound(final String item) {
305349
return 0;
306350
}
307351

308-
byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
352+
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
309353
return getLowerBound(strByte);
310354
}
311355

@@ -327,8 +371,8 @@ public void merge(final CountMinSketch other) {
327371
throw new SketchesException("Cannot merge a sketch with itself");
328372
}
329373

330-
boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() &&
331-
getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();
374+
final boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_()
375+
&& getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();
332376

333377
if (!acceptableConfig) {
334378
throw new SketchesException("Incompatible sketch configuration.");
@@ -342,39 +386,56 @@ public void merge(final CountMinSketch other) {
342386
}
343387

344388
/**
345-
* Serializes the sketch into the provided ByteBuffer.
346-
* @param buf The ByteBuffer to write into.
389+
* Returns the serialized size in bytes.
390+
*/
391+
private int getSerializedSizeBytes() {
392+
final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES;
393+
if (isEmpty()) {
394+
return preambleBytes;
395+
}
396+
return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES);
397+
}
398+
399+
400+
/**
401+
* Returns the sketch as a byte array.
347402
*/
348-
public void serialize(ByteArrayOutputStream buf) {
403+
public byte[] toByteArray() {
404+
final int serializedSizeBytes = getSerializedSizeBytes();
405+
final byte[] bytes = new byte[serializedSizeBytes];
406+
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes));
407+
349408
// Long 0
350409
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
351-
buf.write((byte) preambleLongs);
410+
posSeg.setByte((byte) preambleLongs);
352411
final int serialVersion = 1;
353-
buf.write((byte) serialVersion);
412+
posSeg.setByte((byte) serialVersion);
354413
final int familyId = Family.COUNTMIN.getID();
355-
buf.write((byte) familyId);
414+
posSeg.setByte((byte) familyId);
356415
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
357-
buf.write((byte)flagsByte);
416+
posSeg.setByte((byte) flagsByte);
358417
final int NULL_32 = 0;
359-
buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array());
418+
posSeg.setInt(NULL_32);
360419

361420
// Long 1
362-
buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array());
363-
buf.write(numHashes_);
364-
short hashSeed = Util.computeSeedHash(seed_);
365-
buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array());
421+
posSeg.setInt(numBuckets_);
422+
posSeg.setByte(numHashes_);
423+
final short hashSeed = Util.computeSeedHash(seed_);
424+
posSeg.setShort(hashSeed);
366425
final byte NULL_8 = 0;
367-
buf.write(NULL_8);
426+
posSeg.setByte(NULL_8);
427+
368428
if (isEmpty()) {
369-
return;
429+
return bytes;
370430
}
371431

372-
final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array();
373-
buf.writeBytes(totWeightByte);
432+
posSeg.setLong(totalWeight_);
374433

375-
for (long w: sketchArray_) {
376-
buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array());
434+
for (final long w: sketchArray_) {
435+
posSeg.setLong(w);
377436
}
437+
438+
return bytes;
378439
}
379440

380441
/**
@@ -384,36 +445,51 @@ public void serialize(ByteArrayOutputStream buf) {
384445
* @return The deserialized CountMinSketch.
385446
*/
386447
public static CountMinSketch deserialize(final byte[] b, final long seed) {
387-
ByteBuffer buf = ByteBuffer.allocate(b.length);
388-
buf.put(b);
389-
buf.flip();
390-
391-
final byte preambleLongs = buf.get();
392-
final byte serialVersion = buf.get();
393-
final byte familyId = buf.get();
394-
final byte flagsByte = buf.get();
395-
final int NULL_32 = buf.getInt();
448+
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b));
449+
450+
final byte preambleLongs = posSeg.getByte();
451+
final byte serialVersion = posSeg.getByte();
452+
final byte familyId = posSeg.getByte();
453+
final byte flagsByte = posSeg.getByte();
454+
posSeg.getInt(); // skip NULL_32
455+
456+
// Validate serialization format
457+
final int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs();
458+
if (preambleLongs != expectedPreambleLongs) {
459+
throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs
460+
+ ", actual " + preambleLongs);
461+
}
462+
final int expectedSerialVersion = 1;
463+
if (serialVersion != expectedSerialVersion) {
464+
throw new SketchesArgumentException("Serial version mismatch: expected " + expectedSerialVersion
465+
+ ", actual " + serialVersion);
466+
}
467+
final int expectedFamilyId = Family.COUNTMIN.getID();
468+
if (familyId != expectedFamilyId) {
469+
throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId
470+
+ ", actual " + familyId);
471+
}
396472

397-
final int numBuckets = buf.getInt();
398-
final byte numHashes = buf.get();
399-
final short seedHash = buf.getShort();
400-
final byte NULL_8 = buf.get();
473+
final int numBuckets = posSeg.getInt();
474+
final byte numHashes = posSeg.getByte();
475+
final short seedHash = posSeg.getShort();
476+
posSeg.getByte(); // skip NULL_8
401477

402478
if (seedHash != Util.computeSeedHash(seed)) {
403-
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
404-
+ String.valueOf(Util.computeSeedHash(seed)));
479+
throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", "
480+
+ Util.computeSeedHash(seed));
405481
}
406482

407-
CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
483+
final CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
408484
final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0;
409485
if (empty) {
410486
return cms;
411487
}
412-
long w = buf.getLong();
488+
final long w = posSeg.getLong();
413489
cms.totalWeight_ = w;
414490

415491
for (int i = 0; i < cms.sketchArray_.length; i++) {
416-
cms.sketchArray_[i] = buf.getLong();
492+
cms.sketchArray_[i] = posSeg.getLong();
417493
}
418494

419495
return cms;

0 commit comments

Comments
 (0)