From aabda130e0b1949df2d7d698c2e0987ded96fc2f Mon Sep 17 00:00:00 2001 From: cornzz <39997278+cornzz@users.noreply.github.com> Date: Thu, 16 Jan 2025 18:49:27 +0100 Subject: [PATCH] Fix(LLMLingua): fix perplexity calculation and resulting overcompression (#195) --- llmlingua/prompt_compressor.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index da5e765..1169612 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -1702,14 +1702,19 @@ def iterative_compress_prompt( for delta_end, ratio in iterative_ratios[idx]: loss = past_loss + seg_end = end - iterative_size + delta_end + 1 + if seg_end < iterative_size: + seg_end = iterative_size + seg_start = seg_end - iterative_size if condition_compare: self_loss = self_past_loss + self_seg_start, self_seg_end = seg_start - start, seg_end - start threshold = self.get_estimate_threshold_base_distribution( - self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False + self_loss[self_seg_start:self_seg_end] - loss[seg_start:seg_end], ratio, False ) else: threshold = self.get_estimate_threshold_base_distribution( - loss, ratio, False + loss[seg_start:seg_end], ratio, False ) (