diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index da5e765..b9a0ea9 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -1413,6 +1413,8 @@ def get_compressed_input( self_input_ids=None, self_attention_mask=None, ): + if end < iterative_size: + end = iterative_size if self_loss is not None: need_idx = torch.concat( [