diff --git a/.gitignore b/.gitignore index b520deb..726df61 100644 --- a/.gitignore +++ b/.gitignore @@ -121,4 +121,9 @@ logs/ *.csv !.dvc data +logs +pretrain_embeddings wget-log +checkpoints +*.parquet +*.pt \ No newline at end of file diff --git a/benchmarks/pretrain_benchmark/apply_benchmark.ipynb b/benchmarks/pretrain_benchmark/apply_benchmark.ipynb new file mode 100644 index 0000000..75db554 --- /dev/null +++ b/benchmarks/pretrain_benchmark/apply_benchmark.ipynb @@ -0,0 +1,492 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "345ea8f5-7653-4e0f-92ce-2b5d713d24d7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/machenike/bmml/DataMetaMap/src/data_meta_map/wasserstein_embedder.py:8: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n" + ] + } + ], + "source": [ + "from data_meta_map.task2vec import task2vec\n", + "from data_meta_map.models import get_model\n", + "from data_meta_map import datasets\n", + "from data_meta_map.task2vec import plot_distance_matrix\n", + "from data_meta_map.task2vec import Task2Vec\n", + "from data_meta_map.task2vec.task_similarity import cosine\n", + "\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ef2aa314-ea65-440b-8db2-45c0ae674e14", + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "from collections import defaultdict\n", + "\n", + "def get_pretrained_results(dataset_names, path_to_logs):\n", + " pretrain2downstream_results = defaultdict(dict)\n", + " for pretrain_name in dataset_names:\n", + " for downstream_name in dataset_names:\n", + " if downstream_name == 'imagenet':\n", + " continue\n", + " if downstream_name == pretrain_name:\n", + " continue\n", + " with open(f'{path_to_logs}/{pretrain_name}-{downstream_name}.yaml', 'r') as f:\n", + " test_acc = yaml.safe_load(f)['task']['test_acc']\n", + " pretrain2downstream_results[pretrain_name][downstream_name] = test_acc\n", + " return pretrain2downstream_results\n", + "\n", + "def get_embedder_results(dataset_names, path_to_pretrained_logs, embedder_func, similarity_func):\n", + " embeddings = []\n", + "\n", + " dataset_list = [datasets.__dict__[name](root='../../data')[0] for name in dataset_names]\n", + " for name, dataset in zip(dataset_names, dataset_list):\n", + " embeddings.append(embedder_func(dataset))\n", + "\n", + " def find_closest(dataset_names, embeddings, similarity_metric):\n", + " res = dict()\n", + " for dataset_name, embed in zip(dataset_names, embeddings):\n", + " dists = {name:similarity_metric(embed, other_embed) for name, other_embed in zip(dataset_names, embeddings) if name != dataset_name}\n", + " argmax = max(dists.items(), key = lambda x: x[1])[0]\n", + " res[dataset_name] = argmax\n", + " return res\n", + " \n", + " dataset2closest = find_closest(dataset_names, embeddings, cosine_similarity)\n", + " pretrain2downstream_results = get_pretrained_results(dataset_names, path_to_pretrained_logs)\n", + " method_performance = {name: {'accuracy' : pretrain2downstream_results[dataset2closest[name]][name], \n", + " 'pretrain': dataset2closest[name]} for name in dataset_names}\n", + " return method_performance\n", + "\n", + "def get_random_baseline(dataset_names, path_to_pretrained_logs):\n", + " pretrain2downstream_results = get_pretrained_results(dataset_names + ['imagenet'], path_to_pretrained_logs)\n", + " random_performance = {}\n", + " for name in dataset_names:\n", + " if name == 'imagenet':\n", + " continue\n", + " choice = name\n", + " while choice == name:\n", + " choice = np.random.choice(dataset_names)\n", + " random_performance[name] = {\n", + " 'accuracy': pretrain2downstream_results[choice][name],\n", + " 'pretrain': choice\n", + " }\n", + " return random_performance\n", + " \n", + "def get_big_pretrain_baseline(dataset_names, path_to_pretrained_logs):\n", + " pretrain2downstream_results = get_pretrained_results(dataset_names + ['imagenet'], path_to_pretrained_logs)\n", + " big_baseline_performance = {}\n", + " for name in dataset_names:\n", + " if name == 'imagenet':\n", + " continue\n", + " big_baseline_performance[name] = {\n", + " 'accuracy': pretrain2downstream_results['imagenet'][name],\n", + " 'pretrain': 'imagenet'\n", + " }\n", + " return big_baseline_performance " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b8c01eb9-3a90-4c8b-bfea-d2253b8785ff", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Caching features: 0%| | 0/14 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
mnistcifar10cifar100letterskmnist
task2vec0.99600.87980.49600.9325960.9957
random0.98150.63900.35120.9335100.9752
big_pretrain0.98480.88340.67890.9197120.9851
\n", + "" + ], + "text/plain": [ + " mnist cifar10 cifar100 letters kmnist\n", + "task2vec 0.9960 0.8798 0.4960 0.932596 0.9957\n", + "random 0.9815 0.6390 0.3512 0.933510 0.9752\n", + "big_pretrain 0.9848 0.8834 0.6789 0.919712 0.9851" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "pd.DataFrame(res).T" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/benchmarks/pretrain_benchmark/get_pretrained_to_task.py b/benchmarks/pretrain_benchmark/get_pretrained_to_task.py new file mode 100644 index 0000000..717c028 --- /dev/null +++ b/benchmarks/pretrain_benchmark/get_pretrained_to_task.py @@ -0,0 +1,291 @@ +from data_meta_map.task2vec import task2vec +from data_meta_map.models import get_model +from data_meta_map import datasets +from data_meta_map.task2vec import plot_distance_matrix + +import torch +from torch.utils.data import DataLoader, random_split +import torch.nn as nn +import pandas as pd +from tqdm import tqdm + +import pandas as pd +import torch +from torch.utils.data import TensorDataset, DataLoader, random_split + +import hydra +import yaml +from omegaconf import DictConfig, OmegaConf + + +@torch.inference_mode() +def evaluate(model, loader, device): + + model.eval() + correct = 0 + total = 0 + + for x, y in loader: + + x = x.to(device) + y = y.to(device) + + logits = model(x) + pred = logits.argmax(1) + + correct += (pred == y).sum().item() + total += y.size(0) + + return correct / total + + +def train(model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + optimizer, + criterion, + best_path: str, + num_epochs: int, + patience: int, + device: str): + + best_acc = 0 + best_epoch = 0 + epochs_without_improvement = 0 + logs = {} + + for epoch in range(num_epochs): + + model.train() + pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") + for x, y in pbar: + x = x.to(device) + y = y.to(device) + + logits = model(x) + loss = criterion(logits, y) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + val_acc = evaluate(model, val_loader, device=device) + + # improvement + if val_acc > best_acc: + + best_acc = val_acc + epochs_without_improvement = 0 + best_epoch = epoch + + torch.save(model.state_dict(), best_path) + else: + epochs_without_improvement += 1 + + # early stopping condition + if epochs_without_improvement >= patience: + + break + + model.load_state_dict(torch.load(best_path)) + model.eval() + logs = { + 'best_val_accuracy': best_acc, + 'best_val_epoch': best_epoch + } + return model, logs + + +@torch.inference_mode() +def extract_embeddings(model, loader): + + embeddings = [] + labels = [] + + model.eval() + + for x, y in tqdm(loader): + + x = x.cuda() + + feat = model(x) + feat = feat.view(feat.size(0), -1) + + embeddings.append(feat.cpu()) + labels.append(y) + + embeddings = torch.cat(embeddings).numpy() + labels = torch.cat(labels).numpy() + + return embeddings, labels + + +def save_parquet(embeds, labels, path): + + df = pd.DataFrame({ + "hidden_state": embeds.tolist(), + "target": labels + }) + + df.to_parquet(path, index=False) + + +class MLP(nn.Module): + + def __init__(self, input_dim=512, hidden_dim=256, num_classes=10): + super().__init__() + + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Dropout(0.15), + + nn.Linear(hidden_dim, num_classes) + ) + + def forward(self, x): + return self.net(x) + + +def estimate_task_performance(config): + train_df = pd.read_parquet(config['paths']['train_parquet']) + test_df = pd.read_parquet(config['paths']['test_parquet']) + + X_train = torch.tensor(train_df.hidden_state.tolist(), dtype=torch.float32) + y_train = torch.tensor(train_df.target.values, dtype=torch.long) + + X_test = torch.tensor(test_df.hidden_state.tolist(), dtype=torch.float32) + y_test = torch.tensor(test_df.target.values, dtype=torch.long) + + VAL_RATIO = config['training_task']['val_split'] + + dataset = TensorDataset(X_train, y_train) + + train_size = int((1 - VAL_RATIO) * len(dataset)) + val_size = len(dataset) - train_size + + train_ds, val_ds = random_split(dataset, [train_size, val_size]) + + BATCH_SIZE = config['training_task']['batch_size'] + train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE) + test_loader = DataLoader(TensorDataset( + X_test, y_test), batch_size=BATCH_SIZE) + + criterion = nn.CrossEntropyLoss() + + device = torch.device(config.device) + model = MLP(input_dim=512, + hidden_dim=config['estimate_network_params']['hidden_dim'], + num_classes=max(y_test) + 1).to(device) + optimizer = torch.optim.Adam( + model.parameters(), lr=config['training']['lr']) + + model, logs = train( + model, + train_loader, + val_loader, + optimizer, + criterion, + best_path=config['paths']['checkpoint_task'], + num_epochs=config['training_task']['epochs'], + patience=config['training_task']['early_stopping_epochs'], + device=device + ) + test_acc = evaluate(model, test_loader, device) + logs['test_acc'] = test_acc + return logs + + +def apply_pretrain(config): + device = torch.device(config.device) + + model = get_model(config['model']['name'], + pretrained=config['model']['pretrained']).to(device) + train_dataset, test_dataset = datasets.__dict__[ + config['pretrain_dataset_name']](root=config['paths']['data_dir']) + + device = "cuda" + + VAL_RATIO = config['training']['val_split'] + BATCH_SIZE = config['training']['batch_size'] + + train_size = int((1 - VAL_RATIO) * len(train_dataset)) + val_size = len(train_dataset) - train_size + + train_ds, val_ds = random_split(train_dataset, [train_size, val_size]) + + train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) + test_loader = DataLoader( + test_dataset, batch_size=BATCH_SIZE, shuffle=False) + + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam( + model.parameters(), lr=config['training']['lr']) + + model, logs = train( + model, + train_loader, + val_loader, + optimizer, + criterion, + best_path=config['paths']['checkpoint'], + num_epochs=config['training']['epochs'], + patience=config['training']['early_stopping_epochs'], + device=device + ) + return logs + + +def inference_task(config): + device = torch.device(config.device) + + model = get_model(config['model']['name'], + pretrained=config['model']['pretrained']).to(device) + + if not config['use_basic_model']: + best_path = config['paths']['checkpoint'] + model.load_state_dict(torch.load(best_path)) + model.eval() + + embedder = torch.nn.Sequential(*list(model.children())[:-1]).to(device) + + train_dataset, test_dataset = datasets.__dict__[ + config['task_dataset_name']](root=config['paths']['data_dir']) + + BATCH_SIZE = config['training']['batch_size'] + test_loader = DataLoader( + test_dataset, batch_size=BATCH_SIZE, shuffle=False) + train_full_loader = DataLoader( + train_dataset, batch_size=BATCH_SIZE, shuffle=False) + + train_embeds, train_labels = extract_embeddings( + embedder, train_full_loader) + test_embeds, test_labels = extract_embeddings(embedder, test_loader) + + save_parquet(train_embeds, train_labels, config['paths']['train_parquet']) + save_parquet(test_embeds, test_labels, config['paths']['test_parquet']) + + +@hydra.main(config_path=".", config_name="config", version_base=None) +def main(config: DictConfig): + _, pretrain_logs = apply_pretrain( + config) if config['need_pretrain'] else None, {} + task_logs = {} + if not config['pretrain_only']: + inference_task(config) + task_logs = estimate_task_performance(config) + logs = { + 'pretrain': pretrain_logs, + 'task': task_logs, + 'pretrain_to_task_config': OmegaConf.to_container(config, resolve=True) + } + with open(config['paths']['save_logs'], 'w') as f: + yaml.safe_dump(logs, f) + + +if __name__ == "__main__": + main() diff --git a/configs/pretrain_to_task.yaml b/configs/pretrain_to_task.yaml new file mode 100644 index 0000000..4b37156 --- /dev/null +++ b/configs/pretrain_to_task.yaml @@ -0,0 +1,39 @@ +seed: 42 +device: cuda + +model: + name: resnet18 # resnet18 | resnet34 + pretrained: true + +pretrain_only: false +need_pretrain: true +use_basic_model: false + +pretrain_dataset_name: stl10 # mnist | cifar10 | cifar100 | letters | kmnist +task_dataset_name: mnist # mnist | cifar10 | cifar100 | letters | kmnist + +training: + batch_size: 128 + epochs: 20 + lr: 0.0003 + val_split: 0.2 + early_stopping_epochs: 5 + +training_task: + batch_size: 128 + epochs: 20 + lr: 0.0003 + val_split: 0.2 + early_stopping_epochs: 10 + +estimate_network_params: + hidden_dim: 256 + num_classes: 10 + +paths: + data_dir: /home/machenike/bmml/DataMetaMap/data + train_parquet: /home/machenike/bmml/DataMetaMap/pretrain_embeddings/train_embeds.parquet + test_parquet: /home/machenike/bmml/DataMetaMap/pretrain_embeddings/test_embeds.parquet + checkpoint: /home/machenike/bmml/DataMetaMap/checkpoints/${pretrain_dataset_name}_pretrain.pt + checkpoint_task: /home/machenike/bmml/DataMetaMap/checkpoints/best_model.pt + save_logs: /home/machenike/bmml/DataMetaMap/logs/pretrain_to_task_logs/${pretrain_dataset_name}-${task_dataset_name}.yaml \ No newline at end of file diff --git a/demo/task2vec/simple_example.ipynb b/demo/task2vec/simple_example.ipynb index d5b84ea..ab5be76 100644 --- a/demo/task2vec/simple_example.ipynb +++ b/demo/task2vec/simple_example.ipynb @@ -2,18 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": null, - "id": "732cfc36-76c6-4b8a-b4fb-1e67c9e48902", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "7c7f2bed-6353-4b29-aeed-b08cc9835a1b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/machenike/bmml/DataMetaMap/src/data_meta_map/wasserstein_embedder.py:8: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n" + ] + } + ], "source": [ "from data_meta_map.task2vec import task2vec\n", "from data_meta_map.models import get_model\n", @@ -21,10 +22,274 @@ "from data_meta_map.task2vec import plot_distance_matrix" ] }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2f3376cd-a564-4f99-8966-ad6d89b9ecaa", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Embedding mnist\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Caching features: 0%| | 0/14 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_distance_matrix(embeddings=embeddings, labels=dataset_names)" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "c5107e1b-93e1-4861-becb-16acf1fdd9c6", + "id": "705192a9-5ee8-4f92-aa41-a2c19197d017", "metadata": {}, "outputs": [], "source": [] diff --git a/setup.py b/setup.py index 01ea3b0..bba14c2 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ version="0.1.0", author="...", description="...", - long_description=open("README.rst").read(), + #long_description=open("README.rst").read(), long_description_content_type="text/x-rst", url="https://github.com/intsystems/DataMetaMap", packages=find_packages(where="src"), diff --git a/src/data_meta_map/__init__.py b/src/data_meta_map/__init__.py index aa0f67e..a4a8c1f 100644 --- a/src/data_meta_map/__init__.py +++ b/src/data_meta_map/__init__.py @@ -1,9 +1,8 @@ -from .BaseEmbedder import BaseEmbedder -from .WassersteinEmbedder import WassersteinEmbedder - +from data_meta_map.base_embedder import BaseEmbedder +from data_meta_map.wasserstein_embedder import WassersteinEmbedder __all__ = [ - "BaseEmbedder", - "WassersteinEmbedder" + "baseEmbedder", + "wasserstein_embedder" ] diff --git a/src/data_meta_map/BaseEmbedder.py b/src/data_meta_map/base_embedder.py similarity index 94% rename from src/data_meta_map/BaseEmbedder.py rename to src/data_meta_map/base_embedder.py index 1ca2f9b..a299c4d 100644 --- a/src/data_meta_map/BaseEmbedder.py +++ b/src/data_meta_map/base_embedder.py @@ -19,11 +19,25 @@ class SupportsGetItem(Protocol): Protocol for objects with __getitem__ and __len__ interface. Ensures compatibility with both Dataset and custom iterable objects. """ + def __getitem__(self, index: int) -> Tuple[Any, int]: ... def __len__(self) -> int: ... class BaseEmbedder(ABC): + def __init__( + self + ): + pass + + @abstractmethod + def embed(self, *args, **kwargs): + raise NotImplementedError( + "Override this method in your Embedder class") + + +# DEPRECATED +class BaseEmbedderDEPRECATED(ABC): """ Abstract base class for dataset embedding. @@ -191,7 +205,8 @@ def get_class_statistics( feature_dim = X.shape[1] means = torch.zeros((num_classes, feature_dim), device=self.device) - covs = torch.zeros((num_classes, feature_dim, feature_dim), device=self.device) + covs = torch.zeros( + (num_classes, feature_dim, feature_dim), device=self.device) for idx, label in enumerate(unique_labels): mask = (Y == label) @@ -201,7 +216,8 @@ def get_class_statistics( covs[idx] = torch.cov(class_samples.T) else: # For single sample, covariance is undefined — use zero matrix - covs[idx] = torch.zeros((feature_dim, feature_dim), device=self.device) + covs[idx] = torch.zeros( + (feature_dim, feature_dim), device=self.device) return means, covs diff --git a/src/data_meta_map/datasets.py b/src/data_meta_map/datasets.py index fb577d3..48662af 100644 --- a/src/data_meta_map/datasets.py +++ b/src/data_meta_map/datasets.py @@ -341,7 +341,7 @@ def kmnist(root): transforms.Resize(224), transforms.ToTensor(), ]) - trainset = KMNIST(root, train=True, transform=transform, download=True) + trainset = KMNIST(root, train=True, transform=transform, download=False) testset = KMNIST(root, train=False, transform=transform) return trainset, testset diff --git a/src/data_meta_map/task2vec/task2vec.py b/src/data_meta_map/task2vec/task2vec.py index 40923ca..9989a3d 100644 --- a/src/data_meta_map/task2vec/task2vec.py +++ b/src/data_meta_map/task2vec/task2vec.py @@ -26,6 +26,8 @@ from torch.optim.optimizer import Optimizer from data_meta_map.task2vec.utils import AverageMeter, get_error, get_device +from data_meta_map import BaseEmbedder + class Embedding: def __init__(self, hessian, scale, meta=None): @@ -63,7 +65,7 @@ def task2vec(probe_network, dataset: Dataset, skip_layers=0, max_samples=None, c return embed -class Task2Vec: +class Task2Vec(BaseEmbedder): def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classifier_opts=None, method='montecarlo', method_opts=None, loader_opts=None, bernoulli=False): @@ -90,7 +92,7 @@ def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classif self.loss_fn = nn.CrossEntropyLoss() if not self.bernoulli else nn.BCEWithLogitsLoss() self.loss_fn = self.loss_fn.to(self.device) - def embed(self, dataset: Dataset, create_final_embedding: bool = False): + def embed(self, dataset: Dataset, create_final_embedding: bool = True): # Cache the last layer features (needed to train the classifier) and (if needed) the intermediate layer features # so that we can skip the initial layers when computing the embedding if self.skip_layers > 0: diff --git a/src/data_meta_map/WassersteinEmbedder.py b/src/data_meta_map/wasserstein_embedder.py similarity index 92% rename from src/data_meta_map/WassersteinEmbedder.py rename to src/data_meta_map/wasserstein_embedder.py index e925a98..af2761d 100644 --- a/src/data_meta_map/WassersteinEmbedder.py +++ b/src/data_meta_map/wasserstein_embedder.py @@ -7,8 +7,7 @@ import ot # POT: Python Optimal Transport from tqdm.autonotebook import tqdm -from .BaseEmbedder import BaseEmbedder - +from data_meta_map.base_embedder import BaseEmbedder def sqrtm_newton_schulz(A: torch.Tensor, num_iters: int = 20) -> torch.Tensor: @@ -105,7 +104,8 @@ def __init__( gaussian_assumption: bool = True, diagonal_cov: bool = False, commute: bool = False, - sqrt_method: str = "ns", # 'ns' = Newton-Schulz (default), 'eig' = eigenvalue decomposition + # 'ns' = Newton-Schulz (default), 'eig' = eigenvalue decomposition + sqrt_method: str = "ns", sqrt_niters: int = 20, **kwargs ): @@ -133,7 +133,8 @@ def __init__( self.sqrt_niters = sqrt_niters # Cache for class statistics: {dataset_id: (means, covs, class_offsets)} - self._stats_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor, List[int]]] = {} + self._stats_cache: Dict[int, + Tuple[torch.Tensor, torch.Tensor, List[int]]] = {} # Cache for preprocessed data: {dataset_id: (X, Y)} self._data_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} @@ -171,9 +172,11 @@ def preprocess_dataset( idxs = np.sort(np.random.choice( len(data), self.max_samples, replace=False)) sampler = SubsetRandomSampler(idxs) - loader = DataLoader(data, sampler=sampler, batch_size=self.batch_size) + loader = DataLoader(data, sampler=sampler, + batch_size=self.batch_size) else: - loader = DataLoader(data, batch_size=self.batch_size, shuffle=False) + loader = DataLoader( + data, batch_size=self.batch_size, shuffle=False) elif isinstance(data, DataLoader): loader = data else: @@ -231,7 +234,8 @@ def _compute_gaussian_stats( if self.diagonal_cov: covs = torch.zeros((num_classes, feature_dim), device=self.device) else: - covs = torch.zeros((num_classes, feature_dim, feature_dim), device=self.device) + covs = torch.zeros( + (num_classes, feature_dim, feature_dim), device=self.device) for idx, label in enumerate(unique_labels): mask = (Y == label) @@ -251,7 +255,8 @@ def _compute_gaussian_stats( if self.diagonal_cov: covs[idx] = torch.zeros(feature_dim, device=self.device) else: - covs[idx] = torch.zeros((feature_dim, feature_dim), device=self.device) + covs[idx] = torch.zeros( + (feature_dim, feature_dim), device=self.device) # Global class indices class_offsets = list(range(num_classes)) @@ -378,11 +383,14 @@ def compute_pairwise_distances( means_j[idx_j], covs_j[idx_j] ) else: - X_i, Y_i = self._data_cache.get(i, self.preprocess_dataset(datasets[i], dataset_id=i)) - X_j, Y_j = self._data_cache.get(j, self.preprocess_dataset(datasets[j], dataset_id=j)) + X_i, Y_i = self._data_cache.get( + i, self.preprocess_dataset(datasets[i], dataset_id=i)) + X_j, Y_j = self._data_cache.get( + j, self.preprocess_dataset(datasets[j], dataset_id=j)) samples_i = X_i[Y_i == local_i] samples_j = X_j[Y_j == local_j] - d = self._exact_wasserstein_distance(samples_i, samples_j) + d = self._exact_wasserstein_distance( + samples_i, samples_j) d = torch.tensor(d, device=self.device) D[global_i, global_j] = d @@ -446,7 +454,8 @@ def augment_features( X, Y = self.preprocess_dataset(data, dataset_id=dataset_idx) start_offset = class_offsets[dataset_idx] - end_offset = class_offsets[dataset_idx + 1] if dataset_idx + 1 < len(class_offsets) else label_embeddings.shape[0] + end_offset = class_offsets[dataset_idx + 1] if dataset_idx + \ + 1 < len(class_offsets) else label_embeddings.shape[0] label_emb_for_dataset = label_embeddings[start_offset:end_offset] label_indices = Y.long() @@ -485,7 +494,8 @@ def compute_wte( augmented_datasets: List[torch.Tensor] = [] for idx, dataset in enumerate(datasets): - Z = self.augment_features(dataset, label_embeddings, idx, class_offsets) + Z = self.augment_features( + dataset, label_embeddings, idx, class_offsets) augmented_datasets.append(Z) if reference is None and create_reference: @@ -494,15 +504,18 @@ def compute_wte( ref_indices = torch.randperm(all_data.shape[0])[:ref_size] reference = all_data[ref_indices].float() elif reference is None: - raise ValueError("Either provide 'reference' or set 'create_reference=True'") + raise ValueError( + "Either provide 'reference' or set 'create_reference=True'") task_embeddings = [] ref_size = reference.shape[0] for Z in augmented_datasets: Z = Z.float() - C = ot.dist(Z.cpu().numpy(), reference.cpu().numpy(), metric='euclidean') - gamma = ot.emd(ot.unif(Z.shape[0]), ot.unif(ref_size), C, numItermax=1_000_000) + C = ot.dist(Z.cpu().numpy(), reference.cpu().numpy(), + metric='euclidean') + gamma = ot.emd(ot.unif(Z.shape[0]), ot.unif( + ref_size), C, numItermax=1_000_000) gamma = torch.from_numpy(gamma).float().to(self.device) f = (ref_size * gamma.T @ Z - reference) / np.sqrt(ref_size) task_embeddings.append(f) diff --git a/tests/test_wasserstein.py b/tests/test_wasserstein.py index 8e1ce91..25d3c8a 100644 --- a/tests/test_wasserstein.py +++ b/tests/test_wasserstein.py @@ -4,8 +4,8 @@ import sys import traceback -from data_meta_map.BaseEmbedder import BaseEmbedder -from data_meta_map.WassersteinEmbedder import WassersteinEmbedder +from data_meta_map.baseEmbedder import BaseEmbedder +from data_meta_map.wasserstein_embedder import WassersteinEmbedder # ============================================================================ @@ -14,6 +14,7 @@ class MockDataset(Dataset): """Простой датасет для тестов.""" + def __init__(self, num_samples=100, feature_dim=10, num_classes=5, seed=42): torch.manual_seed(seed) self.data = torch.randn(num_samples, feature_dim) @@ -28,6 +29,7 @@ def __getitem__(self, idx): class MockVectorizedDataset(Dataset): """Датасет с уже векторизованными данными (как текстовые эмбеддинги).""" + def __init__(self, num_samples=100, feature_dim=768, num_classes=3, seed=42): torch.manual_seed(seed) self.data = torch.randn(num_samples, feature_dim) @@ -46,6 +48,7 @@ def __getitem__(self, idx): class ConcreteEmbedder(BaseEmbedder): """Конкретная реализация для тестирования абстрактного класса.""" + def preprocess_dataset(self, data, dataset_id=None): if isinstance(data, Dataset): X = torch.randn(100, 10) @@ -54,7 +57,8 @@ def preprocess_dataset(self, data, dataset_id=None): return data def compute_pairwise_distances(self, datasets, symmetric=True): - n = sum(len(torch.unique(self.preprocess_dataset(d)[1])) for d in datasets) + n = sum( + len(torch.unique(self.preprocess_dataset(d)[1])) for d in datasets) return torch.zeros((n, n), device=self.device) def embed_distance_matrix(self, distance_matrix, emb_dim=None): @@ -64,7 +68,8 @@ def embed_distance_matrix(self, distance_matrix, emb_dim=None): def augment_features(self, data, label_embeddings, dataset_idx, class_offsets): X, Y = self.preprocess_dataset(data) - label_emb = label_embeddings[class_offsets[dataset_idx]:class_offsets[dataset_idx+1]] + label_emb = label_embeddings[class_offsets[dataset_idx] + :class_offsets[dataset_idx+1]] return torch.cat([X, label_emb[Y]], dim=1) @@ -254,7 +259,8 @@ def test_wasserstein_preprocess_vectorized(): print("\n[TEST] WassersteinEmbedder preprocess vectorized data") try: embedder = WassersteinEmbedder(emb_dim=2) - dataset = MockVectorizedDataset(num_samples=100, feature_dim=768, num_classes=5) + dataset = MockVectorizedDataset( + num_samples=100, feature_dim=768, num_classes=5) X, Y = embedder.preprocess_dataset(dataset, dataset_id=2) @@ -320,7 +326,8 @@ def test_wasserstein_bures_distance_identical(): mean2 = torch.tensor([0.0, 0.0]) cov2 = torch.eye(2) - distance = embedder._bures_wasserstein_distance(mean1, cov1, mean2, cov2) + distance = embedder._bures_wasserstein_distance( + mean1, cov1, mean2, cov2) assert torch.allclose(distance, torch.tensor(0.0), atol=1e-2) print(" ✓ Расстояние Бюра для идентичных распределений = 0") @@ -341,7 +348,8 @@ def test_wasserstein_bures_distance_different_means(): mean2 = torch.tensor([1.0, 0.0]) cov2 = torch.eye(2) - distance = embedder._bures_wasserstein_distance(mean1, cov1, mean2, cov2) + distance = embedder._bures_wasserstein_distance( + mean1, cov1, mean2, cov2) assert distance > 0.0 assert torch.allclose(distance, torch.tensor(1.0), atol=1e-5) @@ -379,8 +387,10 @@ def test_wasserstein_pairwise_distances_multiple(): print("\n[TEST] WassersteinEmbedder compute_pairwise_distances multiple datasets") try: embedder = WassersteinEmbedder(emb_dim=2, max_samples=20) - ds1 = MockDataset(num_samples=30, feature_dim=10, num_classes=2, seed=42) - ds2 = MockDataset(num_samples=30, feature_dim=10, num_classes=3, seed=43) + ds1 = MockDataset(num_samples=30, feature_dim=10, + num_classes=2, seed=42) + ds2 = MockDataset(num_samples=30, feature_dim=10, + num_classes=3, seed=43) D = embedder.compute_pairwise_distances([ds1, ds2], symmetric=True) @@ -424,7 +434,8 @@ def test_wasserstein_augment_features(): label_embeddings = torch.randn(3, 3) class_offsets = [0, 3] - Z = embedder.augment_features(dataset, label_embeddings, 0, class_offsets) + Z = embedder.augment_features( + dataset, label_embeddings, 0, class_offsets) assert Z.shape == (50, 13) # 10 features + 3 label embeddings print(" ✓ Аугментация признаков работает корректно")