Description
Add a new sklearn-compatible transformer that wraps the sentence-transformers library, allowing it to be used seamlessly within sklearn pipelines for NLP tasks.
Motivation
Currently, there's no built-in way to use sentence-transformers models within sklearn pipelines. This transformer would enable users to:
- Integrate state-of-the-art sentence embeddings into sklearn workflows
- Leverage lazy loading for efficient model initialization
- Handle various input formats (strings, lists, Series, arrays)
- Process None/NaN values gracefully
Proposed Implementation
New Module: ds_utils.text (or ds_utils.nlp)
Class: SentenceEmbeddingTransformer(BaseEstimator, TransformerMixin)
Key Features:
- Lazy Loading: Model loads only when first needed (in
fit or transform)
- Sklearn Compatible: Implements
fit() and transform() methods
- Flexible Input Handling:
- Supports strings, lists, pandas Series, numpy arrays
- Handles None/NaN values by replacing with empty strings
- Configurable Parameters:
model_name: The sentence-transformers model to use (default: 'sentence-transformers/all-MiniLM-L6-v2')
batch_size: Batch size for encoding (default: 32)
show_progress_bar: Whether to show progress during encoding (default: True)
convert_to_numpy: Whether to return numpy arrays (default: True)
normalize_embeddings: Whether to normalize embeddings to unit length (default: False)
device: Device to use for computation ('cpu', 'cuda', etc.) (default: None - auto-detect)
precision: Precision for embeddings ('float32', 'int8', 'uint8', 'binary', 'ubinary') (default: 'float32')
truncate_dim: Dimension to truncate embeddings to (useful for Matryoshka models) (default: None)
prompt_name: Name of prompt to use from model's prompts dictionary (default: None)
prompt: Custom prompt to prepend to inputs (default: None)
Implementation Based on Code
class SentenceEmbeddingTransformer(BaseEstimator, TransformerMixin):
def __init__(
self,
model_name='sentence-transformers/all-MiniLM-L6-v2',
batch_size=32,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=False,
device=None,
precision='float32',
truncate_dim=None,
prompt_name=None,
prompt=None
):
self.model_name = model_name
self.batch_size = batch_size
self.show_progress_bar = show_progress_bar
self.convert_to_numpy = convert_to_numpy
self.normalize_embeddings = normalize_embeddings
self.device = device
self.precision = precision
self.truncate_dim = truncate_dim
self.prompt_name = prompt_name
self.prompt = prompt
self.model = None
def fit(self, X, y=None):
if self.model is None:
print(f"Loading sentence transformer: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
return self
def transform(self, X):
if self.model is None:
self.fit(X)
# Handle Series or array input
texts = X.tolist() if hasattr(X, 'tolist') else list(X)
# Replace None/NaN with empty string
texts = [str(t) if pd.notna(t) and t is not None and t != 'None' else "" for t in texts]
embeddings = self.model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar,
convert_to_numpy=self.convert_to_numpy,
normalize_embeddings=self.normalize_embeddings,
device=self.device,
precision=self.precision,
truncate_dim=self.truncate_dim,
prompt_name=self.prompt_name,
prompt=self.prompt
)
return embeddings
Example Usage:
from sklearn.pipeline import Pipeline
from ds_utils.text import SentenceEmbeddingTransformer
from sklearn.ensemble import RandomForestClassifier
pipeline = Pipeline([
('embeddings', SentenceEmbeddingTransformer(
model_name='all-MiniLM-L6-v2',
normalize_embeddings=True
)),
('classifier', RandomForestClassifier())
])
pipeline.fit(X_train, y_train)
predictions = pipeline.predict(X_test)
Implementation Checklist
Dependencies
Add sentence-transformers>=3.0.0 to requirements.txt or as an optional dependency
Additional Parameters from sentence-transformers Documentation
Based on the sentence-transformers API, the encode() method supports:
- precision: 'float32', 'int8', 'uint8', 'binary', 'ubinary' for quantized embeddings
- normalize_embeddings: Normalize vectors to unit length for faster dot-product similarity
- truncate_dim: Truncate embeddings (useful for Matryoshka models)
- prompt_name: Use predefined prompts from the model
- prompt: Custom prompt to prepend to text
- device: Single device or list of devices for multi-process encoding
- output_value: 'sentence_embedding' or 'token_embeddings'
- convert_to_numpy: Return numpy arrays vs torch tensors
- convert_to_tensor: Return single tensor instead of array
Questions to Consider
- Should we support
output_value='token_embeddings' for token-level embeddings?
- Should we add a method to get embedding dimension via
get_sentence_embedding_dimension()?
- Should we support multi-process encoding via the
pool parameter?
- Should we add support for
encode_query() and encode_document() methods for specialized encoding?
Related Documentation
Description
Add a new sklearn-compatible transformer that wraps the
sentence-transformerslibrary, allowing it to be used seamlessly within sklearn pipelines for NLP tasks.Motivation
Currently, there's no built-in way to use sentence-transformers models within sklearn pipelines. This transformer would enable users to:
Proposed Implementation
New Module:
ds_utils.text(ords_utils.nlp)Class:
SentenceEmbeddingTransformer(BaseEstimator, TransformerMixin)Key Features:
fitortransform)fit()andtransform()methodsmodel_name: The sentence-transformers model to use (default: 'sentence-transformers/all-MiniLM-L6-v2')batch_size: Batch size for encoding (default: 32)show_progress_bar: Whether to show progress during encoding (default: True)convert_to_numpy: Whether to return numpy arrays (default: True)normalize_embeddings: Whether to normalize embeddings to unit length (default: False)device: Device to use for computation ('cpu', 'cuda', etc.) (default: None - auto-detect)precision: Precision for embeddings ('float32', 'int8', 'uint8', 'binary', 'ubinary') (default: 'float32')truncate_dim: Dimension to truncate embeddings to (useful for Matryoshka models) (default: None)prompt_name: Name of prompt to use from model's prompts dictionary (default: None)prompt: Custom prompt to prepend to inputs (default: None)Implementation Based on Code
Example Usage:
Implementation Checklist
ds_utils/text.py(ords_utils/nlp.py)SentenceEmbeddingTransformerclass with all parameterssentence-transformersto dependenciesDependencies
Add
sentence-transformers>=3.0.0torequirements.txtor as an optional dependencyAdditional Parameters from sentence-transformers Documentation
Based on the sentence-transformers API, the
encode()method supports:Questions to Consider
output_value='token_embeddings'for token-level embeddings?get_sentence_embedding_dimension()?poolparameter?encode_query()andencode_document()methods for specialized encoding?Related Documentation