diff --git a/src/main/java/org/apache/datasketches/sampling/VarOptItemsSketch.java b/src/main/java/org/apache/datasketches/sampling/VarOptItemsSketch.java index a5f6fc909..ad46b4f04 100644 --- a/src/main/java/org/apache/datasketches/sampling/VarOptItemsSketch.java +++ b/src/main/java/org/apache/datasketches/sampling/VarOptItemsSketch.java @@ -331,6 +331,11 @@ public static VarOptItemsSketch heapify(final MemorySegment srcSeg, if (numPreLongs == Family.VAROPT.getMaxPreLongs()) { if (rCount > 0) { totalRWeight = extractTotalRWeight(srcSeg); + if (Double.isNaN(totalRWeight) || (totalRWeight <= 0.0)) { + throw new SketchesArgumentException("Possible Corruption: deserializing in full mode " + + "but invalid R region weight. Found r = " + rCount + + ", R region weight = " + totalRWeight); + } } else { throw new SketchesArgumentException( "Possible Corruption: " diff --git a/src/test/java/org/apache/datasketches/sampling/VarOptItemsSketchTest.java b/src/test/java/org/apache/datasketches/sampling/VarOptItemsSketchTest.java index 43aed7ef8..1331c57a7 100644 --- a/src/test/java/org/apache/datasketches/sampling/VarOptItemsSketchTest.java +++ b/src/test/java/org/apache/datasketches/sampling/VarOptItemsSketchTest.java @@ -288,6 +288,42 @@ public void checkCorruptSerializedWeight() { } } + @Test + public void checkCorruptSerializedRWeightNaN() { + final int k = 32; + final VarOptItemsSketch sketch = getUnweightedLongsVIS(k, k + 1); + final byte[] bytes = sketch.toByteArray(new ArrayOfLongsSerDe()); + final MemorySegment seg = MemorySegment.ofArray(bytes); + assertEquals(PreambleUtil.extractPreLongs(seg), Family.VAROPT.getMaxPreLongs()); + + PreambleUtil.insertTotalRWeight(seg, Double.NaN); + + try { + VarOptItemsSketch.heapify(seg, new ArrayOfLongsSerDe()); + fail(); + } catch (final SketchesArgumentException e) { + assertTrue(e.getMessage().contains("invalid R region weight")); + } + } + + @Test + public void checkCorruptSerializedRWeightZero() { + final int k = 32; + final VarOptItemsSketch sketch = getUnweightedLongsVIS(k, k + 1); + final byte[] bytes = sketch.toByteArray(new ArrayOfLongsSerDe()); + final MemorySegment seg = MemorySegment.ofArray(bytes); + assertEquals(PreambleUtil.extractPreLongs(seg), Family.VAROPT.getMaxPreLongs()); + + PreambleUtil.insertTotalRWeight(seg, 0.0); + + try { + VarOptItemsSketch.heapify(seg, new ArrayOfLongsSerDe()); + fail(); + } catch (final SketchesArgumentException e) { + assertTrue(e.getMessage().contains("invalid R region weight")); + } + } + @Test public void checkCumulativeWeight() { final int k = 256;