From c821129fde6ea0dbba942615001084e96d3b7338 Mon Sep 17 00:00:00 2001 From: vody-am <158507583+vody-am@users.noreply.github.com> Date: Sun, 8 Mar 2026 00:27:44 -0500 Subject: [PATCH] Fix issue with label attention and n_head None --- examples/multilabel_classification.py | 109 ++++++++++++++++++ .../model/components/text_embedder.py | 12 +- torchTextClassifiers/torchTextClassifiers.py | 2 +- 3 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 examples/multilabel_classification.py diff --git a/examples/multilabel_classification.py b/examples/multilabel_classification.py new file mode 100644 index 0000000..ce739fe --- /dev/null +++ b/examples/multilabel_classification.py @@ -0,0 +1,109 @@ +import numpy as np +import torch + +from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers +from torchTextClassifiers.dataset import TextClassificationDataset +from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule +from torchTextClassifiers.model.components import ( + AttentionConfig, + CategoricalVariableNet, + ClassificationHead, + TextEmbedder, + TextEmbedderConfig, +) +from torchTextClassifiers.tokenizers import HuggingFaceTokenizer + +# ========================================== +# 1. Ragged-lists approach +# ========================================== + +# In multilabel classification, each instance can be assigned multiple labels simultaneously. +# Let's use fake data where labels is a list of lists (ragged array). +sample_text_data = [ + "This is a positive example", + "This is a negative example", + "Another positive case", + "Another negative case", + "Good example here", + "Bad example here", +] + +# Each inner list contains labels for the corresponding instance +labels_ragged = [[0, 1, 5], [0, 4], [1, 5], [0, 1, 4], [1, 5], [0]] + +# Note: labels_ragged is a "jagged array." +# np.array(labels_ragged) would not work directly as a standard numeric matrix. +# However, torchTextClassifiers handles this directly. + +# Load a pre-trained tokenizer +tokenizer = HuggingFaceTokenizer.load_from_pretrained( + "google-bert/bert-base-uncased", output_dim=126 +) + +X = np.array(sample_text_data) +Y_ragged = labels_ragged + +# Configure the model and training +# We use BCEWithLogitsLoss for multilabel tasks to treat each label +# as a separate binary classification problem. +embedding_dim = 96 +num_classes = max(max(label_list) for label_list in labels_ragged) + 1 + +model_config = ModelConfig( + embedding_dim=embedding_dim, + num_classes=num_classes, +) + +training_config = TrainingConfig( + lr=1e-3, + batch_size=4, + num_epochs=1, + loss=torch.nn.BCEWithLogitsLoss(), # Essential for multilabel +) + +# Initialize the classifier with ragged_multilabel=True +ttc_ragged = torchTextClassifiers( + tokenizer=tokenizer, + model_config=model_config, + ragged_multilabel=True, # Key for ragged list input! +) + +if __name__ == "__main__": + + print("Starting training with ragged labels...") + ttc_ragged.train( + X_train=X, + y_train=Y_ragged, + training_config=training_config, + ) + + # Behind the scenes, the ragged lists are converted into a binary matrix (one-hot version). + + # ========================================== + # 2. One-hot / multidimensional output approach + # ========================================== + + # You can also provide a one-hot/multidimensional array (or float probabilities). + # Here, each row is a vector of size equal to the number of labels. + labels_one_hot = [ + [1., 1., 0., 0., 0., 1.], + [1., 0., 0., 0., 1., 0.], + [0., 1., 0., 0., 0., 1.], + [1., 1., 0., 0., 1., 0.], + [0., 1., 0., 0., 0., 1.], + [1., 0., 0., 0., 1., 0.] + ] + Y_one_hot = np.array(labels_one_hot) + + # When using one-hot/dense arrays, set ragged_multilabel=False (default) + ttc_dense = torchTextClassifiers( + tokenizer=tokenizer, + model_config=model_config, + ) + + print("\nStarting training with one-hot labels...") + ttc_dense.train( + X_train=X, + y_train=Y_one_hot, + training_config=training_config, + ) diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index 9d5aaa5..117aeae 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -42,7 +42,7 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): # always see a LabelAttentionConfig instance rather than a raw dict. self.config.label_attention_config = self.label_attention_config - self.enable_label_attention = self.label_attention_config is not None + self.enable_label_attention = self.label_attention_config is not None and getattr(self.label_attention_config, 'n_head', None) is not None if self.enable_label_attention: self.label_attention_module = LabelAttentionClassifier(self.config) @@ -58,6 +58,16 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): if self.attention_config is not None: self.attention_config.n_embd = text_embedder_config.embedding_dim + + if self.attention_config.n_embd is None: + raise ValueError("n_embd must be specified in AttentionConfig or embedding_dim in TextEmbedderConfig.") + if self.attention_config.n_head is None: + raise ValueError("n_head must be specified in AttentionConfig.") + if self.attention_config.n_layers is None: + raise ValueError("n_layers must be specified in AttentionConfig.") + if self.attention_config.n_kv_head is None: + raise ValueError("n_kv_head must be specified in AttentionConfig.") + self.transformer = nn.ModuleDict( { "h": nn.ModuleList( diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index e38bb97..33dc3ca 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -302,7 +302,7 @@ def train( texts=X_train["text"], categorical_variables=X_train["categorical_variables"], # None if no cat vars tokenizer=self.tokenizer, - labels=y_train.tolist(), + labels=y_train.tolist() if not isinstance(y_train, list) else y_train, ragged_multilabel=self.ragged_multilabel, ) train_dataloader = train_dataset.create_dataloader(