diff --git a/.gitignore b/.gitignore
index b520deb..726df61 100644
--- a/.gitignore
+++ b/.gitignore
@@ -121,4 +121,9 @@ logs/
*.csv
!.dvc
data
+logs
+pretrain_embeddings
wget-log
+checkpoints
+*.parquet
+*.pt
\ No newline at end of file
diff --git a/benchmarks/pretrain_benchmark/apply_benchmark.ipynb b/benchmarks/pretrain_benchmark/apply_benchmark.ipynb
new file mode 100644
index 0000000..75db554
--- /dev/null
+++ b/benchmarks/pretrain_benchmark/apply_benchmark.ipynb
@@ -0,0 +1,492 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "345ea8f5-7653-4e0f-92ce-2b5d713d24d7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/machenike/bmml/DataMetaMap/src/data_meta_map/wasserstein_embedder.py:8: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
+ " from tqdm.autonotebook import tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "from data_meta_map.task2vec import task2vec\n",
+ "from data_meta_map.models import get_model\n",
+ "from data_meta_map import datasets\n",
+ "from data_meta_map.task2vec import plot_distance_matrix\n",
+ "from data_meta_map.task2vec import Task2Vec\n",
+ "from data_meta_map.task2vec.task_similarity import cosine\n",
+ "\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "ef2aa314-ea65-440b-8db2-45c0ae674e14",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import yaml\n",
+ "from collections import defaultdict\n",
+ "\n",
+ "def get_pretrained_results(dataset_names, path_to_logs):\n",
+ " pretrain2downstream_results = defaultdict(dict)\n",
+ " for pretrain_name in dataset_names:\n",
+ " for downstream_name in dataset_names:\n",
+ " if downstream_name == 'imagenet':\n",
+ " continue\n",
+ " if downstream_name == pretrain_name:\n",
+ " continue\n",
+ " with open(f'{path_to_logs}/{pretrain_name}-{downstream_name}.yaml', 'r') as f:\n",
+ " test_acc = yaml.safe_load(f)['task']['test_acc']\n",
+ " pretrain2downstream_results[pretrain_name][downstream_name] = test_acc\n",
+ " return pretrain2downstream_results\n",
+ "\n",
+ "def get_embedder_results(dataset_names, path_to_pretrained_logs, embedder_func, similarity_func):\n",
+ " embeddings = []\n",
+ "\n",
+ " dataset_list = [datasets.__dict__[name](root='../../data')[0] for name in dataset_names]\n",
+ " for name, dataset in zip(dataset_names, dataset_list):\n",
+ " embeddings.append(embedder_func(dataset))\n",
+ "\n",
+ " def find_closest(dataset_names, embeddings, similarity_metric):\n",
+ " res = dict()\n",
+ " for dataset_name, embed in zip(dataset_names, embeddings):\n",
+ " dists = {name:similarity_metric(embed, other_embed) for name, other_embed in zip(dataset_names, embeddings) if name != dataset_name}\n",
+ " argmax = max(dists.items(), key = lambda x: x[1])[0]\n",
+ " res[dataset_name] = argmax\n",
+ " return res\n",
+ " \n",
+ " dataset2closest = find_closest(dataset_names, embeddings, cosine_similarity)\n",
+ " pretrain2downstream_results = get_pretrained_results(dataset_names, path_to_pretrained_logs)\n",
+ " method_performance = {name: {'accuracy' : pretrain2downstream_results[dataset2closest[name]][name], \n",
+ " 'pretrain': dataset2closest[name]} for name in dataset_names}\n",
+ " return method_performance\n",
+ "\n",
+ "def get_random_baseline(dataset_names, path_to_pretrained_logs):\n",
+ " pretrain2downstream_results = get_pretrained_results(dataset_names + ['imagenet'], path_to_pretrained_logs)\n",
+ " random_performance = {}\n",
+ " for name in dataset_names:\n",
+ " if name == 'imagenet':\n",
+ " continue\n",
+ " choice = name\n",
+ " while choice == name:\n",
+ " choice = np.random.choice(dataset_names)\n",
+ " random_performance[name] = {\n",
+ " 'accuracy': pretrain2downstream_results[choice][name],\n",
+ " 'pretrain': choice\n",
+ " }\n",
+ " return random_performance\n",
+ " \n",
+ "def get_big_pretrain_baseline(dataset_names, path_to_pretrained_logs):\n",
+ " pretrain2downstream_results = get_pretrained_results(dataset_names + ['imagenet'], path_to_pretrained_logs)\n",
+ " big_baseline_performance = {}\n",
+ " for name in dataset_names:\n",
+ " if name == 'imagenet':\n",
+ " continue\n",
+ " big_baseline_performance[name] = {\n",
+ " 'accuracy': pretrain2downstream_results['imagenet'][name],\n",
+ " 'pretrain': 'imagenet'\n",
+ " }\n",
+ " return big_baseline_performance "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b8c01eb9-3a90-4c8b-bfea-d2253b8785ff",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'mnist': 'kmnist', 'cifar10': 'cifar100', 'cifar100': 'cifar10', 'letters': 'kmnist', 'kmnist': 'mnist'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "def cosine_similarity(e0, e1):\n",
+ " return (e0*e1).sum()/np.linalg.norm(e0)/np.linalg.norm(e1)\n",
+ "\n",
+ "def task2vec_resnet_embedder(dataset):\n",
+ " resnet = get_model('resnet18', pretrained=True, num_classes=int(max(dataset.targets)+1)).cuda()\n",
+ " task2vec_embedder = Task2Vec(resnet, skip_layers=6, max_samples=1000)\n",
+ " return task2vec_embedder.embed(dataset)\n",
+ "\n",
+ "dataset_names = ['mnist', 'cifar10', 'cifar100', 'letters', 'kmnist']\n",
+ "task2vec_res = get_embedder_results(\n",
+ " dataset_names,\n",
+ " path_to_pretrained_logs='/home/machenike/bmml/DataMetaMap/logs/pretrain_to_task_logs',\n",
+ " embedder_func=task2vec_resnet_embedder,\n",
+ " similarity_func=cosine_similarity\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "eb480227-2075-44fd-b386-4a0f2b811317",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "random_baseline_res = get_random_baseline(\n",
+ " dataset_names,\n",
+ " path_to_pretrained_logs='/home/machenike/bmml/DataMetaMap/logs/pretrain_to_task_logs',\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "d86598c9-30f7-441b-9ce0-c89d5ae0d13f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "imagenet_baseline_res = get_big_pretrain_baseline(\n",
+ " dataset_names,\n",
+ " path_to_pretrained_logs='/home/machenike/bmml/DataMetaMap/logs/pretrain_to_task_logs',\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "763ba4e2-d5a5-4dc9-8077-ed05f29a696f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "res = {\n",
+ " 'task2vec':{\n",
+ " name:task2vec_res[name]['accuracy'] for name in dataset_names\n",
+ " },\n",
+ " 'random':{\n",
+ " name:random_baseline_res[name]['accuracy'] for name in dataset_names\n",
+ " },\n",
+ " 'big_pretrain':{\n",
+ " name:imagenet_baseline_res[name]['accuracy'] for name in dataset_names\n",
+ " }\n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "cb76994b-3c01-4f04-afd8-f5ab2a8074f2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " mnist | \n",
+ " cifar10 | \n",
+ " cifar100 | \n",
+ " letters | \n",
+ " kmnist | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | task2vec | \n",
+ " 0.9960 | \n",
+ " 0.8798 | \n",
+ " 0.4960 | \n",
+ " 0.932596 | \n",
+ " 0.9957 | \n",
+ "
\n",
+ " \n",
+ " | random | \n",
+ " 0.9815 | \n",
+ " 0.6390 | \n",
+ " 0.3512 | \n",
+ " 0.933510 | \n",
+ " 0.9752 | \n",
+ "
\n",
+ " \n",
+ " | big_pretrain | \n",
+ " 0.9848 | \n",
+ " 0.8834 | \n",
+ " 0.6789 | \n",
+ " 0.919712 | \n",
+ " 0.9851 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " mnist cifar10 cifar100 letters kmnist\n",
+ "task2vec 0.9960 0.8798 0.4960 0.932596 0.9957\n",
+ "random 0.9815 0.6390 0.3512 0.933510 0.9752\n",
+ "big_pretrain 0.9848 0.8834 0.6789 0.919712 0.9851"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "pd.DataFrame(res).T"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/benchmarks/pretrain_benchmark/get_pretrained_to_task.py b/benchmarks/pretrain_benchmark/get_pretrained_to_task.py
new file mode 100644
index 0000000..717c028
--- /dev/null
+++ b/benchmarks/pretrain_benchmark/get_pretrained_to_task.py
@@ -0,0 +1,291 @@
+from data_meta_map.task2vec import task2vec
+from data_meta_map.models import get_model
+from data_meta_map import datasets
+from data_meta_map.task2vec import plot_distance_matrix
+
+import torch
+from torch.utils.data import DataLoader, random_split
+import torch.nn as nn
+import pandas as pd
+from tqdm import tqdm
+
+import pandas as pd
+import torch
+from torch.utils.data import TensorDataset, DataLoader, random_split
+
+import hydra
+import yaml
+from omegaconf import DictConfig, OmegaConf
+
+
+@torch.inference_mode()
+def evaluate(model, loader, device):
+
+ model.eval()
+ correct = 0
+ total = 0
+
+ for x, y in loader:
+
+ x = x.to(device)
+ y = y.to(device)
+
+ logits = model(x)
+ pred = logits.argmax(1)
+
+ correct += (pred == y).sum().item()
+ total += y.size(0)
+
+ return correct / total
+
+
+def train(model: nn.Module,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ optimizer,
+ criterion,
+ best_path: str,
+ num_epochs: int,
+ patience: int,
+ device: str):
+
+ best_acc = 0
+ best_epoch = 0
+ epochs_without_improvement = 0
+ logs = {}
+
+ for epoch in range(num_epochs):
+
+ model.train()
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
+ for x, y in pbar:
+ x = x.to(device)
+ y = y.to(device)
+
+ logits = model(x)
+ loss = criterion(logits, y)
+
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ val_acc = evaluate(model, val_loader, device=device)
+
+ # improvement
+ if val_acc > best_acc:
+
+ best_acc = val_acc
+ epochs_without_improvement = 0
+ best_epoch = epoch
+
+ torch.save(model.state_dict(), best_path)
+ else:
+ epochs_without_improvement += 1
+
+ # early stopping condition
+ if epochs_without_improvement >= patience:
+
+ break
+
+ model.load_state_dict(torch.load(best_path))
+ model.eval()
+ logs = {
+ 'best_val_accuracy': best_acc,
+ 'best_val_epoch': best_epoch
+ }
+ return model, logs
+
+
+@torch.inference_mode()
+def extract_embeddings(model, loader):
+
+ embeddings = []
+ labels = []
+
+ model.eval()
+
+ for x, y in tqdm(loader):
+
+ x = x.cuda()
+
+ feat = model(x)
+ feat = feat.view(feat.size(0), -1)
+
+ embeddings.append(feat.cpu())
+ labels.append(y)
+
+ embeddings = torch.cat(embeddings).numpy()
+ labels = torch.cat(labels).numpy()
+
+ return embeddings, labels
+
+
+def save_parquet(embeds, labels, path):
+
+ df = pd.DataFrame({
+ "hidden_state": embeds.tolist(),
+ "target": labels
+ })
+
+ df.to_parquet(path, index=False)
+
+
+class MLP(nn.Module):
+
+ def __init__(self, input_dim=512, hidden_dim=256, num_classes=10):
+ super().__init__()
+
+ self.net = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+
+ nn.BatchNorm1d(hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(0.15),
+
+ nn.Linear(hidden_dim, num_classes)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def estimate_task_performance(config):
+ train_df = pd.read_parquet(config['paths']['train_parquet'])
+ test_df = pd.read_parquet(config['paths']['test_parquet'])
+
+ X_train = torch.tensor(train_df.hidden_state.tolist(), dtype=torch.float32)
+ y_train = torch.tensor(train_df.target.values, dtype=torch.long)
+
+ X_test = torch.tensor(test_df.hidden_state.tolist(), dtype=torch.float32)
+ y_test = torch.tensor(test_df.target.values, dtype=torch.long)
+
+ VAL_RATIO = config['training_task']['val_split']
+
+ dataset = TensorDataset(X_train, y_train)
+
+ train_size = int((1 - VAL_RATIO) * len(dataset))
+ val_size = len(dataset) - train_size
+
+ train_ds, val_ds = random_split(dataset, [train_size, val_size])
+
+ BATCH_SIZE = config['training_task']['batch_size']
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
+ test_loader = DataLoader(TensorDataset(
+ X_test, y_test), batch_size=BATCH_SIZE)
+
+ criterion = nn.CrossEntropyLoss()
+
+ device = torch.device(config.device)
+ model = MLP(input_dim=512,
+ hidden_dim=config['estimate_network_params']['hidden_dim'],
+ num_classes=max(y_test) + 1).to(device)
+ optimizer = torch.optim.Adam(
+ model.parameters(), lr=config['training']['lr'])
+
+ model, logs = train(
+ model,
+ train_loader,
+ val_loader,
+ optimizer,
+ criterion,
+ best_path=config['paths']['checkpoint_task'],
+ num_epochs=config['training_task']['epochs'],
+ patience=config['training_task']['early_stopping_epochs'],
+ device=device
+ )
+ test_acc = evaluate(model, test_loader, device)
+ logs['test_acc'] = test_acc
+ return logs
+
+
+def apply_pretrain(config):
+ device = torch.device(config.device)
+
+ model = get_model(config['model']['name'],
+ pretrained=config['model']['pretrained']).to(device)
+ train_dataset, test_dataset = datasets.__dict__[
+ config['pretrain_dataset_name']](root=config['paths']['data_dir'])
+
+ device = "cuda"
+
+ VAL_RATIO = config['training']['val_split']
+ BATCH_SIZE = config['training']['batch_size']
+
+ train_size = int((1 - VAL_RATIO) * len(train_dataset))
+ val_size = len(train_dataset) - train_size
+
+ train_ds, val_ds = random_split(train_dataset, [train_size, val_size])
+
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
+ test_loader = DataLoader(
+ test_dataset, batch_size=BATCH_SIZE, shuffle=False)
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.Adam(
+ model.parameters(), lr=config['training']['lr'])
+
+ model, logs = train(
+ model,
+ train_loader,
+ val_loader,
+ optimizer,
+ criterion,
+ best_path=config['paths']['checkpoint'],
+ num_epochs=config['training']['epochs'],
+ patience=config['training']['early_stopping_epochs'],
+ device=device
+ )
+ return logs
+
+
+def inference_task(config):
+ device = torch.device(config.device)
+
+ model = get_model(config['model']['name'],
+ pretrained=config['model']['pretrained']).to(device)
+
+ if not config['use_basic_model']:
+ best_path = config['paths']['checkpoint']
+ model.load_state_dict(torch.load(best_path))
+ model.eval()
+
+ embedder = torch.nn.Sequential(*list(model.children())[:-1]).to(device)
+
+ train_dataset, test_dataset = datasets.__dict__[
+ config['task_dataset_name']](root=config['paths']['data_dir'])
+
+ BATCH_SIZE = config['training']['batch_size']
+ test_loader = DataLoader(
+ test_dataset, batch_size=BATCH_SIZE, shuffle=False)
+ train_full_loader = DataLoader(
+ train_dataset, batch_size=BATCH_SIZE, shuffle=False)
+
+ train_embeds, train_labels = extract_embeddings(
+ embedder, train_full_loader)
+ test_embeds, test_labels = extract_embeddings(embedder, test_loader)
+
+ save_parquet(train_embeds, train_labels, config['paths']['train_parquet'])
+ save_parquet(test_embeds, test_labels, config['paths']['test_parquet'])
+
+
+@hydra.main(config_path=".", config_name="config", version_base=None)
+def main(config: DictConfig):
+ _, pretrain_logs = apply_pretrain(
+ config) if config['need_pretrain'] else None, {}
+ task_logs = {}
+ if not config['pretrain_only']:
+ inference_task(config)
+ task_logs = estimate_task_performance(config)
+ logs = {
+ 'pretrain': pretrain_logs,
+ 'task': task_logs,
+ 'pretrain_to_task_config': OmegaConf.to_container(config, resolve=True)
+ }
+ with open(config['paths']['save_logs'], 'w') as f:
+ yaml.safe_dump(logs, f)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/configs/pretrain_to_task.yaml b/configs/pretrain_to_task.yaml
new file mode 100644
index 0000000..4b37156
--- /dev/null
+++ b/configs/pretrain_to_task.yaml
@@ -0,0 +1,39 @@
+seed: 42
+device: cuda
+
+model:
+ name: resnet18 # resnet18 | resnet34
+ pretrained: true
+
+pretrain_only: false
+need_pretrain: true
+use_basic_model: false
+
+pretrain_dataset_name: stl10 # mnist | cifar10 | cifar100 | letters | kmnist
+task_dataset_name: mnist # mnist | cifar10 | cifar100 | letters | kmnist
+
+training:
+ batch_size: 128
+ epochs: 20
+ lr: 0.0003
+ val_split: 0.2
+ early_stopping_epochs: 5
+
+training_task:
+ batch_size: 128
+ epochs: 20
+ lr: 0.0003
+ val_split: 0.2
+ early_stopping_epochs: 10
+
+estimate_network_params:
+ hidden_dim: 256
+ num_classes: 10
+
+paths:
+ data_dir: /home/machenike/bmml/DataMetaMap/data
+ train_parquet: /home/machenike/bmml/DataMetaMap/pretrain_embeddings/train_embeds.parquet
+ test_parquet: /home/machenike/bmml/DataMetaMap/pretrain_embeddings/test_embeds.parquet
+ checkpoint: /home/machenike/bmml/DataMetaMap/checkpoints/${pretrain_dataset_name}_pretrain.pt
+ checkpoint_task: /home/machenike/bmml/DataMetaMap/checkpoints/best_model.pt
+ save_logs: /home/machenike/bmml/DataMetaMap/logs/pretrain_to_task_logs/${pretrain_dataset_name}-${task_dataset_name}.yaml
\ No newline at end of file
diff --git a/demo/task2vec/simple_example.ipynb b/demo/task2vec/simple_example.ipynb
index d5b84ea..ab5be76 100644
--- a/demo/task2vec/simple_example.ipynb
+++ b/demo/task2vec/simple_example.ipynb
@@ -2,18 +2,19 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
- "id": "732cfc36-76c6-4b8a-b4fb-1e67c9e48902",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": 3,
+ "execution_count": 1,
"id": "7c7f2bed-6353-4b29-aeed-b08cc9835a1b",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/machenike/bmml/DataMetaMap/src/data_meta_map/wasserstein_embedder.py:8: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
+ " from tqdm.autonotebook import tqdm\n"
+ ]
+ }
+ ],
"source": [
"from data_meta_map.task2vec import task2vec\n",
"from data_meta_map.models import get_model\n",
@@ -21,10 +22,274 @@
"from data_meta_map.task2vec import plot_distance_matrix"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "2f3376cd-a564-4f99-8966-ad6d89b9ecaa",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Embedding mnist\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Embedding cifar10\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Embedding cifar100\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Embedding letters\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Caching features: 0%| | 0/14 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Fitting classifier: 0%| | 0/10 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Computing Fisher: 0%| | 0/156 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "dataset_names = ('mnist', 'cifar10', 'cifar100', 'letters')\n",
+ "dataset_list = [datasets.__dict__[name](root='../../data')[0] for name in dataset_names] \n",
+ "\n",
+ "embeddings = []\n",
+ "for name, dataset in zip(dataset_names, dataset_list):\n",
+ " print(f\"Embedding {name}\")\n",
+ " probe_network = get_model('resnet18', pretrained=True, num_classes=int(max(dataset.targets)+1)).cuda()\n",
+ " embeddings.append(task2vec(probe_network, dataset, skip_layers=6, max_samples=1000))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "d5ebc361-2980-4257-939b-0d8fed942143",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.15588998794555664"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "def cosine_similarity(e0, e1):\n",
+ " e0_new = e0/(e0 + e1)\n",
+ " e1_new = e1/(e0 + e1)\n",
+ " e0 = e0_new\n",
+ " e1 = e1_new\n",
+ " return (e0*e1).sum()/np.linalg.norm(e0)/np.linalg.norm(e1)\n",
+ "\n",
+ "1-cosine_similarity(embeddings[0].hessian/embeddings[0].scale, embeddings[1].hessian/embeddings[1].scale)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "444da2db-d33d-469c-b5e9-4d34c8df45b5",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAPdCAYAAACXzguGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAABFhUlEQVR4nO3de3SV1Z3w8d8JkAREEFEIXgCLKFUREIVC7aCWCoy9oIyCnVcodawdwVbwmlZBaEfEWkGqI7WOqG+rgtW6HOugNEttB1E7eMEqomNxeDsQELuUiyWRcN4/jKkpRDk52RxJPp9ZZ5U8z3O2+8TQ6Tf7uWSy2Ww2AAAAgCZXVOgJAAAAQHMlugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIq3zefOXis5sqnkU1JId9xV6CgAAADRDeUV3plWrppoHAAAANDt5RXdknJ0OAAAADclvpbso01TzAAAAgGYnv5Vup5cDNCvl5eWxbdu2Qk8DgERKS0tj1qxZhZ4GtCh5nl7ecle6b7755vjRj34UlZWV0a9fv/jJT34SgwYN2uWxL7/8ckybNi2WL18e//M//xNz5syJiy66qN4xV199dcyYMaPetiOPPDJeffXVVB8BYCfbtm2LOXPmFHoaACQyZcqUQk8BWhw3UmuEhQsXxtSpU2P+/PkxePDgmDt3bowYMSJWrVoVXbp02en49957Lz7zmc/EmWee+bH/RXf00UfHb37zm7qvW7fO73ciAAAAFFZ+VVfUMm+kdsMNN8R5550XEydOjIiI+fPnx69//eu4/fbb44orrtjp+BNOOCFOOOGEiIhd7v9Q69ato6ysLM2kAQAA2OPyvKa7eUR3VVVVVFVV1dtWUlISJSUlOx1bXV0dy5cvj/Ly8rptRUVFMXz48Fi2bFle83j99dfjoIMOitLS0hgyZEjMmjUrunfvnteYAAAAFE5e1ZzJFDWL16xZs6Jjx471Xg3dYGLjxo1RU1MTXbt2rbe9a9euUVlZ2ejv5eDBg+OOO+6IxYsXxy233BKrV6+OL3zhC7F58+ZGjwkAAEBh5Xl6efO4kVp5eXlMnTq13rZdrXKnNGrUqLo/H3vssTF48ODo0aNHLFq0KM4999w9OhcAAACahmu6o+FTyXflgAMOiFatWsX69evrbV+/fn2TXo+93377xRFHHBH//d//3WRjAgAAsGflV81FRc3jlYPi4uIYOHBgVFRU1G3bsWNHVFRUxJAhQ/L6dn7Uli1b4o033ohu3bo12ZgAAADsWVa6G2Hq1KkxYcKEOP7442PQoEExd+7c2Lp1a93dzMePHx8HH3xw3XXh1dXV8corr9T9+X//93/jhRdeiPbt28fhhx8eERGXXHJJfOUrX4kePXrE2rVrY/r06dGqVas4++yzC/MhAQAAyFt+0Z1pHtd052rs2LHx1ltvxbRp06KysjL69+8fixcvrru52po1a6LoI7+QWLt2bQwYMKDu6+uvvz6uv/76GDZsWDzxxBMREfGnP/0pzj777Hj77bfjwAMPjBNPPDGefvrpOPDAA/foZwMAAKDpeGRYI02ePDkmT568y30fhvSHevbsGdls9mPHu/fee5tqagAAAHxK5LnS3XKjGwAAAD6JR4YBAABAIm6kBgAAAIm4kRoAAAAk4kZqAAAAkEhe0Z210g0AAAANyvOa7iaaBQAAADRDbqQGAAAAibiRGgAAACSS3zXdrUQ3AAAANMRKNwAAACSS30q3a7oBAACgQe5eDgAAAIk4vRwAAAASyfP0ctENAAAADbHSDQAAAIl4ZBgAAAAkkl90a24AAABoUJ53L1fdAAAA0JA8V7pFNwAAADTEc7oBAAAgEY8MAwAAgETcSA0AAAASsdINAAAAieR3TbfmBgAAgAblt9LdSnUDAABAQ6x0AwAAQCJupAYAAACJuJEaAAAAJOL0cgAAAEjESjcAAAAk4ppuAAAASMTp5QAAAJBIns/pbqppAAAAQPPj9HIAAABIxI3UAAAAIBHXdAMAAEAiea50N9U0AAAAoPlxTTcAAAAk4vRyAAAASMTp5QAAAJBInqeXW+oGAACAhljpBgAAgERc0w0AAACJWOkGAACARKx0AwAAQCJ5rVVnM83j1Rg333xz9OzZM0pLS2Pw4MHx7LPPNnjsyy+/HGPGjImePXtGJpOJuXPn5j0mAAAAn375RXdR83jlauHChTF16tSYPn16PPfcc9GvX78YMWJEbNiwYZfHv/fee/GZz3wmrr322igrK2uSMQEAAPj0y++q7EwzeeXohhtuiPPOOy8mTpwYRx11VMyfPz/atWsXt99++y6PP+GEE+JHP/pRjBs3LkpKSppkTAAAAD79rHQXRVRVVcWmTZvqvaqqqnb5maurq2P58uUxfPjwv34Ti4pi+PDhsWzZskZ9H1OMCQAAQOFZ6c5EzJo1Kzp27FjvNWvWrF1+5I0bN0ZNTU107dq13vauXbtGZWVlI76JacYEAACg8PJ7ZFhTzaLAysvLY+rUqfW2NXQaOAAAAOyu/B4Z1kye011SUrLbkX3AAQdEq1atYv369fW2r1+/vsGbpBViTAAAAArPI8NyvJFacXFxDBw4MCoqKuq27dixIyoqKmLIkCGN+j6mGBMAAIDCy+/08may0p2rqVOnxoQJE+L444+PQYMGxdy5c2Pr1q0xceLEiIgYP358HHzwwXXXhVdXV8crr7xS9+f//d//jRdeeCHat28fhx9++G6NCQAAwN4nv9PLG/G4reZg7Nix8dZbb8W0adOisrIy+vfvH4sXL667EdqaNWuiqOivv5FYu3ZtDBgwoO7r66+/Pq6//voYNmxYPPHEE7s1JgAAAHsfK92NNHny5Jg8efIu930Y0h/q2bNnZLOffNu5jxsTAACAvY+VbgAAAEhEdAMAAEAiTi8HAACARPKLbivdAAAA0KD8Ti+30g0AAAANstINAAAAibiRGgAAACRipRsAAAASsdINAAAAiXhkGAAAACRipRsAAAASsdINAAAAieS50p1tomkAAABA82OlGwAAABJxTTcAAAAk4jndAAAAkEh+K91OLwcAAIAGuZEaAAAAJOL0cgAAAEjEjdQAAAAgESvdAAAAkEieN1JzTTcAAAA0xOnlAAAAkIiVbgAAAEjESjcAAAAkkt+N1Kx0AwAAQIOsdAMAAEAieUV3RnQDAABAg9xIDQAAABJxejkAAAAkkmd0W+kGAACAhuR5TbfoBgAgN+Xl5bFt27ZCT6NF2rBhQ0yZMqXQ02iRSktLY9asWYWeBgWQX3S7phsAgBxt27Yt5syZU+hpwB7llx0tl2u6AQAAIBGPDAMAAIBE8jy9fEdTzQMAAACaHSvdAAAAkIgbqQEAAEAiHhkGAAAAieQX3U01CwAAAGiG8oruIjdSAwAAgAa5kRoAAAAkkld0t7LSDQAAAA1yIzUAAABIJM+VbtENAAAADbHSDQAAAInkd/fyaLnRffPNN8ePfvSjqKysjH79+sVPfvKTGDRoUIPH33fffXHVVVfFm2++Gb17947Zs2fH3//939ft/8Y3vhF33nlnvfeMGDEiFi9enOwzwKdVeXl5bNu2rdDTaJE2bNgQU6ZMKfQ0WpzS0tKYNWtWoacBACTg9PJGWLhwYUydOjXmz58fgwcPjrlz58aIESNi1apV0aVLl52Of+qpp+Lss8+OWbNmxZe//OW4++67Y/To0fHcc8/FMcccU3fcyJEjY8GCBXVfl5SU7JHPA58227Ztizlz5hR6GrDH+EUHADRfRfm8OZPJNotXrm644YY477zzYuLEiXHUUUfF/Pnzo127dnH77bfv8vgbb7wxRo4cGZdeeml89rOfjR/84Adx3HHHxU033VTvuJKSkigrK6t7derUqVH/XgAAAPh0yCu6WxXtaBavqqqq2LRpU71XVVXVLj9zdXV1LF++PIYPH/7Xb2JRUQwfPjyWLVu2y/csW7as3vERH5w6/rfHP/HEE9GlS5c48sgj45//+Z/j7bffzudfDwAAAAWWV3QXZbLN4jVr1qzo2LFjvVdD19Zt3LgxampqomvXrvW2d+3aNSorK3f5nsrKyk88fuTIkXHXXXdFRUVFzJ49O5588skYNWpU1NTU5POvCAAAgALK75ruzI6mmkdBlZeXx9SpU+tt29PXU48bN67uz3379o1jjz02evXqFU888UR88Ytf3KNzAQAAoGnkd/fyZvLIsJKSkt2O7AMOOCBatWoV69evr7d9/fr1UVZWtsv3lJWV5XR8RMRnPvOZOOCAA+K///u/RTcAAMBeyunlOf7ioLi4OAYOHBgVFRV123bs2BEVFRUxZMiQXb5nyJAh9Y6PiFiyZEmDx0dE/OlPf4q33347unXrltP8AAAA+PTIa6W7dVHzOL08V1OnTo0JEybE8ccfH4MGDYq5c+fG1q1bY+LEiRERMX78+Dj44IPrrgv/7ne/G8OGDYsf//jHcdppp8W9994b//Vf/xW33nprRERs2bIlZsyYEWPGjImysrJ444034rLLLovDDz88RowYUbDPCQAAQH7yO708msfp5bkaO3ZsvPXWWzFt2rSorKyM/v37x+LFi+tulrZmzZooKvrrSQRDhw6Nu+++O6688sr43ve+F717944HH3yw7hndrVq1ihUrVsSdd94Z77zzThx00EFx6qmnxg9+8APP6gYAANiLWelupMmTJ8fkyZN3ue+JJ57YaduZZ54ZZ5555i6Pb9u2bTz66KNNOT0AAAA+BdxIDQAAABLJb6W7mTwyDAAAAFKw0g0AAACJ5BndVroBAACgIU4vBwAAgEScXg4AAACJ5BXdbYpqmmoeAAAA0Ozkt9IdVroBAACgIW6kBgAAAInkd3q56AYAAIAGWekGAACARPK8kZroBgAAgIbkeSM10Q0AAAANyfOabo8MAwAAgIbkeU23R4YBAABAQ/KK7lZupAYAAAANcno5AAAAJOKRYQAAAJCIlW4AAABIJL9rusON1AAAAKAhea50b2+qeQAAAECz45FhAAAAkEiep5e7kRoAAAA0xI3UAAAAIJH8Vro9MgwAAAAaZKUbAAAAEnFNNwAAACSS593LRTcAAAA0JK/oLnZ6OQAAADQov5Vup5cDAABAg/Jc6d7eVPMAAACAZifPle5sU80DAAAAmh3XdAMAAEAirukGAACARPJ7TrdHhgEAAECD8oruNqIbAAAAGpTfSrfTywEAAKBBVroBAAAgkTxXuj0yDAAAABqS50q36AYAAICGWOkGAACARPJ8TjcAAADQkLyiuziTaap5AAAAQLNjpRsAAAASyfNGarIbAAAAGpLnjdScXg4AAAANyWupuk2mqFm8GuPmm2+Onj17RmlpaQwePDieffbZjz3+vvvuiz59+kRpaWn07ds3HnnkkXr7s9lsTJs2Lbp16xZt27aN4cOHx+uvv96ouQEAAPDpkFd0FzWT/8vVwoULY+rUqTF9+vR47rnnol+/fjFixIjYsGHDLo9/6qmn4uyzz45zzz03nn/++Rg9enSMHj06/vCHP9Qdc91118W8efNi/vz58cwzz8Q+++wTI0aMiG3btjX63w8AAACFlWd0Z5rFK1c33HBDnHfeeTFx4sQ46qijYv78+dGuXbu4/fbbd3n8jTfeGCNHjoxLL700PvvZz8YPfvCDOO644+Kmm26KiA9WuefOnRtXXnllfO1rX4tjjz027rrrrli7dm08+OCD+fwrAgAAoIDyPL28dbN4VVVVxaZNm+q9qqqqdvmZq6urY/ny5TF8+PC/fhOLimL48OGxbNmyXb5n2bJl9Y6PiBgxYkTd8atXr47Kysp6x3Ts2DEGDx7c4JgAAAB8+uX3yLCy15pqHgU16+qrY8aMGfW2TZ8+Pa6++uqdjt24cWPU1NRE165d623v2rVrvPrqq7scv7KycpfHV1ZW1u3/cFtDxwAAALD3ySu6m4vy8vKYOnVqvW0lJSUFmg0AAADNheiODwJ7dyP7gAMOiFatWsX69evrbV+/fn2UlZXt8j1lZWUfe/yH/7l+/fro1q1bvWP69++/ux8DAACAT5m8ruluiYqLi2PgwIFRUVFRt23Hjh1RUVERQ4YM2eV7hgwZUu/4iIglS5bUHX/YYYdFWVlZvWM2bdoUzzzzTINjAgAA8OmXyWaz2UJPYm+zcOHCmDBhQvz0pz+NQYMGxdy5c2PRokXx6quvRteuXWP8+PFx8MEHx6xZsyLig0eGDRs2LK699to47bTT4t57741rrrkmnnvuuTjmmGMiImL27Nlx7bXXxp133hmHHXZYXHXVVbFixYp45ZVXorS0tJAft8XaUXlEoacAe1y/Z88u9BRgj+p0W/tCTwH2qNLFzxd6CrBHPVZ9d6Gn4PTyxhg7dmy89dZbMW3atKisrIz+/fvH4sWL626EtmbNmigq+utJBEOHDo277747rrzyyvje974XvXv3jgcffLAuuCMiLrvssti6dWt861vfinfeeSdOPPHEWLx4seAGAADYi1nphgZY6aYlstJNS2Olm5bGSjctzadhpds13QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBEWhd6AgAAAPBp8eyzz8ayZcuisrIyIiLKyspiyJAhMWjQoEaNJ7oBAABolqqqqqKqqqretpKSkigpKdnp2A0bNsSYMWNi6dKl0b179+jatWtERKxfvz6mTJkSn//85+P++++PLl265DQH0c0ulZeXx7Zt2wo9jYL68eWFngEAAJCPWbNmxYwZM+ptmz59elx99dU7HXvBBRdETU1NrFy5Mo488sh6+1atWhXf/OY3Y9KkSXHfffflNIdMNpvN5jxzmr0pU6bEnDlzCj2NgtpReUShpwB7XL9nzy70FGCP6nRb+0JPAfao0sXPF3oKsEf9++YFu73Sve+++8Zvf/vbGDBgwC7HWr58eZx00kmxefPmnOZgpRsAAIBmqaHAbujYTZs2Nbh/8+bNuz3WR7l7OQAAAC3e2LFjY8KECfGrX/2qXnxv2rQpfvWrX8XEiRPj7LNzPyvQSjcAAAAt3g033BA7duyIcePGxfbt26O4uDgiIqqrq6N169Zx7rnnxvXXX5/zuKIbAACAFq+kpCRuueWWmD17dixfvrzeI8MGDhwYHTp0aNS4ohsAAABqdejQIU4++eQmG8813QAAAPAJ1q9fHzNnzsz5faIbAAAAPkFlZeVOz/zeHU4vBwAAoMVbsWLFx+5ftWpVo8YV3QAAALR4/fv3j0wmE9lsdqd9H27PZDI5jyu6AQAAaPH233//uO666+KLX/ziLve//PLL8ZWvfCXncUU3AAAALd7AgQNj7dq10aNHj13uf+edd3a5Cv5JRDcAAAAt3re//e3YunVrg/u7d+8eCxYsyHlc0Q0AAECLd/rpp3/s/k6dOsWECRNyHtcjwwAAACAi3n///ejVq1esXLmyycYU3QAAABARbdq0iW3btjXpmKIbAAAAak2aNClmz54d27dvb5LxXNMNAAAAtX7/+99HRUVFPPbYY9G3b9/YZ5996u1/4IEHchpPdAMAAECt/fbbL8aMGdNk44luAAAAqNWYx4J9HNd0AwAAQCJWugEAAOAjfvnLX8aiRYtizZo1UV1dXW/fc889l9NYVroBAACg1rx582LixInRtWvXeP7552PQoEHRuXPn+OMf/xijRo3KeTzRDQAAALX+9V//NW699db4yU9+EsXFxXHZZZfFkiVL4jvf+U68++67OY8nugEAAKDWmjVrYujQoRER0bZt29i8eXNERJxzzjlxzz335Dye6AYAAIBaZWVl8ec//zkiIrp37x5PP/10RESsXr06stlszuOJbgAAAKh1yimnxEMPPRQRERMnTowpU6bEl770pRg7dmycfvrpOY/n7uUAAABQ69Zbb40dO3ZERMSkSZOic+fO8dRTT8VXv/rVOP/883Mez0o3AAAALdoZZ5wRmzZtioiIn//851FTU1O3b9y4cTFv3ry48MILo7i4OOexRTcAAAAt2sMPPxxbt26NiA9OKW/MXcob4vRyAAAAWrQ+ffpEeXl5nHzyyZHNZmPRokXRoUOHXR47fvz4nMYW3QAAALRo8+fPj6lTp8avf/3ryGQyceWVV0Ymk9npuEwmI7oBAAAgF0OHDq17NFhRUVG89tpr0aVLlyYZ2zXdAAAAUGv16tVx4IEHNtl4VroBAABo0VasWBHHHHNMFBUVxbvvvhsvvfRSg8cee+yxOY0tugEAAGjR+vfvH5WVldGlS5fo379/ZDKZyGazdfs//DqTydR7nNjuEN0AAAC0aB89pXz16tVNOrZrugEAAGjRevToUXe38rvvvjsqKiqiR48e9V4VFRVx77335jy26AYAAIBaP/3pT6NPnz47bT/66KNj/vz5OY8nugEAAKBWZWVldOvWbaftBx54YKxbty7n8UQ3AAAA1Dr00ENj6dKlO21funRpHHTQQTmP50ZqAAAAUOu8886Liy66KN5///045ZRTIiKioqIiLrvssrj44otzHk90AwAAQK1LL7003n777bjggguiuro6IiJKS0vj8ssvj/Ly8pzHE90AAABQK5PJxOzZs+Oqq66KlStXRtu2baN3795RUlLSqPFENwAAAPyN9u3bxwknnJD3OG6kBgAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAIm0bsrBysvLY9u2bU05JAWyYcOGQk+h4Po9e3ahpwB73IuD7in0FGCPOvHn5xd6CrBHZbe/X+gpQIvTpNG9bdu2mDNnTlMOSYFMmTKl0FMAAADY6zm9HAAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJBI60JPAAAAAD4ttm/fHi+//HJUVlZGRERZWVkcddRR0aZNm0aNJ7oBAABolqqqqqKqqqretpKSkigpKdnp2B07dsS0adPi5ptvjnfffbfevo4dO8bkyZNjxowZUVSU2wnjoptdKi0tjSlTphR6GoV1codCzwAAAMjDrFmzYsaMGfW2TZ8+Pa6++uqdjr3iiivijjvuiGuvvTZGjBgRXbt2jYiI9evXx2OPPRZXXXVVVFdXx+zZs3OaQyabzWYb/Qn+xpQpU2LOnDlNNRwUVN+Hphd6CrDHvTjonkJPAfaoE79zfqGnAHvUPr98utBTgD3q4b/8fLdXusvKyuLOO++MESNG7HKsRx99NMaPHx/r16/PaQ5WugEAAGiWGgrsXdm8eXMcdNBBDe7v1q1bbN26Nec5uHs5AAAALd5JJ50Ul1xySWzcuHGnfRs3bozLL788TjrppJzHtdINAABAizd//vz4+7//++jWrVv07du33jXdL730Uhx11FHx8MMP5zyu6AYAAKDFO/TQQ+PFF1+MRx99NJ5++um6R4YNGjQorrnmmjj11FNzvnN5hOgGAACAiIgoKiqKUaNGxahRo5psTNENAAAAtZ599tlYtmxZ3Up3WVlZDB06NE444YRGjSe6AQAAaPE2bNgQY8aMiaVLl0b37t3rXdM9ZcqU+PznPx/3339/dOnSJadx3b0cAACAFu+CCy6ImpqaWLlyZbz55pvxzDPPxDPPPBNvvvlmrFy5Mnbs2BGTJk3KeVwr3QAAALR4jz76aPz2t7+NI488cqd9Rx55ZMybN69Rjwyz0g0AAECLV1JSEps2bWpw/+bNm6OkpCTncUU3AAAALd7YsWNjwoQJ8atf/apefG/atCl+9atfxcSJE+Pss8/OeVynlwMAANDi3XDDDbFjx44YN25cbN++PYqLiyMiorq6Olq3bh3nnntuXH/99TmPK7oBAABo8UpKSuKWW26J2bNnx/Lly+s9MmzgwIHRoUOHRo0rugEAAKBWhw4d4uSTT26y8VzTDQAAAJ9g/fr1MXPmzJzfJ7oBAADgE1RWVsaMGTNyfp/TywEAAGjxVqxY8bH7V61a1ahxRTcAAAAtXv/+/SOTyUQ2m91p34fbM5lMzuOKbgAAAFq8/fffP6677rr44he/uMv9L7/8cnzlK1/JeVzRDQAAQIs3cODAWLt2bfTo0WOX+995551droJ/EtENAABAi/ftb387tm7d2uD+7t27x4IFC3IeV3QDAADQ4p1++ukfu79Tp04xYcKEnMf1yDAAAACIiPfffz969eoVK1eubLIxRTcAAABERJs2bWLbtm1NOqboBgAAgFqTJk2K2bNnx/bt25tkPNd0AwAAQK3f//73UVFREY899lj07ds39tlnn3r7H3jggZzGE90AAABQa7/99osxY8Y02XiiGwAAAGo15rFgH8c13QAAAJCIlW4AAAD4iF/+8pexaNGiWLNmTVRXV9fb99xzz+U0lpVuAAAAqDVv3ryYOHFidO3aNZ5//vkYNGhQdO7cOf74xz/GqFGjch5PdAMAAECtf/3Xf41bb701fvKTn0RxcXFcdtllsWTJkvjOd74T7777bs7jiW4AAACotWbNmhg6dGhERLRt2zY2b94cERHnnHNO3HPPPTmPJ7oBAACgVllZWfz5z3+OiIju3bvH008/HRERq1evjmw2m/N4ohsAAABqnXLKKfHQQw9FRMTEiRNjypQp8aUvfSnGjh0bp59+es7juXs5AAAA1Lr11ltjx44dERExadKk6Ny5czz11FPx1a9+Nc4///ycx7PSDQAAQIt2xhlnxKZNmyIi4uc//3nU1NTU7Rs3blzMmzcvLrzwwiguLs55bNENAABAi/bwww/H1q1bI+KDU8obc5fyhji9HAAAgBatT58+UV5eHieffHJks9lYtGhRdOjQYZfHjh8/PqexRTcAAAAt2vz582Pq1Knx61//OjKZTFx55ZWRyWR2Oi6TyYhuAAAAyMXQoUPrHg1WVFQUr732WnTp0qVJxnZNNwAAANRavXp1HHjggU02npVuAAAAWrQVK1bEMcccE0VFRfHuu+/GSy+91OCxxx57bE5ji24AAABatP79+0dlZWV06dIl+vfvH5lMJrLZbN3+D7/OZDL1Hie2O0Q3AAAALdpHTylfvXp1k47tmm4AAABatB49etTdrfzuu++OioqK6NGjR71XRUVF3HvvvTmPLboBAACg1k9/+tPo06fPTtuPPvromD9/fs7jiW4AAACoVVlZGd26ddtp+4EHHhjr1q3LeTzRDQAAALUOPfTQWLp06U7bly5dGgcddFDO47mRGgAAANQ677zz4qKLLor3338/TjnllIiIqKioiMsuuywuvvjinMcT3QAAAFDr0ksvjbfffjsuuOCCqK6ujoiI0tLSuPzyy6O8vDzn8UQ3AAAA1MpkMjF79uy46qqrYuXKldG2bdvo3bt3lJSUNGo80Q0AAAB/o3379nHCCSfkPY4bqQEAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEikdaEnAJ9WnW5rX+gpwB534s/PL/QUYI/6z3k/LfQUYI/6zKh/KvQUoMWx0g0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEhHdAAAAkIjoBgAAgERENwAAACQiugEAACAR0Q0AAACJiG4AAABIRHQDAABAIqIbAAAAEmld6AkAAABAClVVVVFVVVVvW0lJSZSUlOQ0zjvvvBP77bdfo+Ygupux8vLy2LZtW6GnsRc7qNATAAAA8jBr1qyYMWNGvW3Tp0+Pq6++usH3zJ49O3r27Bljx46NiIizzjor7r///igrK4tHHnkk+vXrl9McRHcztm3btpgzZ06hp7HX+ruv/qjQUwAAAPJQXl4eU6dOrbftk1a558+fH7/4xS8iImLJkiWxZMmS+I//+I9YtGhRXHrppfHYY4/lNAfRDQAAQLPUmFPJKysr49BDD42IiIcffjjOOuusOPXUU6Nnz54xePDgnOfgRmoAAABQq1OnTvH//t//i4iIxYsXx/DhwyMiIpvNRk1NTc7jWekGAACAWmeccUZ8/etfj969e8fbb78do0aNioiI559/Pg4//PCcxxPdAAAAUGvOnDlx2GGHxZo1a+K6666L9u3bR0TEunXr4oILLsh5PNENAAAAEfH+++/H+eefH1dddVUcdthh9fZNmTKlUWO6phsAAAAiok2bNnH//fc36ZiiGwAAAGqNHj06HnzwwSYbz+nlAAAAUKt3794xc+bMWLp0aQwcODD22Wefevu/853v5DSe6AYAAIBa//Zv/xb77bdfLF++PJYvX15vXyaTEd0AAADQWKtXr27S8VzTDQAAAH+juro6Vq1aFdu3b89rHNENAAAAtd57770499xzo127dnH00UfHmjVrIiLiwgsvjGuvvTbn8UQ3AAAA1CovL48XX3wxnnjiiSgtLa3bPnz48Fi4cGHO47mmGwAAAGo9+OCDsXDhwvjc5z4XmUymbvvRRx8db7zxRs7jWekGAACAWm+99VZ06dJlp+1bt26tF+G7S3QDAABAreOPPz5+/etf1339YWjfdtttMWTIkJzHc3o5AAAA1Lrmmmti1KhR8corr8T27dvjxhtvjFdeeSWeeuqpePLJJ3Mez0o3AAAA1DrxxBPjhRdeiO3bt0ffvn3jscceiy5dusSyZcti4MCBOY9npRsAAAA+olevXvGzn/2sScay0g0AAAC1WrVqFRs2bNhp+9tvvx2tWrXKeTzRDQAAALWy2ewut1dVVUVxcXHO4zm9HAAAgBZv3rx5EfHB3cpvu+22aN++fd2+mpqa+O1vfxt9+vTJeVzRDQAAQIs3Z86ciPhgpXv+/Pn1TiUvLi6Onj17xvz583MeV3QDAADQ4q1evToiIk4++eR44IEHolOnTk0yrmu6AQAAoNbJJ58cJSUlO23/y1/+EjNnzsx5PNENAAAAtWbMmBFbtmzZaft7770XM2bMyHk80Q0AAAC1stlsZDKZnba/+OKLsf/+++c8nmu6AQAAaPE6deoUmUwmMplMHHHEEfXCu6amJrZs2RLf/va3cx5XdAMAANDizZ07N7LZbHzzm9+MGTNmRMeOHev2fXj38iFDhuQ8rugGAACgxZswYUJERBx22GHx+c9/Plq3bppcdk03AAAA1Bo2bFj8z//8T1x55ZVx9tlnx4YNGyIi4j/+4z/i5Zdfznk80Q0AAAC1nnzyyejbt28888wz8cADD9TdyfzFF1+M6dOn5zye6AYAAIBaV1xxRfzwhz+MJUuWRHFxcd32U045JZ5++umcxxPdAAAAUOull16K008/faftXbp0iY0bN+Y8nugGAACAWvvtt1+sW7dup+3PP/98HHzwwTmPJ7oBAACg1rhx4+Lyyy+PysrKyGQysWPHjli6dGlccsklMX78+JzHE90AAABQ65prrok+ffrEoYceGlu2bImjjjoqvvCFL8TQoUPjyiuvzHk8z+kGAACAWsXFxfGzn/0spk2bFi+99FJs2bIlBgwYEL17927UeKIbAACAFm3q1Kkfu/+jdy2/4YYbchpbdAMAANCiPf/887t1XCaTyXls0Q0AAECL9vjjjycb243UAAAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAirZtysNLS0pgyZUpTDkkeNmzYUOgp7NVKFz9f6CnAHpfd/n6hpwB71GdG/VOhpwB71B9H3VboKcAedlmhJ9C00T1r1qymHI48+QUIAABAYTm9HAAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAADUuuuuu6Kqqmqn7dXV1XHXXXflPJ7oBgAAoFmqqqqKTZs21XvtKqg/auLEifHuu+/utH3z5s0xceLEnOfQOud3sNcoLS2NKVOmFHoaAAAABTFr1qyYMWNGvW3Tp0+Pq6++usH3ZLPZyGQyO23/05/+FB07dsx5DplsNpvN+V3QApxa/PVCTwH2uOz29ws9BdijXv+3Ewo9Bdij/jjqtkJPAfao9zu9tNPKdklJSZSUlOx07IABAyKTycSLL74YRx99dLRu/dc16pqamli9enWMHDkyFi1alNMcrHQDAADQLDUU2LsyevToiIh44YUXYsSIEdG+ffu6fcXFxdGzZ88YM2ZMznMQ3QAAALR406dPj4iInj17xrhx43Y71j+JG6kBAABArVNOOSXeeuutuq+fffbZuOiii+LWW29t1HiiGwAAAGp9/etfj8cffzwiIiorK2P48OHx7LPPxve///2YOXNmzuOJbgAAAKj1hz/8IQYNGhQREYsWLYq+ffvGU089Fb/4xS/ijjvuyHk80Q0AAAC13n///brruX/zm9/EV7/61YiI6NOnT6xbty7n8UQ3AAAA1Dr66KNj/vz58bvf/S6WLFkSI0eOjIiItWvXRufOnXMeT3QDAABArdmzZ8dPf/rTOOmkk+Lss8+Ofv36RUTEQw89VHfaeS48MgwAAABqnXTSSbFx48bYtGlTdOrUqW77t771rWjXrl3O44luAAAA+IhWrVrVC+6ID57f3RiiGwAAgBbtuOOOi4qKiujUqVMMGDAgMplMg8c+99xzOY0tugEAAGjRvva1r9XdsXz06NFNOnYmm81mm3REaCZOLf56oacAe1x2+/uFngLsUa//2wmFngLsUX8cdVuhpwB7VFHZa4WegpVuAAAA+FvV1dWxYcOG2LFjR73t3bt3z2kc0Q0AAAC1XnvttTj33HPjqaeeqrc9m81GJpOJmpqanMYT3QAAAFBr4sSJ0bp163j44YejW7duH3tTtd0hugEAAKDWCy+8EMuXL48+ffo0yXhFTTIKAAAANANHHXVUbNy4scnGE90AAABQa/bs2XHZZZfFE088EW+//XZs2rSp3itXTi8HAACAWsOHD4+IiFNOOaXe9dxupAYAAAB5evzxx5t0PNENAAAAtYYNGxbbtm2LFStW7PI53bkS3QAAAFBr8eLFMX78+F3eTK0xp5e7kRoAAADUuvDCC+PMM8+MdevWxY4dO+q9cg3uCNENAAAAddavXx9Tp06Nrl27Nsl4ohsAAABq/cM//EM88cQTTTaea7oBAACg1k033RRnnnlm/O53v4u+fftGmzZt6u3/zne+k9N4ohsAAABq3XPPPfHYY49FaWlpPPHEE/We1Z3JZEQ3AAAANNb3v//9mDFjRlxxxRVRVJT/Fdmu6QYAAIBa1dXVMXbs2CYJ7gjRDQAAAHUmTJgQCxcubLLxnF4OAAAAtWpqauK6666LRx99NI499tidbqR2ww035DSe6AYAAIBaL730UgwYMCAiIv7whz/U2/fRm6rtLtENAAAAtR5//PEmHc813QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBERDcAAAAkIroBAAAgEdENAAAAiYhuAAAASER0AwAAQCKiGwAAABIR3QAAAJCI6AYAAIBEMtlsNlvoSQB8qKqqKmbNmhXl5eVRUlJS6OlAcn7maWn8zNMS+blv2UQ38KmyadOm6NixY7z77rvRoUOHQk8HkvMzT0vjZ56WyM99y+b0cgAAAEhEdAMAAEAiohsAAAASEd3Ap0pJSUlMnz7dTUZoMfzM09L4macl8nPfsrmRGgAAACRipRsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6gSb35ptvRiaTiRdeeKFu29KlS6Nv377Rpk2bGD16dMHmBvny801Lc9JJJ8VFF11U6GlAs3L11VdH//79Cz0N9hDRDTS5Qw89NNatWxfHHHNM3bapU6dG//79Y/Xq1XHHHXfkNf4DDzwQp556anTu3Hmn+PnQtm3bYtKkSdG5c+do3759jBkzJtavX5/XPxci9p6f7zVr1sRpp50W7dq1iy5dusSll14a27dvz2tu8EnuuOOO2G+//Xba3rNnz5g7d+4enw98Wl1yySVRUVGxW8cK9L2f6AaaXKtWraKsrCxat25dt+2NN96IU045JQ455JBd/g+y3VFdXR0REVu3bo0TTzwxZs+e3eCxU6ZMiX//93+P++67L5588slYu3ZtnHHGGY3658JH7Q0/3zU1NXHaaadFdXV1PPXUU3HnnXfGHXfcEdOmTWvU3ODT4sO/J7C3a9++fXTu3LnQ02BPyQI0Uk1NTXb27NnZXr16ZYuLi7OHHnpo9oc//GF29erV2YjIPv/883V//uhrwYIF2e3bt2e/+c1vZnv27JktLS3NHnHEEdm5c+fWG3/ChAnZr33ta9kf/vCH2W7dumV79uxZb/9H/zkf9c4772TbtGmTve++++q2rVy5MhsR2WXLliX7ftC87M0/34888ki2qKgoW1lZWXfMLbfcku3QoUO2qqqqib9TNHfDhg3Lfve7381ms9nstm3bshdffHH2oIMOyrZr1y47aNCg7OOPP57NZrPZxx9/fKe/D9OnT88OGzZsp+0f+t3vfpc98cQTs6WlpdlDDjkke+GFF2a3bNlSt79Hjx7ZmTNnZs8555zsvvvum50wYUK2qqoqO2nSpGxZWVm2pKQk27179+w111yzJ78ltDDDhg3LTp48Ofvd7343u99++2W7dOmSvfXWW7NbtmzJfuMb38i2b98+26tXr+wjjzySzWb/+nfhN7/5TXbgwIHZtm3bZocMGZJ99dVX68acPn16tl+/fnVfP/7449kTTjgh265du2zHjh2zQ4cOzb755pvZBQsW7PL/z7B3sdINNFp5eXlce+21cdVVV8Urr7wSd999d3Tt2rXeMR+eituhQ4eYO3durFu3LsaOHRs7duyIQw45JO6777545ZVXYtq0afG9730vFi1aVO/9FRUVsWrVqliyZEk8/PDDuzWv5cuXx/vvvx/Dhw+v29anT5/o3r17LFu2LP8PTouwN/98L1u2LPr27VtvviNGjIhNmzbFyy+/3NhvCcTkyZNj2bJlce+998aKFSvizDPPjJEjR8brr78eQ4cOjblz50aHDh1i3bp1sW7durjkkkvigQceiEMOOSRmzpxZtz3igzNERo4cGWPGjIkVK1bEwoUL4z//8z9j8uTJ9f6Z119/ffTr1y+ef/75uOqqq2LevHnx0EMPxaJFi2LVqlXxi1/8Inr27FmA7wYtyZ133hkHHHBAPPvss3HhhRfGP//zP8eZZ54ZQ4cOjeeeey5OPfXUOOecc+K9996re8/3v//9+PGPfxz/9V//Fa1bt45vfvObuxx7+/btMXr06Bg2bFisWLEili1bFt/61rcik8nE2LFj4+KLL46jjz667u/P2LFj99THpom0/uRDAHa2efPmuPHGG+Omm26KCRMmREREr1694sQTT4w333yz7rgPT8XNZDLRsWPHKCsrq9s3Y8aMuj8fdthhsWzZsli0aFGcddZZddv32WefuO2226K4uHi351ZZWRnFxcU7nebbtWvXqKyszPGT0hLt7T/flZWVO/2C4MOv/R2gsdasWRMLFiyINWvWxEEHHRQRH1yXunjx4liwYEFcc8010bFjx8hkMvX+LkR88Hdl3333rbd91qxZ8Y//+I91N2nr3bt3zJs3L4YNGxa33HJLlJaWRkTEKaecEhdffHG9efTu3TtOPPHEyGQy0aNHj8SfHCL69esXV155ZUT89ZeyBxxwQJx33nkRETFt2rS45ZZbYsWKFXXv+Zd/+ZcYNmxYRERcccUVcdppp8W2bdvqfrY/tGnTpnj33Xfjy1/+cvTq1SsiIj772c/W7W/fvn20bt16p79X7D1EN9AoK1eujKqqqvjiF7/Y6DFuvvnmuP3222PNmjXxl7/8Jaqrq3e6UUjfvn1zChJoCn6+YWcvvfRS1NTUxBFHHFFve1VVVaOuTX3xxRdjxYoV8Ytf/KJuWzabjR07dsTq1avrouP444+v975vfOMb8aUvfSmOPPLIGDlyZHz5y1+OU089tRGfCHbfscceW/fnVq1aRefOnaNv37512z78xeaGDRuiQ4cOO72nW7dudfu7d+9eb+z9998/vvGNb8SIESPiS1/6UgwfPjzOOuusuvew93N6OdAobdu2zev99957b1xyySVx7rnnxmOPPRYvvPBCTJw4caeb5Oyzzz45j11WVhbV1dXxzjvv1Nu+fv16vyVmt+ztP99lZWU73c38w6/9HaCxtmzZEq1atYrly5fHCy+8UPdauXJl3HjjjY0a7/zzz6831osvvhivv/563WpfxM5/T4477rhYvXp1/OAHP4i//OUvcdZZZ8U//MM/5P354OO0adOm3teZTKbetkwmExERO3bs2OV7drX/oxYsWBDLli2LoUOHxsKFC+OII46Ip59+usnmT2GJbqBRevfuHW3btt3tx138raVLl8bQoUPjggsuiAEDBsThhx8eb7zxRpPMbeDAgdGmTZt6c1u1alWsWbMmhgwZ0iT/DJq3vf3ne8iQIfHSSy/Fhg0b6o5ZsmRJdOjQIY466qgmmQctz4ABA6KmpiY2bNgQhx9+eL3Xh7/MKS4ujpqamp3eu6vtxx13XLzyyis7jXX44Yd/4hkgHTp0iLFjx8bPfvazWLhwYdx///3x5z//uek+LBTAgAEDory8PJ566qk45phj4u67746Ihv9esfdwejnQKKWlpXH55ZfHZZddFsXFxfH5z38+3nrrrXj55Zd365Tc3r17x1133RWPPvpoHHbYYfF//+//jd///vdx2GGHfeJ7//znP8eaNWti7dq1EfFBcER8sIJXVlYWHTt2jHPPPTemTp0a+++/f3To0CEuvPDCGDJkSHzuc5/L74PTIuztP9+nnnpqHHXUUXHOOefEddddF5WVlXHllVfGpEmToqSkJI/vDC3ZEUccEf/4j/8Y48ePjx//+McxYMCAeOutt6KioiKOPfbYOO2006Jnz56xZcuWqKioiH79+kW7du2iXbt20bNnz/jtb38b48aNi5KSkjjggAPi8ssvj8997nMxefLk+Kd/+qfYZ5994pVXXoklS5bETTfd1OA8brjhhujWrVsMGDAgioqK4r777ouysrJGP64PCm316tVx6623xle/+tU46KCDYtWqVfH666/H+PHjI+KD59yvXr06XnjhhTjkkENi33339d/lexkr3UCjXXXVVXHxxRfHtGnT4rOf/WyMHTu23sraxzn//PPjjDPOiLFjx8bgwYPj7bffjgsuuGC33vvQQw/FgAED4rTTTouIiHHjxsWAAQNi/vz5dcfMmTMnvvzlL8eYMWPi7/7u76KsrCweeOCB3D8kLdbe/PPdqlWrePjhh6NVq1YxZMiQ+D//5//E+PHjY+bMmTl8B2BnCxYsiPHjx8fFF18cRx55ZIwePTp+//vf112jOnTo0Pj2t78dY8eOjQMPPDCuu+66iIiYOXNmvPnmm9GrV6848MADI+KD612ffPLJeO211+ILX/hCDBgwIKZNm1Z3k7aG7LvvvnHdddfF8ccfHyeccEK8+eab8cgjj0RRkf9Zy96pXbt28eqrr8aYMWPiiCOOiG9961sxadKkOP/88yMiYsyYMTFy5Mg4+eST48ADD4x77rmnwDMmV5lsNpst9CQAAACgOfIrQQAAAEhEdAMAAEAiohsAAAASEd0AAACQiOgGAACAREQ3AAAAJCK6AQAAIBHRDQAAAImIbgAAAEhEdAMAAEAiohsAAAAS+f+IYxRsDWUb0wAAAABJRU5ErkJggg==",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plot_distance_matrix(embeddings=embeddings, labels=dataset_names)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
- "id": "c5107e1b-93e1-4861-becb-16acf1fdd9c6",
+ "id": "705192a9-5ee8-4f92-aa41-a2c19197d017",
"metadata": {},
"outputs": [],
"source": []
diff --git a/setup.py b/setup.py
index 01ea3b0..bba14c2 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
version="0.1.0",
author="...",
description="...",
- long_description=open("README.rst").read(),
+ #long_description=open("README.rst").read(),
long_description_content_type="text/x-rst",
url="https://github.com/intsystems/DataMetaMap",
packages=find_packages(where="src"),
diff --git a/src/data_meta_map/__init__.py b/src/data_meta_map/__init__.py
index aa0f67e..a4a8c1f 100644
--- a/src/data_meta_map/__init__.py
+++ b/src/data_meta_map/__init__.py
@@ -1,9 +1,8 @@
-from .BaseEmbedder import BaseEmbedder
-from .WassersteinEmbedder import WassersteinEmbedder
-
+from data_meta_map.base_embedder import BaseEmbedder
+from data_meta_map.wasserstein_embedder import WassersteinEmbedder
__all__ = [
- "BaseEmbedder",
- "WassersteinEmbedder"
+ "baseEmbedder",
+ "wasserstein_embedder"
]
diff --git a/src/data_meta_map/BaseEmbedder.py b/src/data_meta_map/base_embedder.py
similarity index 94%
rename from src/data_meta_map/BaseEmbedder.py
rename to src/data_meta_map/base_embedder.py
index 1ca2f9b..a299c4d 100644
--- a/src/data_meta_map/BaseEmbedder.py
+++ b/src/data_meta_map/base_embedder.py
@@ -19,11 +19,25 @@ class SupportsGetItem(Protocol):
Protocol for objects with __getitem__ and __len__ interface.
Ensures compatibility with both Dataset and custom iterable objects.
"""
+
def __getitem__(self, index: int) -> Tuple[Any, int]: ...
def __len__(self) -> int: ...
class BaseEmbedder(ABC):
+ def __init__(
+ self
+ ):
+ pass
+
+ @abstractmethod
+ def embed(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Override this method in your Embedder class")
+
+
+# DEPRECATED
+class BaseEmbedderDEPRECATED(ABC):
"""
Abstract base class for dataset embedding.
@@ -191,7 +205,8 @@ def get_class_statistics(
feature_dim = X.shape[1]
means = torch.zeros((num_classes, feature_dim), device=self.device)
- covs = torch.zeros((num_classes, feature_dim, feature_dim), device=self.device)
+ covs = torch.zeros(
+ (num_classes, feature_dim, feature_dim), device=self.device)
for idx, label in enumerate(unique_labels):
mask = (Y == label)
@@ -201,7 +216,8 @@ def get_class_statistics(
covs[idx] = torch.cov(class_samples.T)
else:
# For single sample, covariance is undefined — use zero matrix
- covs[idx] = torch.zeros((feature_dim, feature_dim), device=self.device)
+ covs[idx] = torch.zeros(
+ (feature_dim, feature_dim), device=self.device)
return means, covs
diff --git a/src/data_meta_map/datasets.py b/src/data_meta_map/datasets.py
index fb577d3..48662af 100644
--- a/src/data_meta_map/datasets.py
+++ b/src/data_meta_map/datasets.py
@@ -341,7 +341,7 @@ def kmnist(root):
transforms.Resize(224),
transforms.ToTensor(),
])
- trainset = KMNIST(root, train=True, transform=transform, download=True)
+ trainset = KMNIST(root, train=True, transform=transform, download=False)
testset = KMNIST(root, train=False, transform=transform)
return trainset, testset
diff --git a/src/data_meta_map/task2vec/task2vec.py b/src/data_meta_map/task2vec/task2vec.py
index 40923ca..9989a3d 100644
--- a/src/data_meta_map/task2vec/task2vec.py
+++ b/src/data_meta_map/task2vec/task2vec.py
@@ -26,6 +26,8 @@
from torch.optim.optimizer import Optimizer
from data_meta_map.task2vec.utils import AverageMeter, get_error, get_device
+from data_meta_map import BaseEmbedder
+
class Embedding:
def __init__(self, hessian, scale, meta=None):
@@ -63,7 +65,7 @@ def task2vec(probe_network, dataset: Dataset, skip_layers=0, max_samples=None, c
return embed
-class Task2Vec:
+class Task2Vec(BaseEmbedder):
def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classifier_opts=None,
method='montecarlo', method_opts=None, loader_opts=None, bernoulli=False):
@@ -90,7 +92,7 @@ def __init__(self, model: ProbeNetwork, skip_layers=0, max_samples=None, classif
self.loss_fn = nn.CrossEntropyLoss() if not self.bernoulli else nn.BCEWithLogitsLoss()
self.loss_fn = self.loss_fn.to(self.device)
- def embed(self, dataset: Dataset, create_final_embedding: bool = False):
+ def embed(self, dataset: Dataset, create_final_embedding: bool = True):
# Cache the last layer features (needed to train the classifier) and (if needed) the intermediate layer features
# so that we can skip the initial layers when computing the embedding
if self.skip_layers > 0:
diff --git a/src/data_meta_map/WassersteinEmbedder.py b/src/data_meta_map/wasserstein_embedder.py
similarity index 92%
rename from src/data_meta_map/WassersteinEmbedder.py
rename to src/data_meta_map/wasserstein_embedder.py
index e925a98..af2761d 100644
--- a/src/data_meta_map/WassersteinEmbedder.py
+++ b/src/data_meta_map/wasserstein_embedder.py
@@ -7,8 +7,7 @@
import ot # POT: Python Optimal Transport
from tqdm.autonotebook import tqdm
-from .BaseEmbedder import BaseEmbedder
-
+from data_meta_map.base_embedder import BaseEmbedder
def sqrtm_newton_schulz(A: torch.Tensor, num_iters: int = 20) -> torch.Tensor:
@@ -105,7 +104,8 @@ def __init__(
gaussian_assumption: bool = True,
diagonal_cov: bool = False,
commute: bool = False,
- sqrt_method: str = "ns", # 'ns' = Newton-Schulz (default), 'eig' = eigenvalue decomposition
+ # 'ns' = Newton-Schulz (default), 'eig' = eigenvalue decomposition
+ sqrt_method: str = "ns",
sqrt_niters: int = 20,
**kwargs
):
@@ -133,7 +133,8 @@ def __init__(
self.sqrt_niters = sqrt_niters
# Cache for class statistics: {dataset_id: (means, covs, class_offsets)}
- self._stats_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor, List[int]]] = {}
+ self._stats_cache: Dict[int,
+ Tuple[torch.Tensor, torch.Tensor, List[int]]] = {}
# Cache for preprocessed data: {dataset_id: (X, Y)}
self._data_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
@@ -171,9 +172,11 @@ def preprocess_dataset(
idxs = np.sort(np.random.choice(
len(data), self.max_samples, replace=False))
sampler = SubsetRandomSampler(idxs)
- loader = DataLoader(data, sampler=sampler, batch_size=self.batch_size)
+ loader = DataLoader(data, sampler=sampler,
+ batch_size=self.batch_size)
else:
- loader = DataLoader(data, batch_size=self.batch_size, shuffle=False)
+ loader = DataLoader(
+ data, batch_size=self.batch_size, shuffle=False)
elif isinstance(data, DataLoader):
loader = data
else:
@@ -231,7 +234,8 @@ def _compute_gaussian_stats(
if self.diagonal_cov:
covs = torch.zeros((num_classes, feature_dim), device=self.device)
else:
- covs = torch.zeros((num_classes, feature_dim, feature_dim), device=self.device)
+ covs = torch.zeros(
+ (num_classes, feature_dim, feature_dim), device=self.device)
for idx, label in enumerate(unique_labels):
mask = (Y == label)
@@ -251,7 +255,8 @@ def _compute_gaussian_stats(
if self.diagonal_cov:
covs[idx] = torch.zeros(feature_dim, device=self.device)
else:
- covs[idx] = torch.zeros((feature_dim, feature_dim), device=self.device)
+ covs[idx] = torch.zeros(
+ (feature_dim, feature_dim), device=self.device)
# Global class indices
class_offsets = list(range(num_classes))
@@ -378,11 +383,14 @@ def compute_pairwise_distances(
means_j[idx_j], covs_j[idx_j]
)
else:
- X_i, Y_i = self._data_cache.get(i, self.preprocess_dataset(datasets[i], dataset_id=i))
- X_j, Y_j = self._data_cache.get(j, self.preprocess_dataset(datasets[j], dataset_id=j))
+ X_i, Y_i = self._data_cache.get(
+ i, self.preprocess_dataset(datasets[i], dataset_id=i))
+ X_j, Y_j = self._data_cache.get(
+ j, self.preprocess_dataset(datasets[j], dataset_id=j))
samples_i = X_i[Y_i == local_i]
samples_j = X_j[Y_j == local_j]
- d = self._exact_wasserstein_distance(samples_i, samples_j)
+ d = self._exact_wasserstein_distance(
+ samples_i, samples_j)
d = torch.tensor(d, device=self.device)
D[global_i, global_j] = d
@@ -446,7 +454,8 @@ def augment_features(
X, Y = self.preprocess_dataset(data, dataset_id=dataset_idx)
start_offset = class_offsets[dataset_idx]
- end_offset = class_offsets[dataset_idx + 1] if dataset_idx + 1 < len(class_offsets) else label_embeddings.shape[0]
+ end_offset = class_offsets[dataset_idx + 1] if dataset_idx + \
+ 1 < len(class_offsets) else label_embeddings.shape[0]
label_emb_for_dataset = label_embeddings[start_offset:end_offset]
label_indices = Y.long()
@@ -485,7 +494,8 @@ def compute_wte(
augmented_datasets: List[torch.Tensor] = []
for idx, dataset in enumerate(datasets):
- Z = self.augment_features(dataset, label_embeddings, idx, class_offsets)
+ Z = self.augment_features(
+ dataset, label_embeddings, idx, class_offsets)
augmented_datasets.append(Z)
if reference is None and create_reference:
@@ -494,15 +504,18 @@ def compute_wte(
ref_indices = torch.randperm(all_data.shape[0])[:ref_size]
reference = all_data[ref_indices].float()
elif reference is None:
- raise ValueError("Either provide 'reference' or set 'create_reference=True'")
+ raise ValueError(
+ "Either provide 'reference' or set 'create_reference=True'")
task_embeddings = []
ref_size = reference.shape[0]
for Z in augmented_datasets:
Z = Z.float()
- C = ot.dist(Z.cpu().numpy(), reference.cpu().numpy(), metric='euclidean')
- gamma = ot.emd(ot.unif(Z.shape[0]), ot.unif(ref_size), C, numItermax=1_000_000)
+ C = ot.dist(Z.cpu().numpy(), reference.cpu().numpy(),
+ metric='euclidean')
+ gamma = ot.emd(ot.unif(Z.shape[0]), ot.unif(
+ ref_size), C, numItermax=1_000_000)
gamma = torch.from_numpy(gamma).float().to(self.device)
f = (ref_size * gamma.T @ Z - reference) / np.sqrt(ref_size)
task_embeddings.append(f)
diff --git a/tests/test_wasserstein.py b/tests/test_wasserstein.py
index 8e1ce91..25d3c8a 100644
--- a/tests/test_wasserstein.py
+++ b/tests/test_wasserstein.py
@@ -4,8 +4,8 @@
import sys
import traceback
-from data_meta_map.BaseEmbedder import BaseEmbedder
-from data_meta_map.WassersteinEmbedder import WassersteinEmbedder
+from data_meta_map.baseEmbedder import BaseEmbedder
+from data_meta_map.wasserstein_embedder import WassersteinEmbedder
# ============================================================================
@@ -14,6 +14,7 @@
class MockDataset(Dataset):
"""Простой датасет для тестов."""
+
def __init__(self, num_samples=100, feature_dim=10, num_classes=5, seed=42):
torch.manual_seed(seed)
self.data = torch.randn(num_samples, feature_dim)
@@ -28,6 +29,7 @@ def __getitem__(self, idx):
class MockVectorizedDataset(Dataset):
"""Датасет с уже векторизованными данными (как текстовые эмбеддинги)."""
+
def __init__(self, num_samples=100, feature_dim=768, num_classes=3, seed=42):
torch.manual_seed(seed)
self.data = torch.randn(num_samples, feature_dim)
@@ -46,6 +48,7 @@ def __getitem__(self, idx):
class ConcreteEmbedder(BaseEmbedder):
"""Конкретная реализация для тестирования абстрактного класса."""
+
def preprocess_dataset(self, data, dataset_id=None):
if isinstance(data, Dataset):
X = torch.randn(100, 10)
@@ -54,7 +57,8 @@ def preprocess_dataset(self, data, dataset_id=None):
return data
def compute_pairwise_distances(self, datasets, symmetric=True):
- n = sum(len(torch.unique(self.preprocess_dataset(d)[1])) for d in datasets)
+ n = sum(
+ len(torch.unique(self.preprocess_dataset(d)[1])) for d in datasets)
return torch.zeros((n, n), device=self.device)
def embed_distance_matrix(self, distance_matrix, emb_dim=None):
@@ -64,7 +68,8 @@ def embed_distance_matrix(self, distance_matrix, emb_dim=None):
def augment_features(self, data, label_embeddings, dataset_idx, class_offsets):
X, Y = self.preprocess_dataset(data)
- label_emb = label_embeddings[class_offsets[dataset_idx]:class_offsets[dataset_idx+1]]
+ label_emb = label_embeddings[class_offsets[dataset_idx]
+ :class_offsets[dataset_idx+1]]
return torch.cat([X, label_emb[Y]], dim=1)
@@ -254,7 +259,8 @@ def test_wasserstein_preprocess_vectorized():
print("\n[TEST] WassersteinEmbedder preprocess vectorized data")
try:
embedder = WassersteinEmbedder(emb_dim=2)
- dataset = MockVectorizedDataset(num_samples=100, feature_dim=768, num_classes=5)
+ dataset = MockVectorizedDataset(
+ num_samples=100, feature_dim=768, num_classes=5)
X, Y = embedder.preprocess_dataset(dataset, dataset_id=2)
@@ -320,7 +326,8 @@ def test_wasserstein_bures_distance_identical():
mean2 = torch.tensor([0.0, 0.0])
cov2 = torch.eye(2)
- distance = embedder._bures_wasserstein_distance(mean1, cov1, mean2, cov2)
+ distance = embedder._bures_wasserstein_distance(
+ mean1, cov1, mean2, cov2)
assert torch.allclose(distance, torch.tensor(0.0), atol=1e-2)
print(" ✓ Расстояние Бюра для идентичных распределений = 0")
@@ -341,7 +348,8 @@ def test_wasserstein_bures_distance_different_means():
mean2 = torch.tensor([1.0, 0.0])
cov2 = torch.eye(2)
- distance = embedder._bures_wasserstein_distance(mean1, cov1, mean2, cov2)
+ distance = embedder._bures_wasserstein_distance(
+ mean1, cov1, mean2, cov2)
assert distance > 0.0
assert torch.allclose(distance, torch.tensor(1.0), atol=1e-5)
@@ -379,8 +387,10 @@ def test_wasserstein_pairwise_distances_multiple():
print("\n[TEST] WassersteinEmbedder compute_pairwise_distances multiple datasets")
try:
embedder = WassersteinEmbedder(emb_dim=2, max_samples=20)
- ds1 = MockDataset(num_samples=30, feature_dim=10, num_classes=2, seed=42)
- ds2 = MockDataset(num_samples=30, feature_dim=10, num_classes=3, seed=43)
+ ds1 = MockDataset(num_samples=30, feature_dim=10,
+ num_classes=2, seed=42)
+ ds2 = MockDataset(num_samples=30, feature_dim=10,
+ num_classes=3, seed=43)
D = embedder.compute_pairwise_distances([ds1, ds2], symmetric=True)
@@ -424,7 +434,8 @@ def test_wasserstein_augment_features():
label_embeddings = torch.randn(3, 3)
class_offsets = [0, 3]
- Z = embedder.augment_features(dataset, label_embeddings, 0, class_offsets)
+ Z = embedder.augment_features(
+ dataset, label_embeddings, 0, class_offsets)
assert Z.shape == (50, 13) # 10 features + 3 label embeddings
print(" ✓ Аугментация признаков работает корректно")