Skip to content

Add SentenceEmbeddingTransformer to support sentence-transformers in sklearn pipelines #88

@idanmoradarthas

Description

@idanmoradarthas

Description

Add a new sklearn-compatible transformer that wraps the sentence-transformers library, allowing it to be used seamlessly within sklearn pipelines for NLP tasks.

Motivation

Currently, there's no built-in way to use sentence-transformers models within sklearn pipelines. This transformer would enable users to:

  • Integrate state-of-the-art sentence embeddings into sklearn workflows
  • Leverage lazy loading for efficient model initialization
  • Handle various input formats (strings, lists, Series, arrays)
  • Process None/NaN values gracefully

Proposed Implementation

New Module: ds_utils.text (or ds_utils.nlp)

Class: SentenceEmbeddingTransformer(BaseEstimator, TransformerMixin)

Key Features:

  1. Lazy Loading: Model loads only when first needed (in fit or transform)
  2. Sklearn Compatible: Implements fit() and transform() methods
  3. Flexible Input Handling:
    • Supports strings, lists, pandas Series, numpy arrays
    • Handles None/NaN values by replacing with empty strings
  4. Configurable Parameters:
    • model_name: The sentence-transformers model to use (default: 'sentence-transformers/all-MiniLM-L6-v2')
    • batch_size: Batch size for encoding (default: 32)
    • show_progress_bar: Whether to show progress during encoding (default: True)
    • convert_to_numpy: Whether to return numpy arrays (default: True)
    • normalize_embeddings: Whether to normalize embeddings to unit length (default: False)
    • device: Device to use for computation ('cpu', 'cuda', etc.) (default: None - auto-detect)
    • precision: Precision for embeddings ('float32', 'int8', 'uint8', 'binary', 'ubinary') (default: 'float32')
    • truncate_dim: Dimension to truncate embeddings to (useful for Matryoshka models) (default: None)
    • prompt_name: Name of prompt to use from model's prompts dictionary (default: None)
    • prompt: Custom prompt to prepend to inputs (default: None)

Implementation Based on Code

class SentenceEmbeddingTransformer(BaseEstimator, TransformerMixin):
    def __init__(
        self,
        model_name='sentence-transformers/all-MiniLM-L6-v2',
        batch_size=32,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=False,
        device=None,
        precision='float32',
        truncate_dim=None,
        prompt_name=None,
        prompt=None
    ):
        self.model_name = model_name
        self.batch_size = batch_size
        self.show_progress_bar = show_progress_bar
        self.convert_to_numpy = convert_to_numpy
        self.normalize_embeddings = normalize_embeddings
        self.device = device
        self.precision = precision
        self.truncate_dim = truncate_dim
        self.prompt_name = prompt_name
        self.prompt = prompt
        self.model = None
    
    def fit(self, X, y=None):
        if self.model is None:
            print(f"Loading sentence transformer: {self.model_name}")
            self.model = SentenceTransformer(self.model_name)
        return self
    
    def transform(self, X):
        if self.model is None:
            self.fit(X)
        
        # Handle Series or array input
        texts = X.tolist() if hasattr(X, 'tolist') else list(X)
        
        # Replace None/NaN with empty string
        texts = [str(t) if pd.notna(t) and t is not None and t != 'None' else "" for t in texts]
        
        embeddings = self.model.encode(
            texts,
            batch_size=self.batch_size,
            show_progress_bar=self.show_progress_bar,
            convert_to_numpy=self.convert_to_numpy,
            normalize_embeddings=self.normalize_embeddings,
            device=self.device,
            precision=self.precision,
            truncate_dim=self.truncate_dim,
            prompt_name=self.prompt_name,
            prompt=self.prompt
        )
        
        return embeddings

Example Usage:

from sklearn.pipeline import Pipeline
from ds_utils.text import SentenceEmbeddingTransformer
from sklearn.ensemble import RandomForestClassifier

pipeline = Pipeline([
    ('embeddings', SentenceEmbeddingTransformer(
        model_name='all-MiniLM-L6-v2',
        normalize_embeddings=True
    )),
    ('classifier', RandomForestClassifier())
])

pipeline.fit(X_train, y_train)
predictions = pipeline.predict(X_test)

Implementation Checklist

  • Create new module ds_utils/text.py (or ds_utils/nlp.py)
  • Implement SentenceEmbeddingTransformer class with all parameters
  • Add unit tests covering:
    • Model lazy loading
    • Different input formats (list, Series, array, single string)
    • None/NaN handling
    • Integration with sklearn pipelines
    • Batch processing
    • Different precision modes (float32, int8, binary, etc.)
    • normalize_embeddings parameter
    • truncate_dim parameter
    • prompt and prompt_name parameters
  • Add documentation and docstrings
  • Update README.md with usage examples
  • Add sentence-transformers to dependencies

Dependencies

Add sentence-transformers>=3.0.0 to requirements.txt or as an optional dependency

Additional Parameters from sentence-transformers Documentation

Based on the sentence-transformers API, the encode() method supports:

  • precision: 'float32', 'int8', 'uint8', 'binary', 'ubinary' for quantized embeddings
  • normalize_embeddings: Normalize vectors to unit length for faster dot-product similarity
  • truncate_dim: Truncate embeddings (useful for Matryoshka models)
  • prompt_name: Use predefined prompts from the model
  • prompt: Custom prompt to prepend to text
  • device: Single device or list of devices for multi-process encoding
  • output_value: 'sentence_embedding' or 'token_embeddings'
  • convert_to_numpy: Return numpy arrays vs torch tensors
  • convert_to_tensor: Return single tensor instead of array

Questions to Consider

  1. Should we support output_value='token_embeddings' for token-level embeddings?
  2. Should we add a method to get embedding dimension via get_sentence_embedding_dimension()?
  3. Should we support multi-process encoding via the pool parameter?
  4. Should we add support for encode_query() and encode_document() methods for specialized encoding?

Related Documentation

Metadata

Metadata

Labels

enhancementNew feature or request

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions