diff --git a/README.md b/README.md
index 0c32850..1ca17bc 100644
--- a/README.md
+++ b/README.md
@@ -48,6 +48,11 @@ LLMLingua-2, a small-size yet powerful prompt compression method trained via dat
- [LLMLingua-2: Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression](https://aclanthology.org/2024.findings-acl.57/) (ACL 2024 Findings)
_Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang_
+SecurityLingua is a safety guardrail model that uses the security-aware prompt compression to reveal the malicious intentions behind jailbreak attacks, enabling LLMs to detect attacks and generate safe responses. Due to the highly efficient prompt compression, the defense involves negligible overhead and 100x less token costs compared to state-of-the-art LLM guardrail approaches.
+
+- [SecurityLingua: Efficient Defense of LLM Jailbreak Attacks via Security-Aware Prompt Compression](https://openreview.net/forum?id=tybbSo6wba) (CoLM 2025)
+ _Yucheng Li, Surin Ahn, Huiqiang Jiang, Amir H. Abdi, Yuqing Yang and Lili Qiu_
+
## 🎥 Overview

@@ -133,6 +138,16 @@ If you find this repo helpful, please cite the following papers:
}
```
+```bibtex
+@inproceedings{li2025securitylingua,
+ title={{S}ecurity{L}ingua: Efficient Defense of {LLM} Jailbreak Attacks via Security-Aware Prompt Compression},
+ author={Yucheng Li and Surin Ahn and Huiqiang Jiang and Amir H. Abdi and Yuqing Yang and Lili Qiu},
+ booktitle={Second Conference on Language Modeling},
+ year={2025},
+ url={https://openreview.net/forum?id=tybbSo6wba}
+}
+```
+
## 🎯 Quick Start
#### 1. **Installing LLMLingua:**
@@ -205,6 +220,20 @@ llm_lingua = PromptCompressor(
)
```
+To try **SecurityLingua** in your scenarios, you can use
+
+```python
+from llmlingua import PromptCompressor
+
+securitylingua = PromptCompressor(
+ model_name="SecurityLingua/securitylingua-xlm-s2s",
+ use_slingua=True
+)
+intention = securitylingua.compress_prompt(malicious_prompt)
+```
+
+For more details about SecurityLingua, please refer to [securitylingua readme](./experiments/securitylingua/readme.md).
+
#### 3. **Advanced usage - Structured Prompt Compression:**
Split text into sections, decide on whether to compress and its rate. Use `` tags for context segmentation, with optional rate and compress parameters.
diff --git a/experiments/securitylingua/env_setup.sh b/experiments/securitylingua/env_setup.sh
new file mode 100644
index 0000000..18a7530
--- /dev/null
+++ b/experiments/securitylingua/env_setup.sh
@@ -0,0 +1,13 @@
+conda create -n llmlingua python=3.10 -y && conda activate llmlingua
+pip install -e .
+pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
+pip install accelerate wandb
+pip install openai==0.28
+
+pip install spacy
+python -m spacy download en_core_web_sm
+pip install scikit-learn
+pip install tensorboard
+pip install datasets hf_transfer
+
+unset WANDB_RUN_ID WANDB_RUN_GROUP WANDB_PROJECT WANDB_NOTES WANDB_NAME
diff --git a/experiments/securitylingua/filter.py b/experiments/securitylingua/filter.py
new file mode 100644
index 0000000..a3ddcb1
--- /dev/null
+++ b/experiments/securitylingua/filter.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+from collections import defaultdict
+from typing import Dict, List, Tuple, DefaultDict
+import numpy as np
+import torch
+
+def parse_arguments() -> argparse.Namespace:
+ """Parse command line arguments"""
+ parser = argparse.ArgumentParser(description="Filter compressed prompts based on metrics.")
+ parser.add_argument(
+ "--load_path",
+ help="path to load data",
+ default="../../../results/meetingbank/gpt-4-32k_comp/annotation_cs512_meetingbank_train_formated.pt",
+ )
+ parser.add_argument(
+ "--save_path",
+ help="path to save filtered data",
+ default="../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt",
+ )
+ parser.add_argument(
+ "--percentile",
+ help="percentile threshold for filtering",
+ default=90,
+ type=int
+ )
+ return parser.parse_args()
+
+def filter_by_metric(
+ data: DefaultDict[str, List],
+ metric_name: str,
+ percentile: float
+) -> Tuple[DefaultDict[str, List], DefaultDict[str, List]]:
+ """
+ Filter data based on a specific metric and percentile threshold
+
+ Args:
+ data: Dictionary containing all data points and their metrics
+ metric_name: Name of the metric to filter by
+ percentile: Percentile threshold for filtering
+
+ Returns:
+ Tuple of (kept_data, filtered_data)
+ """
+ metric_list = data[metric_name]
+ threshold = np.percentile(metric_list, percentile)
+
+ kept = defaultdict(list)
+ filtered = defaultdict(list)
+
+ # List of all metrics to transfer
+ metrics = [
+ "labels", "origin", "comp", "retrieval", "comp_rate",
+ "variation_rate", "hitting_rate", "matching_rate", "alignment_gap"
+ ]
+
+ for values in zip(*(data[metric] for metric in metrics)):
+ # Create a dictionary of current values
+ current = dict(zip(metrics, values))
+
+ # Determine which container to use based on the metric threshold
+ target = filtered if current[metric_name] >= threshold else kept
+
+ # Add values to appropriate container
+ for metric, value in current.items():
+ target[metric].append(value)
+
+ return kept, filtered
+
+def main():
+ """Main function to run the filtering process"""
+ args = parse_arguments()
+
+ # Load data
+ res_pt = torch.load(args.load_path, weights_only=False)
+ print(f"Initial sample count: {len(res_pt['variation_rate'])}")
+
+ # First filtering stage: variation rate
+ kept, filtered = filter_by_metric(
+ data=res_pt,
+ metric_name="variation_rate",
+ percentile=args.percentile
+ )
+
+ # Second filtering stage: alignment gap
+ final_kept, additional_filtered = filter_by_metric(
+ data=kept,
+ metric_name="alignment_gap",
+ percentile=args.percentile
+ )
+
+ # Save filtered results
+ torch.save(final_kept, args.save_path)
+
+ # Print statistics
+ print(f"Samples after first filter: {len(kept['variation_rate'])}")
+ print(f"Final kept samples: {len(final_kept['variation_rate'])}")
+
+if __name__ == "__main__":
+ main()
diff --git a/experiments/securitylingua/label_word.py b/experiments/securitylingua/label_word.py
new file mode 100644
index 0000000..94c9b4c
--- /dev/null
+++ b/experiments/securitylingua/label_word.py
@@ -0,0 +1,236 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import json
+import logging
+import os
+from collections import defaultdict
+from typing import List, Set, Tuple, Dict, Any
+from datasets import load_dataset
+import spacy
+import torch
+from tqdm import tqdm
+from multiprocessing import Pool
+import multiprocessing
+
+def setup_logging(save_path: str) -> logging.Logger:
+ """Setup logging configuration"""
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ logging.basicConfig(
+ filename=f"{os.path.dirname(save_path)}/log.log",
+ level=logging.INFO,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ )
+ return logging.getLogger()
+
+def parse_arguments() -> argparse.Namespace:
+ """Parse command line arguments"""
+ parser = argparse.ArgumentParser(description="annotate token")
+ parser.add_argument(
+ "--dataset_name", help="dataset used to compress", default="meetingbank"
+ )
+ parser.add_argument("--split", help="dataset part", default="train")
+ parser.add_argument(
+ "--load_prompt_from",
+ help="where to load compressed prompt",
+ default="results/meetingbank/origin-comp-list_llmcomp_cs512.json",
+ )
+ parser.add_argument(
+ "--save_path",
+ help="path to save results",
+ default="results/meetingbank/annotation/label_word.json",
+ )
+ parser.add_argument("--window_size", help="window size", type=int, default=150)
+ parser.add_argument(
+ "--verbose",
+ help="print debug info",
+ action=argparse.BooleanOptionalAction,
+ default=False,
+ )
+ parser.add_argument(
+ "--max_samples",
+ help="max samples",
+ type=int,
+ default=1000,
+ )
+ return parser.parse_args()
+
+def split_string(input_string: str, ignore_tokens: Set[str] = {","}) -> List[str]:
+ """Split string into tokens using spaCy"""
+ doc = nlp(input_string)
+ return [word.lemma_ for word in doc if word.lemma_ not in ignore_tokens]
+
+def is_equal(token1: str, token2: str) -> bool:
+ """Compare tokens case-insensitively"""
+ return token1.lower() == token2.lower()
+
+def load_data(load_path: str, max_samples: int = 1000) -> Tuple[List[str], List[str]]:
+ """Load and prepare dataset"""
+ origins, comps = [], []
+ dataset = load_dataset(load_path, split="train")
+
+ for i, sample in enumerate(dataset):
+ if len(sample["prompt_list"]) != len(sample["compressed_prompt_list"]):
+ print(f"{i}-th length not equal")
+ continue
+
+ origins.extend(sample["prompt_list"])
+ comps.extend(sample["compressed_prompt_list"])
+
+ if len(origins) > max_samples:
+ break
+
+ return origins, comps
+
+def load_secure_data(load_path: str, max_samples: int = 1000) -> Tuple[List[str], List[str]]:
+ """Load and prepare dataset"""
+ origins, comps = [], []
+ if "json" in load_path:
+ with open(load_path, "r") as f:
+ dataset = json.load(f)
+ else:
+ dataset = load_dataset(load_path, split="train")
+
+ for sample in dataset:
+ origins.append(sample["extended"])
+ comps.append(sample["original"])
+
+ if max_samples != -1 and len(origins) > max_samples:
+ break
+
+ return origins, comps
+
+def process_sample(origin: str, comp: str, window_size: int, verbose: bool = False) -> Dict[str, Any]:
+ """Process a single sample pair"""
+ origin_tokens = split_string(origin)
+ comp_tokens = split_string(comp)
+ origin_tokens_set = set(origin_tokens) | set(token.lower() for token in origin_tokens)
+
+ num_find = 0
+ prev_idx = 0
+ num_origin_tokens = len(origin_tokens)
+ labels = [False] * num_origin_tokens
+
+ # Token matching logic
+ for token in comp_tokens:
+ if token in origin_tokens_set or token.lower() in origin_tokens_set:
+ num_find += 1
+
+ for i in range(window_size):
+ # Forward and backward token matching
+ for token_idx in [
+ min(prev_idx + i, num_origin_tokens - 1),
+ max(prev_idx - i, 0)
+ ]:
+ if is_equal(origin_tokens[token_idx], token) and not labels[token_idx]:
+ labels[token_idx] = True
+ prev_idx = token_idx if token_idx < prev_idx else (
+ token_idx if token_idx - prev_idx <= window_size // 2
+ else prev_idx + window_size // 2
+ )
+
+ if verbose:
+ print(f"{token}, {token_idx}, {prev_idx}, {origin_tokens[token_idx - 1 : token_idx + 2]}")
+ break
+ else:
+ continue
+ break
+
+ # Calculate metrics
+ retrieval_tokens = [token for idx, token in enumerate(origin_tokens) if labels[idx]]
+ retrieval = " ".join(retrieval_tokens)
+
+ metrics = calculate_metrics(
+ len(comp_tokens), len(origin_tokens), num_find, labels
+ )
+
+ return {
+ "labels": labels,
+ "origin": origin,
+ "comp": comp,
+ "retrieval": retrieval,
+ "origin_tokens": origin_tokens,
+ **metrics
+ }
+
+def calculate_metrics(comp_len: int, origin_len: int, num_find: int, labels: List[bool]) -> Dict[str, float]:
+ """Calculate various metrics for the compression"""
+ comp_rate = comp_len / origin_len if origin_len > 0 else 0
+ find_rate = num_find / comp_len if comp_len > 0 else 0
+ variation_rate = 1 - find_rate
+ hitting_rate = num_find / origin_len if origin_len > 0 else 0
+ matching_rate = sum(labels) / len(labels) if labels else 0
+ alignment_gap = hitting_rate - matching_rate
+
+ return {
+ "comp_rate": comp_rate,
+ "variation_rate": variation_rate,
+ "hitting_rate": hitting_rate,
+ "matching_rate": matching_rate,
+ "alignment_gap": alignment_gap
+ }
+
+def process_chunk(args):
+ """Worker function for multiprocessing"""
+ chunk_idx, (origin, comp), window_size, verbose = args
+ if not origin or not comp:
+ return None
+
+ sample_results = process_sample(origin, comp, window_size, verbose)
+ return chunk_idx, sample_results
+
+def main():
+ args = parse_arguments()
+ logger = setup_logging(args.save_path)
+
+ # origins, comps = load_data(args.load_prompt_from, args.max_samples)
+ origins, comps = load_secure_data(args.load_prompt_from, args.max_samples)
+ print(f'origins: {len(origins)}')
+ print(f'comps: {len(comps)}')
+
+ res = {}
+ res_pt = defaultdict(list)
+ metrics_sum = defaultdict(float)
+
+ # Prepare arguments for multiprocessing
+ process_args = [
+ (idx, (origin, comp), args.window_size, args.verbose)
+ for idx, (origin, comp) in enumerate(zip(origins, comps))
+ ]
+
+ num_processes = 24
+
+ with Pool(num_processes) as pool:
+ for chunk_result in tqdm(pool.imap(process_chunk, process_args), total=len(process_args)):
+ if chunk_result is None:
+ continue
+
+ chunk_idx, sample_results = chunk_result
+
+ # Store results in memory
+ res[chunk_idx] = sample_results
+ for key, value in sample_results.items():
+ res_pt[key].append(value)
+
+ # Update running metrics
+ for key in ["comp_rate", "variation_rate", "hitting_rate", "matching_rate", "alignment_gap"]:
+ metrics_sum[key] += sample_results[key]
+
+ # Save all results at once at the end
+ json.dump(res, open(args.save_path, "w"), indent=4)
+ torch.save(res_pt, args.save_path.replace(".json", ".pt"))
+
+ # Log final metrics
+ num_samples = len(origins)
+ metrics_avg = {k: v/num_samples for k, v in metrics_sum.items()}
+ print_info = (f"window size: {args.window_size}, "
+ f"comp rate: {metrics_avg['comp_rate']:.3f}, "
+ f"hitting_rate: {metrics_avg['hitting_rate']:.3f}, "
+ f"retrieval rate: {metrics_avg['matching_rate']:.3f}")
+ print(print_info)
+ logger.info(print_info)
+
+if __name__ == "__main__":
+ nlp = spacy.load("en_core_web_sm")
+ main()
\ No newline at end of file
diff --git a/experiments/securitylingua/readme.md b/experiments/securitylingua/readme.md
new file mode 100644
index 0000000..30b1d28
--- /dev/null
+++ b/experiments/securitylingua/readme.md
@@ -0,0 +1,76 @@
+# SecurityLingua
+
+To use securitylingua to safeguard your LLM, please follow this simple two steps instruction:
+
+```
+
+# 0. first load the securitylingua model
+from llmlingua import PromptCompressor
+llm_lingua = PromptCompressor(
+ model_name="SecurityLingua/securitylingua-xlm-s2s",
+ use_slingua=True
+)
+
+# 1. compress the prompt to reveal the malicious intention
+intention = llm_lingua.compress_prompt(malicious_prompt)
+
+# 2. construct the augmented system prompt, to provide the LLM with the malicious intention
+augmented_system_prompt = f"{system_prompt}\n\nTo help you better understand the user's intention to detect potential malicious behavior, I have extracted the user's intention and it is: {intention}. If you believe the user's intention is malicious, please donot respond or respond with I'm sorry, I can't help with that."
+
+# at last, chat with the LLM using the augmented system prompt
+response = vllm.generate([
+ augmented_system_prompt + malicious_prompt
+])
+```
+
+# Train a SecurityLingua on your own data
+
+1. setup environment
+
+```bash
+bash env_setup.sh
+```
+
+2. build your own data for securitylingua training
+
+```bash
+python label_word.py \
+ --load_prompt_from SecurityLingua/securitylingua-jailbreak-pairs \
+ --window_size 400 \
+ --save_path ../results/security_lingua/jailbreak_pairs_annotated.pt
+
+python filter.py \
+ --load_path ../results/security_lingua/jailbreak_pairs_annotated.pt \
+ --save_path ../results/security_lingua/jailbreak_pairs_annotated_filtered.pt
+```
+
+refer to [securitylingua-jailbreak-pairs](https://huggingface.co/datasets/SecurityLingua/securitylingua-jailbreak-pairs) for the format of the dataset before parsing.
+
+you can also finetune the filtering threshold in [filter.py](filter.py) to trade off performance and security.
+
+3. train the securitylingua model
+
+```bash
+python train_roberta.py \
+ --data_path ../results/security_lingua/jailbreak_pairs_annotated_filtered.pt \
+ --save_path ../results/security_lingua/jailbreak_pairs_annotated_filtered_roberta.pt \
+ --model_name microsoft/llmlingua-2-xlm-roberta-large-meetingbank \
+ --num_epoch 5 \
+ --run_name meetbank_slingua \
+ --wandb_project slingua \
+ --wandb_name meetbank_slingua
+```
+
+or you can do multi-GPU training with
+
+```bash
+ACCELERATE_LOG_LEVEL="ERROR" accelerate launch --num_processes 4 experiments/llmlingua2/model_training/train_roberta.py \
+ --data_path experiments/llmlingua2/results/security_lingua/jailbreak_pairs_annotated_filtered.pt \
+ --save_path experiments/llmlingua2/results/models/xlm_slingua.pth \
+ --num_epoch 5 \
+ --run_name xlm_slingua \
+ --wandb_project slingua \
+ --wandb_name xlm_slingua
+```
+
+4. At last load your own checkpoint and use it in `PromptCompressor` (see above for usage)
\ No newline at end of file
diff --git a/experiments/securitylingua/run.sh b/experiments/securitylingua/run.sh
new file mode 100644
index 0000000..1fc9c11
--- /dev/null
+++ b/experiments/securitylingua/run.sh
@@ -0,0 +1,28 @@
+python label_word.py \
+ --load_prompt_from liyucheng/jailbreak-pairs \
+ --window_size 400 \
+ --save_path ../results/security_lingua/jailbreak_pairs_annotated.pt
+
+python filter.py \
+ --load_path ../results/security_lingua/jailbreak_pairs_annotated.pt \
+ --save_path ../results/security_lingua/jailbreak_pairs_annotated_filtered.pt
+
+python train_roberta.py \
+ --data_path ../results/security_lingua/jailbreak_pairs_annotated_filtered.pt \
+ --save_path ../results/security_lingua/jailbreak_pairs_annotated_filtered_roberta.pt \
+ --model_name microsoft/llmlingua-2-xlm-roberta-large-meetingbank \
+ --num_epoch 5 \
+ --run_name meetbank_slingua \
+ --wandb_project slingua \
+ --wandb_name meetbank_slingua
+
+# Multi-GPU training
+# ACCELERATE_LOG_LEVEL="ERROR" accelerate launch --num_processes 4 experiments/llmlingua2/model_training/train_roberta.py \
+# --data_path experiments/llmlingua2/results/security_lingua/v8_filtered.pt \
+# --save_path experiments/llmlingua2/results/models/xlm_slingua_v8.pth \
+# --num_epoch 5 \
+# --run_name xlm_slingua_v8 \
+# --wandb_project slingua \
+# --wandb_name xlm_slingua_v8 \
+# --batch_size 32
+
diff --git a/experiments/securitylingua/train_roberta.py b/experiments/securitylingua/train_roberta.py
new file mode 100644
index 0000000..b7ce4e4
--- /dev/null
+++ b/experiments/securitylingua/train_roberta.py
@@ -0,0 +1,326 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import argparse
+import os
+import random
+import time
+from typing import List, Tuple, Dict
+
+import torch
+from sklearn.metrics import accuracy_score
+from torch import cuda
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+from transformers import AutoModelForTokenClassification, AutoTokenizer
+from utils import TokenClfDataset
+import wandb
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+
+# Constants
+MAX_LEN = 512
+MAX_GRAD_NORM = 10
+
+def parse_arguments() -> argparse.Namespace:
+ """Parse command line arguments"""
+ parser = argparse.ArgumentParser(
+ description="train bert to do compression (by token classification)"
+ )
+ parser.add_argument(
+ "--model_name",
+ help="token classification model",
+ default="FacebookAI/xlm-roberta-large",
+ )
+ parser.add_argument(
+ "--data_path",
+ help="training and validation data path",
+ default="../../../results/meetingbank/gpt-4-32k_comp/annotation_kept_cs512_meetingbank_train_formated.pt",
+ )
+ parser.add_argument(
+ "--label_type",
+ help="word label or token label",
+ default="word_label",
+ choices=["word_label", "token_label"],
+ )
+ parser.add_argument(
+ "--save_path",
+ help="save path",
+ default="../../../results/models/xlm_roberta_large_meetingbank_only.pth",
+ )
+ parser.add_argument(
+ "--run_name",
+ help="run name",
+ default="xlm_roberta_large_meetingbank_only",
+ )
+ parser.add_argument("--lr", help="learning rate", default=1e-5, type=float)
+ parser.add_argument(
+ "--num_epoch", help="number of training epoch", default=10, type=int
+ )
+ parser.add_argument("--batch_size", type=int, default=10)
+ parser.add_argument(
+ "--wandb_project",
+ help="wandb project name. If not provided, wandb will not be used",
+ default=None,
+ type=str,
+ )
+ parser.add_argument(
+ "--wandb_name",
+ help="wandb run name",
+ default=None,
+ type=str,
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="Random seed for initialization",
+ )
+ return parser.parse_args()
+
+def setup_wandb(project: str, name: str, accelerator: Accelerator) -> bool:
+ """Setup wandb tracking if project is provided and only on main process"""
+ if project is None:
+ return False
+
+ # Only init wandb on main process
+ if accelerator.is_main_process:
+ import wandb
+ wandb.init(project=project, name=name)
+ return True
+
+def load_and_split_data(data_path: str, seed: int = 42) -> Tuple[List[Tuple[str, List]], List[Tuple[str, List]]]:
+ """Load and split data into train and validation sets"""
+ data = torch.load(data_path, weights_only=False)
+ text_label = [(text, label) for text, label in zip(data["origin"], data["labels"])]
+ set_seed(seed)
+ random.shuffle(text_label)
+
+ split_idx = int(len(text_label) * 0.9)
+ train_data = text_label[:split_idx]
+ val_data = text_label[split_idx:]
+
+ return train_data, val_data
+
+def prepare_datasets(
+ train_data: List[Tuple[str, List]],
+ val_data: List[Tuple[str, List]],
+ tokenizer,
+ model_name: str
+) -> Tuple[TokenClfDataset, TokenClfDataset]:
+ """Prepare training and validation datasets"""
+ train_text = [text for text, label in train_data]
+ train_label = [label for text, label in train_data]
+ val_text = [text for text, label in val_data]
+ val_label = [label for text, label in val_data]
+
+ train_dataset = TokenClfDataset(
+ train_text, train_label, MAX_LEN, tokenizer=tokenizer, model_name=model_name
+ )
+ val_dataset = TokenClfDataset(
+ val_text, val_label, MAX_LEN, tokenizer=tokenizer, model_name=model_name
+ )
+
+ return train_dataset, val_dataset
+
+def train_epoch(
+ model: AutoModelForTokenClassification,
+ train_dataloader: DataLoader,
+ optimizer: torch.optim.Optimizer,
+ accelerator: Accelerator,
+ epoch: int,
+ use_wandb: bool = False,
+) -> None:
+ """Train for one epoch"""
+ model.train()
+ tr_loss, tr_accuracy = 0, 0
+ nb_tr_steps = 0
+
+ # Get num_labels from unwrapped model
+ num_labels = accelerator.unwrap_model(model).num_labels
+
+ # Add progress bar
+ progress_bar = tqdm(
+ train_dataloader,
+ desc=f"Training Epoch {epoch}",
+ disable=not accelerator.is_local_main_process
+ )
+
+ for batch in progress_bar:
+ outputs = model(
+ input_ids=batch["ids"],
+ attention_mask=batch["mask"],
+ labels=batch["targets"]
+ )
+ loss, tr_logits = outputs.loss, outputs.logits
+
+ accelerator.backward(loss)
+
+ tr_loss += loss.item()
+ nb_tr_steps += 1
+
+ # Calculate accuracy
+ flattened_targets = batch["targets"].view(-1)
+ active_logits = tr_logits.view(-1, num_labels)
+ flattened_predictions = torch.argmax(active_logits, axis=1)
+ active_accuracy = batch["mask"].view(-1) == 1
+ targets = torch.masked_select(flattened_targets, active_accuracy)
+ predictions = torch.masked_select(flattened_predictions, active_accuracy)
+
+ tmp_tr_accuracy = accuracy_score(
+ targets.cpu().numpy(), predictions.cpu().numpy()
+ )
+ tr_accuracy += tmp_tr_accuracy
+
+ # Update progress bar
+ progress_bar.set_postfix({
+ 'loss': f'{tr_loss/nb_tr_steps:.4f}',
+ 'accuracy': f'{tr_accuracy/nb_tr_steps:.4f}'
+ })
+
+ # Log metrics to wandb
+ if nb_tr_steps % 100 == 0 and use_wandb and accelerator.is_main_process:
+ wandb.log({
+ "train/loss": tr_loss / nb_tr_steps,
+ "train/accuracy": tr_accuracy / nb_tr_steps,
+ "train/step": nb_tr_steps + epoch * len(train_dataloader)
+ })
+
+ # Optimize
+ torch.nn.utils.clip_grad_norm_(
+ parameters=model.parameters(), max_norm=MAX_GRAD_NORM
+ )
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Print epoch metrics
+ tr_loss = tr_loss / nb_tr_steps
+ tr_accuracy = tr_accuracy / nb_tr_steps
+ print(f"Training loss epoch: {tr_loss}")
+ print(f"Training accuracy epoch: {tr_accuracy}")
+
+def evaluate(
+ model: AutoModelForTokenClassification,
+ eval_dataloader: DataLoader,
+ accelerator: Accelerator,
+ epoch: int,
+ use_wandb: bool = False,
+) -> float:
+ """Evaluate the model"""
+ model.eval()
+ eval_loss, eval_accuracy = 0, 0
+ nb_eval_steps = 0
+
+ # Get num_labels from unwrapped model
+ num_labels = accelerator.unwrap_model(model).num_labels
+
+ # Add progress bar
+ progress_bar = tqdm(
+ eval_dataloader,
+ desc=f"Evaluating Epoch {epoch}",
+ disable=not accelerator.is_local_main_process
+ )
+
+ with torch.no_grad():
+ for batch in progress_bar:
+ outputs = model(
+ input_ids=batch["ids"],
+ attention_mask=batch["mask"],
+ labels=batch["targets"]
+ )
+ loss, eval_logits = outputs.loss, outputs.logits
+ eval_loss += loss.item()
+ nb_eval_steps += 1
+
+ # Calculate accuracy
+ flattened_targets = batch["targets"].view(-1)
+ active_logits = eval_logits.view(-1, num_labels)
+ flattened_predictions = torch.argmax(active_logits, axis=1)
+ active_accuracy = batch["mask"].view(-1) == 1
+ targets = torch.masked_select(flattened_targets, active_accuracy)
+ predictions = torch.masked_select(flattened_predictions, active_accuracy)
+
+ tmp_eval_accuracy = accuracy_score(
+ targets.cpu().numpy(), predictions.cpu().numpy()
+ )
+ eval_accuracy += tmp_eval_accuracy
+
+ # Update progress bar
+ progress_bar.set_postfix({
+ 'loss': f'{eval_loss/nb_eval_steps:.4f}',
+ 'accuracy': f'{eval_accuracy/nb_eval_steps:.4f}'
+ })
+
+ # Calculate and log metrics
+ eval_loss = eval_loss / nb_eval_steps
+ eval_accuracy = eval_accuracy / nb_eval_steps
+ print(f"Validation Loss: {eval_loss}")
+ print(f"Validation Accuracy: {eval_accuracy}")
+
+ if use_wandb and accelerator.is_main_process:
+ wandb.log({
+ "eval/loss": eval_loss,
+ "eval/accuracy": eval_accuracy,
+ "eval/epoch": epoch
+ })
+
+ return eval_accuracy
+
+def main():
+ args = parse_arguments()
+ os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
+
+ # Initialize accelerator first
+ accelerator = Accelerator()
+ logger = get_logger(__name__)
+
+ # Setup wandb after accelerator
+ use_wandb = setup_wandb(args.wandb_project, args.wandb_name, accelerator)
+
+ # Set seed for reproducibility
+ set_seed(args.seed)
+
+ # Load model and tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+ model = AutoModelForTokenClassification.from_pretrained(
+ args.model_name, num_labels=2, ignore_mismatched_sizes=True
+ )
+
+ # Prepare data
+ train_data, val_data = load_and_split_data(args.data_path, args.seed)
+ train_dataset, val_dataset = prepare_datasets(
+ train_data, val_data, tokenizer, args.model_name
+ )
+
+ logger.info(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
+
+ # Create dataloaders
+ train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
+ val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
+
+ # Initialize optimizer
+ optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
+
+ # Prepare everything with accelerator
+ model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
+ model, optimizer, train_dataloader, val_dataloader
+ )
+
+ # Training loop
+ best_acc = 0
+ for epoch in tqdm(range(args.num_epoch)):
+ logger.info(f"Training epoch: {epoch + 1}")
+ train_epoch(model, train_dataloader, optimizer, accelerator, epoch, use_wandb)
+ acc = evaluate(model, val_dataloader, accelerator, epoch, use_wandb)
+
+ if acc > best_acc:
+ best_acc = acc
+ # Unwrap model before saving
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ accelerator.save(unwrapped_model.state_dict(), args.save_path)
+
+if __name__ == "__main__":
+ main()
diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py
index 7af7920..84e390e 100644
--- a/llmlingua/prompt_compressor.py
+++ b/llmlingua/prompt_compressor.py
@@ -75,10 +75,12 @@ def __init__(
model_config: dict = {},
open_api_config: dict = {},
use_llmlingua2: bool = False,
+ use_slingua: bool = False,
llmlingua2_config: dict = {},
):
self.model_name = model_name
self.use_llmlingua2 = use_llmlingua2
+ self.use_slingua = use_slingua
self.retrieval_model = None
self.retrieval_model_name = None
self.open_api_config = open_api_config
@@ -87,7 +89,7 @@ def __init__(
self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
self.load_model(model_name, device_map, model_config)
- if use_llmlingua2:
+ if use_llmlingua2 or use_slingua: # slingua use llmlingua2 backend
self.init_llmlingua2(**llmlingua2_config)
def init_llmlingua2(
@@ -2398,9 +2400,13 @@ def split_string_to_words(input_string):
for word, word_prob in zip(words, word_probs):
num_token = len(self.oai_tokenizer.encode(word))
new_token_probs.extend([word_prob for _ in range(num_token)])
- threshold = np.percentile(
- new_token_probs, int(100 * reduce_rate + 1)
- )
+
+ if self.use_slingua:
+ threshold = 0.5 # slingua use fixed threshold 0.5 for binary token classification
+ else:
+ threshold = np.percentile(
+ new_token_probs, int(100 * reduce_rate + 1)
+ )
keep_words = []
word_labels = []
diff --git a/llmlingua/utils.py b/llmlingua/utils.py
index a08f615..da9986e 100644
--- a/llmlingua/utils.py
+++ b/llmlingua/utils.py
@@ -87,7 +87,9 @@ def is_begin_of_new_word(token, model_name, force_tokens, token_map):
):
return True
return not token.startswith("##")
- elif "xlm-roberta-large" in model_name:
+ elif "xlm-roberta-large" in model_name \
+ or 'slingua' in model_name.lower() \
+ or 'securitylingua' in model_name.lower():
if (
token in string.punctuation
or token in force_tokens
@@ -110,7 +112,9 @@ def get_pure_token(token, model_name):
or "tinybert" in model_name.lower() \
or "mobilebert" in model_name.lower():
return token.lstrip("##")
- elif "xlm-roberta-large" in model_name:
+ elif "xlm-roberta-large" in model_name \
+ or 'slingua' in model_name.lower() \
+ or 'securitylingua' in model_name.lower():
return token.lstrip("▁")
else:
raise NotImplementedError()