The data_meta_map package provides tools for embedding datasets into a shared
vector space via multiple algorithms. The main entry points are
WassersteinEmbedder and
Task2Vec, both inheriting from the abstract
BaseEmbedder.
Abstract base class for all dataset embedders.
All concrete embedder classes must inherit from BaseEmbedder and
implement the embed method.
embed(*args, **kwargs)
Override in subclass to produce embeddings for a given dataset.
Raises NotImplementedError if called directly.
get_class_statistics(X, Y)
Compute per-class mean and covariance from feature and label tensors.
| Parameter | Type | Description |
|---|---|---|
| X | Tensor [N, D] | Feature matrix. |
| Y | Tensor [N], long | Integer class labels. |
Returns: (means, covs) — tensors of shape
[C, D] and [C, D, D] where C = number of classes.
Dataset embedder based on Wasserstein distance (Optimal Transport). Supports two distance modes:
gaussian_assumption=True) —
Bures-Wasserstein distance. Fast: O(d³) per pair.
Based on OTDD (Microsoft Research).
gaussian_assumption=False) —
distribution-free but slow: O(n³ log n).
Class-level embeddings are produced via Multidimensional Scaling (MDS).
| Parameter | Type | Description |
|---|---|---|
| emb_dim | int | Target dimensionality of label embeddings. |
| device | str | torch.device | Computation device ("cpu" or "cuda"). |
| max_samples | int | None | Maximum samples to load per dataset. None = all. |
| batch_size | int | Batch size for DataLoader creation. |
| gaussian_assumption | bool | If True, use Gaussian (Bures) approximation; else exact EMD. |
| diagonal_cov | bool | Use only diagonal of covariance (faster on high-dim data). |
| sqrt_method | str | Matrix square root method: "ns" (Newton-Schulz) or "eig". |
| sqrt_niters | int | Iterations for Newton-Schulz matrix square root. |
compute_wte(datasets, reference=None, create_reference=True)
Main method. Computes Wasserstein Transport Embeddings for a collection of datasets. Pipeline: pairwise distances → MDS label embeddings → feature augmentation → transport to reference.
| Parameter | Type | Description |
|---|---|---|
| datasets | List[Dataset | DataLoader] | Datasets to embed. |
| reference | Tensor | None | Optional reference distribution. |
| create_reference | bool | If True and reference is None, builds reference from merged data. |
Returns: (task_embeddings, label_embeddings, augmented_datasets)
task_embeddings: [num_datasets, ref_size, feature_dim + emb_dim]label_embeddings: [total_classes, emb_dim]augmented_datasets: list of [num_samples, feature_dim + emb_dim]embedder = WassersteinEmbedder(emb_dim=32, device="cpu", max_samples=1000)
task_embs, label_embs, augmented = embedder.compute_wte([cifar10, mnist])
compute_pairwise_distances(datasets, symmetric=True)
Compute a [total_classes × total_classes] pairwise distance matrix
between all classes across all datasets.
Global class indices are contiguous: dataset 0 occupies [0 … k₀-1],
dataset 1 occupies [k₀ … k₀+k₁-1], and so on.
| Parameter | Type | Description |
|---|---|---|
| datasets | List[Dataset | DataLoader] | Input datasets. |
| symmetric | bool | Compute only upper triangle (saves ~50 % time). |
Returns: Distance tensor [total_classes, total_classes].
preprocess_dataset(data, dataset_id=None)
cached
Load and flatten dataset into (X, Y) tensors.
Results are cached by dataset_id to avoid re-loading.
Returns:
X: Tensor[N, D], Y: Tensor[N] (long).
embed_distance_matrix(distance_matrix, emb_dim=None)
Transform a precomputed distance matrix into low-dimensional embeddings via MDS.
Returns: Tensor[N, emb_dim].
augment_features(data, label_embeddings, dataset_idx, class_offsets)
Concatenate label embeddings to each sample's feature vector.
Returns: Z: Tensor[N, feature_dim + emb_dim].
clear_cache()
Clear internal data and statistics caches to free memory.
embed(datasets, **kwargs)
Satisfies the BaseEmbedder abstract interface.
Delegates to compute_wte(datasets, **kwargs) and returns
(task_embeddings, label_embeddings, augmented_datasets).
Task2Vec embeds a dataset as a fixed-length vector by computing the diagonal of the Fisher Information Matrix (FIM) of a pre-trained probe network fine-tuned on the target task. Intuitively, parameters that change a lot when adapting to a task carry more information about that task — and vice versa. See Achille et al., ICCV 2019.
| Parameter | Type | Description |
|---|---|---|
| model | ProbeNetwork | Pre-trained network exposing a .classifier property (e.g. ResNet-18). |
| skip_layers | int | Number of initial layers to skip. Cached activations are used as inputs instead. Useful for large models. |
| max_samples | int | None | Maximum samples to use from the dataset. None = all. |
| classifier_opts | dict | None | Kwargs forwarded to _fit_classifier (e.g. {"epochs": 5, "learning_rate": 1e-3}). |
| method | str | Fisher approximation: "montecarlo" (default, faster) or "variational" (more accurate). |
| method_opts | dict | None | Kwargs forwarded to the Fisher method (e.g. {"epochs": 2}). |
| loader_opts | dict | None | Kwargs for the internal DataLoader (e.g. {"batch_size": 32, "num_workers": 4}). |
| bernoulli | bool | If True, use BCEWithLogitsLoss (binary classification). Default uses CrossEntropyLoss. |
embed(dataset, create_final_embedding=True)
Full pipeline: cache features → fit classifier → compute Fisher → extract embedding.
| Parameter | Type | Description |
|---|---|---|
| dataset | Dataset | Target PyTorch dataset to embed. |
| create_final_embedding | bool | If True, returns hessian / scale (normalized numpy array). If False, returns the raw Embedding object. |
Returns: ndarray or Embedding (see above).
from data_meta_map.task2vec.task2vec import task2vec
from data_meta_map.models import resnet18
probe = resnet18(pretrained=True)
embedding = task2vec(probe, cifar10, max_samples=500)
print(embedding.hessian.shape) # (num_params,)
compute_fisher(dataset)
Compute the diagonal FIM approximation and store it in each layer's
weight.grad2_acc attribute. Dispatches to
montecarlo_fisher or variational_fisher
depending on self.method.
montecarlo_fisher(dataset, epochs=1)
Monte-Carlo approximation of the FIM. For each mini-batch, samples y from the model's output distribution (not from the dataset labels) and accumulates squared gradients. Fast but approximate.
variational_fisher(dataset, epochs=1, beta=1e-7)
Variational approximation: optimises per-weight noise variance subject to a KL regulariser. Slower but more accurate than Monte-Carlo.
extract_embedding(model)
Read grad2_acc values stored by compute_fisher
and pack them into an Embedding object.
Classifier weights are excluded.
Returns: Embedding with .hessian and .scale.
Abstract base class that all probe networks must inherit from.
Extends torch.nn.Module.
The only requirement is exposing a classifier property that returns
the final classification layer (e.g. a nn.Linear).
Task2Vec uses this property to fit and exclude the head during FIM extraction.
classifier abstractpropertyReturns the final classification sub-module (e.g. the last fully-connected layer).
The library ships a ready-made ResNet probe in data_meta_map.models:
from data_meta_map.models import resnet18, resnet34
# Pre-trained on ImageNet
probe = resnet18(pretrained=True, num_classes=10)
To use your own network, subclass ProbeNetwork:
from data_meta_map.task2vec.task2vec import ProbeNetwork
import torch.nn as nn
class MyProbe(ProbeNetwork):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(nn.Linear(512, 256), nn.ReLU())
self.fc = nn.Linear(256, 10)
self.layers = [self.backbone, self.fc]
@property
def classifier(self):
return self.fc
@classifier.setter
def classifier(self, val):
self.fc = val
def forward(self, x, start_from=0):
for layer in self.layers[start_from:]:
x = layer(x)
return x
Plain data class returned by Task2Vec.extract_embedding.
Stores the per-parameter Fisher diagonal and layer scales as numpy arrays.
| Attribute | Type | Description |
|---|---|---|
| hessian | ndarray [P] | Filter-wise diagonal FIM; one value per convolutional filter / linear unit. Larger values → parameter more task-relevant. |
| scale | ndarray [P] | Prior scale for each parameter, used for normalization. Typically all-ones for Monte-Carlo method. |
| meta | Any | Optional metadata (e.g. task name, dataset info). |
The normalized embedding used for distance computation is hessian / scale:
raw_emb = task2vec(probe, dataset, create_final_embedding=False)
normalized = raw_emb.hessian / raw_emb.scale # same as create_final_embedding=True
Utility module for comparing Embedding objects.
Provides pairwise distance functions and matrix-building helpers.
All functions accept two Embedding objects and return a scalar distance.
| Function | Formula | Notes |
|---|---|---|
cosine(e0, e1) |
Cosine distance on scaled Hessians | Most commonly used. Range [0, 2]. |
kl(e0, e1) |
Symmetric KL divergence on variances | Symmetric: max(KL(p‖q), KL(q‖p)). Always ≥ 0. |
asymmetric_kl(e0, e1) |
KL(p‖q) only | Not symmetric. Useful for transfer direction analysis. |
jsd(e0, e1) |
Jensen-Shannon divergence | Symmetric, bounded. Range [0, log 2]. |
normalized_cosine(e0, e1) |
Cosine on normalized variances | Scale-invariant variant of cosine. |
correlation(e0, e1) |
Correlation distance on variances | 1 - Pearson correlation. |
pdist(embeddings, distance='cosine')Compute all-pairs distance matrix for a list of embeddings. Result is symmetric with zero diagonal.
Returns: ndarray [N, N]
from data_meta_map.task2vec.task_similarity import pdist
embeddings = [task2vec(probe, ds) for ds in [cifar10, mnist, stl10]]
D = pdist(embeddings, distance="cosine")
# array([[0. , 0.312, 0.481],
# [0.312, 0. , 0.395],
# [0.481, 0.395, 0. ]])
cdist(from_embeddings, to_embeddings, distance='cosine')Cross-distance matrix between two lists of embeddings.
Returns: ndarray [M, N]
plot_distance_matrix(embeddings, labels=None, distance='cosine')
Render a clustered heatmap of pairwise distances using seaborn.
Requires seaborn, scipy, matplotlib.
Dataset embedder based on Dataset2Vec (Iwata & Ghahramani, 2020). Learns a permutation-invariant meta-feature extractor for tabular datasets via contrastive training. The encoder operates on (feature, target) pairs and aggregates them with two levels of mean-pooling.
| Parameter | Type | Description |
|---|---|---|
| model | Dataset2Vec | Pre-initialized model. Create with get_model('dataset2vec') or Dataset2Vec(config, opt_config). |
| max_epochs | int | Training epochs for fit(). |
| batch_size | int | Batch size for the data loader. |
| n_batches | int | Number of batches per epoch. |
fit(data, val_data=None, trainer_kwargs=None)
Train the Dataset2Vec model on a collection of tabular datasets.
| Parameter | Type | Description |
|---|---|---|
| data | Path | list[Path] | list[DataFrame] | list | Training datasets. Each element is one dataset; the last column is the target. |
| val_data | same | None | Optional validation data. |
| trainer_kwargs | dict | None | Extra kwargs for pytorch_lightning.Trainer. |
Returns: self (for chaining).
embed(X, y)
Compute embedding for a single tabular dataset. Must call fit() first.
| Parameter | Type | Description |
|---|---|---|
| X | Tensor [n, d] | Feature matrix. |
| y | Tensor [n] or [n, 1] | Target vector. |
Returns: ndarray [output_size]
Raises: RuntimeError if called before fit().
from data_meta_map.models import get_model
from data_meta_map.dataset2vec_embedder import Dataset2VecEmbedder
import torch
model = get_model("dataset2vec")
embedder = Dataset2VecEmbedder(model, max_epochs=20)
embedder.fit(train_datasets)
X = torch.randn(100, 10)
y = torch.randint(0, 3, (100,)).float()
vec = embedder.embed(X, y)
print(vec.shape) # (16,) — default output_size
save(path) / load(path)
Save and load model weights (state_dict) to/from disk.
load() sets _is_fitted = True and returns self.
embedder.save("d2v_weights.pt")
embedder2 = Dataset2VecEmbedder(get_model("dataset2vec"))
embedder2.load("d2v_weights.pt")
dataset2vec(model, X, y, fit_data=None, **kwargs)
One-shot helper: optionally fit then embed.
If fit_data is None, the model must already be trained.
from data_meta_map.dataset2vec_embedder import dataset2vec
from data_meta_map.models import get_model
model = get_model("dataset2vec")
vec = dataset2vec(model, X_test, y_test, fit_data=train_datasets)
The Dataset2Vec architecture is controlled by
data_meta_map.dataset2vec.config.Dataset2VecConfig:
from data_meta_map.dataset2vec.config import Dataset2VecConfig, OptimizerConfig
from data_meta_map.dataset2vec.model import Dataset2Vec
import torch.nn as nn
cfg = Dataset2VecConfig(
output_size=32, # embedding dimension
f_dense_hidden_size=64, # encoder hidden size
activation_cls=nn.GELU, # activation function
)
opt_cfg = OptimizerConfig(learning_rate=5e-4)
model = Dataset2Vec(cfg, opt_cfg)
sqrtm_newton_schulz(A, num_iters=20)
Matrix square root via Newton-Schulz iteration. Adapted from OTDD. Works on GPU without leaving PyTorch's compute graph.
| Parameter | Type | Description |
|---|---|---|
| A | Tensor [d, d] | Square positive semi-definite matrix. |
| num_iters | int | Number of Newton-Schulz iterations (default: 20). |
Returns: sqrtA: Tensor [d, d]
compute_bures_term(cov1, cov2, sqrt_cov1=None, diagonal_cov=False, num_iters=20)
Compute the covariance term of the Bures-Wasserstein distance:
Tr(Σ₁ + Σ₂ − 2(Σ₁½ Σ₂ Σ₁½)½)
Returns: scalar tensor.