DataMetaMap

data_meta_map

The data_meta_map package provides tools for embedding datasets into a shared vector space via multiple algorithms. The main entry points are WassersteinEmbedder and Task2Vec, both inheriting from the abstract BaseEmbedder.

data_meta_map.base_embedder source

class data_meta_map.base_embedder.BaseEmbedder

Abstract base class for all dataset embedders. All concrete embedder classes must inherit from BaseEmbedder and implement the embed method.

abstractmethod embed(*args, **kwargs)

Override in subclass to produce embeddings for a given dataset. Raises NotImplementedError if called directly.

get_class_statistics(X, Y)

Compute per-class mean and covariance from feature and label tensors.

ParameterTypeDescription
X Tensor [N, D] Feature matrix.
Y Tensor [N], long Integer class labels.

Returns: (means, covs) — tensors of shape [C, D] and [C, D, D] where C = number of classes.

data_meta_map.wasserstein_embedder source

class data_meta_map.wasserstein_embedder.WassersteinEmbedder(
    emb_dim: int = 2,
    device: str | torch.device = "cpu",
    max_samples: int | None = None,
    batch_size: int = 64,
    gaussian_assumption: bool = True,
    diagonal_cov: bool = False,
    sqrt_method: str = "ns",
    sqrt_niters: int = 20,
)

Dataset embedder based on Wasserstein distance (Optimal Transport). Supports two distance modes:

Class-level embeddings are produced via Multidimensional Scaling (MDS).

ParameterTypeDescription
emb_dimintTarget dimensionality of label embeddings.
devicestr | torch.deviceComputation device ("cpu" or "cuda").
max_samplesint | NoneMaximum samples to load per dataset. None = all.
batch_sizeintBatch size for DataLoader creation.
gaussian_assumptionboolIf True, use Gaussian (Bures) approximation; else exact EMD.
diagonal_covboolUse only diagonal of covariance (faster on high-dim data).
sqrt_methodstrMatrix square root method: "ns" (Newton-Schulz) or "eig".
sqrt_nitersintIterations for Newton-Schulz matrix square root.

Methods

compute_wte(datasets, reference=None, create_reference=True)

Main method. Computes Wasserstein Transport Embeddings for a collection of datasets. Pipeline: pairwise distances → MDS label embeddings → feature augmentation → transport to reference.

ParameterTypeDescription
datasetsList[Dataset | DataLoader]Datasets to embed.
referenceTensor | NoneOptional reference distribution.
create_referenceboolIf True and reference is None, builds reference from merged data.

Returns: (task_embeddings, label_embeddings, augmented_datasets)

  • task_embeddings: [num_datasets, ref_size, feature_dim + emb_dim]
  • label_embeddings: [total_classes, emb_dim]
  • augmented_datasets: list of [num_samples, feature_dim + emb_dim]
embedder = WassersteinEmbedder(emb_dim=32, device="cpu", max_samples=1000)
task_embs, label_embs, augmented = embedder.compute_wte([cifar10, mnist])

compute_pairwise_distances(datasets, symmetric=True)

Compute a [total_classes × total_classes] pairwise distance matrix between all classes across all datasets. Global class indices are contiguous: dataset 0 occupies [0 … k₀-1], dataset 1 occupies [k₀ … k₀+k₁-1], and so on.

ParameterTypeDescription
datasetsList[Dataset | DataLoader]Input datasets.
symmetricboolCompute only upper triangle (saves ~50 % time).

Returns: Distance tensor [total_classes, total_classes].

preprocess_dataset(data, dataset_id=None)

cached Load and flatten dataset into (X, Y) tensors. Results are cached by dataset_id to avoid re-loading.

Returns: X: Tensor[N, D], Y: Tensor[N] (long).

embed_distance_matrix(distance_matrix, emb_dim=None)

Transform a precomputed distance matrix into low-dimensional embeddings via MDS.

Returns: Tensor[N, emb_dim].

augment_features(data, label_embeddings, dataset_idx, class_offsets)

Concatenate label embeddings to each sample's feature vector.

Returns: Z: Tensor[N, feature_dim + emb_dim].

clear_cache()

Clear internal data and statistics caches to free memory.

embed(datasets, **kwargs)

Satisfies the BaseEmbedder abstract interface. Delegates to compute_wte(datasets, **kwargs) and returns (task_embeddings, label_embeddings, augmented_datasets).

data_meta_map.task2vec.Task2Vec source

class data_meta_map.task2vec.task2vec.Task2Vec(
    model: ProbeNetwork,
    skip_layers: int = 0,
    max_samples: int | None = None,
    classifier_opts: dict | None = None,
    method: str = "montecarlo",
    method_opts: dict | None = None,
    loader_opts: dict | None = None,
    bernoulli: bool = False,
)

Task2Vec embeds a dataset as a fixed-length vector by computing the diagonal of the Fisher Information Matrix (FIM) of a pre-trained probe network fine-tuned on the target task. Intuitively, parameters that change a lot when adapting to a task carry more information about that task — and vice versa. See Achille et al., ICCV 2019.

ParameterTypeDescription
modelProbeNetworkPre-trained network exposing a .classifier property (e.g. ResNet-18).
skip_layersintNumber of initial layers to skip. Cached activations are used as inputs instead. Useful for large models.
max_samplesint | NoneMaximum samples to use from the dataset. None = all.
classifier_optsdict | NoneKwargs forwarded to _fit_classifier (e.g. {"epochs": 5, "learning_rate": 1e-3}).
methodstrFisher approximation: "montecarlo" (default, faster) or "variational" (more accurate).
method_optsdict | NoneKwargs forwarded to the Fisher method (e.g. {"epochs": 2}).
loader_optsdict | NoneKwargs for the internal DataLoader (e.g. {"batch_size": 32, "num_workers": 4}).
bernoulliboolIf True, use BCEWithLogitsLoss (binary classification). Default uses CrossEntropyLoss.

Methods

embed(dataset, create_final_embedding=True)

Full pipeline: cache features → fit classifier → compute Fisher → extract embedding.

ParameterTypeDescription
datasetDatasetTarget PyTorch dataset to embed.
create_final_embeddingboolIf True, returns hessian / scale (normalized numpy array). If False, returns the raw Embedding object.

Returns: ndarray or Embedding (see above).

from data_meta_map.task2vec.task2vec import task2vec
from data_meta_map.models import resnet18

probe = resnet18(pretrained=True)
embedding = task2vec(probe, cifar10, max_samples=500)
print(embedding.hessian.shape)  # (num_params,)

compute_fisher(dataset)

Compute the diagonal FIM approximation and store it in each layer's weight.grad2_acc attribute. Dispatches to montecarlo_fisher or variational_fisher depending on self.method.

montecarlo_fisher(dataset, epochs=1)

Monte-Carlo approximation of the FIM. For each mini-batch, samples y from the model's output distribution (not from the dataset labels) and accumulates squared gradients. Fast but approximate.

variational_fisher(dataset, epochs=1, beta=1e-7)

Variational approximation: optimises per-weight noise variance subject to a KL regulariser. Slower but more accurate than Monte-Carlo.

extract_embedding(model)

Read grad2_acc values stored by compute_fisher and pack them into an Embedding object. Classifier weights are excluded.

Returns: Embedding with .hessian and .scale.

data_meta_map.task2vec.ProbeNetwork source

class data_meta_map.task2vec.task2vec.ProbeNetwork ABC

Abstract base class that all probe networks must inherit from. Extends torch.nn.Module. The only requirement is exposing a classifier property that returns the final classification layer (e.g. a nn.Linear). Task2Vec uses this property to fit and exclude the head during FIM extraction.

classifier abstractproperty

Returns the final classification sub-module (e.g. the last fully-connected layer).

The library ships a ready-made ResNet probe in data_meta_map.models:

from data_meta_map.models import resnet18, resnet34

# Pre-trained on ImageNet
probe = resnet18(pretrained=True, num_classes=10)

To use your own network, subclass ProbeNetwork:

from data_meta_map.task2vec.task2vec import ProbeNetwork
import torch.nn as nn

class MyProbe(ProbeNetwork):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(nn.Linear(512, 256), nn.ReLU())
        self.fc = nn.Linear(256, 10)
        self.layers = [self.backbone, self.fc]

    @property
    def classifier(self):
        return self.fc

    @classifier.setter
    def classifier(self, val):
        self.fc = val

    def forward(self, x, start_from=0):
        for layer in self.layers[start_from:]:
            x = layer(x)
        return x

data_meta_map.task2vec.Embedding source

class data_meta_map.task2vec.task2vec.Embedding(
    hessian: array-like,
    scale: array-like,
    meta: Any = None,
)

Plain data class returned by Task2Vec.extract_embedding. Stores the per-parameter Fisher diagonal and layer scales as numpy arrays.

AttributeTypeDescription
hessianndarray [P]Filter-wise diagonal FIM; one value per convolutional filter / linear unit. Larger values → parameter more task-relevant.
scalendarray [P]Prior scale for each parameter, used for normalization. Typically all-ones for Monte-Carlo method.
metaAnyOptional metadata (e.g. task name, dataset info).

The normalized embedding used for distance computation is hessian / scale:

raw_emb = task2vec(probe, dataset, create_final_embedding=False)
normalized = raw_emb.hessian / raw_emb.scale  # same as create_final_embedding=True

data_meta_map.task2vec.task_similarity source

Utility module for comparing Embedding objects. Provides pairwise distance functions and matrix-building helpers.

Distance functions

All functions accept two Embedding objects and return a scalar distance.

FunctionFormulaNotes
cosine(e0, e1) Cosine distance on scaled Hessians Most commonly used. Range [0, 2].
kl(e0, e1) Symmetric KL divergence on variances Symmetric: max(KL(p‖q), KL(q‖p)). Always ≥ 0.
asymmetric_kl(e0, e1) KL(p‖q) only Not symmetric. Useful for transfer direction analysis.
jsd(e0, e1) Jensen-Shannon divergence Symmetric, bounded. Range [0, log 2].
normalized_cosine(e0, e1) Cosine on normalized variances Scale-invariant variant of cosine.
correlation(e0, e1) Correlation distance on variances 1 - Pearson correlation.

Matrix helpers

pdist(embeddings, distance='cosine')

Compute all-pairs distance matrix for a list of embeddings. Result is symmetric with zero diagonal.

Returns: ndarray [N, N]

from data_meta_map.task2vec.task_similarity import pdist

embeddings = [task2vec(probe, ds) for ds in [cifar10, mnist, stl10]]
D = pdist(embeddings, distance="cosine")
# array([[0.   , 0.312, 0.481],
#        [0.312, 0.   , 0.395],
#        [0.481, 0.395, 0.   ]])

cdist(from_embeddings, to_embeddings, distance='cosine')

Cross-distance matrix between two lists of embeddings.

Returns: ndarray [M, N]

plot_distance_matrix(embeddings, labels=None, distance='cosine')

Render a clustered heatmap of pairwise distances using seaborn. Requires seaborn, scipy, matplotlib.

data_meta_map.dataset2vec_embedder source

class data_meta_map.dataset2vec_embedder.Dataset2VecEmbedder(
    model: Dataset2Vec,
    max_epochs: int = 10,
    batch_size: int = 32,
    n_batches: int = 100,
)

Dataset embedder based on Dataset2Vec (Iwata & Ghahramani, 2020). Learns a permutation-invariant meta-feature extractor for tabular datasets via contrastive training. The encoder operates on (feature, target) pairs and aggregates them with two levels of mean-pooling.

ParameterTypeDescription
modelDataset2VecPre-initialized model. Create with get_model('dataset2vec') or Dataset2Vec(config, opt_config).
max_epochsintTraining epochs for fit().
batch_sizeintBatch size for the data loader.
n_batchesintNumber of batches per epoch.

Methods

fit(data, val_data=None, trainer_kwargs=None)

Train the Dataset2Vec model on a collection of tabular datasets.

ParameterTypeDescription
dataPath | list[Path] | list[DataFrame] | listTraining datasets. Each element is one dataset; the last column is the target.
val_datasame | NoneOptional validation data.
trainer_kwargsdict | NoneExtra kwargs for pytorch_lightning.Trainer.

Returns: self (for chaining).

embed(X, y)

Compute embedding for a single tabular dataset. Must call fit() first.

ParameterTypeDescription
XTensor [n, d]Feature matrix.
yTensor [n] or [n, 1]Target vector.

Returns: ndarray [output_size]

Raises: RuntimeError if called before fit().

from data_meta_map.models import get_model
from data_meta_map.dataset2vec_embedder import Dataset2VecEmbedder
import torch

model = get_model("dataset2vec")
embedder = Dataset2VecEmbedder(model, max_epochs=20)
embedder.fit(train_datasets)

X = torch.randn(100, 10)
y = torch.randint(0, 3, (100,)).float()
vec = embedder.embed(X, y)
print(vec.shape)  # (16,)  — default output_size

save(path) / load(path)

Save and load model weights (state_dict) to/from disk. load() sets _is_fitted = True and returns self.

embedder.save("d2v_weights.pt")
embedder2 = Dataset2VecEmbedder(get_model("dataset2vec"))
embedder2.load("d2v_weights.pt")

Convenience function

dataset2vec(model, X, y, fit_data=None, **kwargs)

One-shot helper: optionally fit then embed. If fit_data is None, the model must already be trained.

from data_meta_map.dataset2vec_embedder import dataset2vec
from data_meta_map.models import get_model

model = get_model("dataset2vec")
vec = dataset2vec(model, X_test, y_test, fit_data=train_datasets)

Model configuration

The Dataset2Vec architecture is controlled by data_meta_map.dataset2vec.config.Dataset2VecConfig:

from data_meta_map.dataset2vec.config import Dataset2VecConfig, OptimizerConfig
from data_meta_map.dataset2vec.model import Dataset2Vec
import torch.nn as nn

cfg = Dataset2VecConfig(
    output_size=32,          # embedding dimension
    f_dense_hidden_size=64,  # encoder hidden size
    activation_cls=nn.GELU,  # activation function
)
opt_cfg = OptimizerConfig(learning_rate=5e-4)
model = Dataset2Vec(cfg, opt_cfg)

Helper functions source

sqrtm_newton_schulz(A, num_iters=20)

Matrix square root via Newton-Schulz iteration. Adapted from OTDD. Works on GPU without leaving PyTorch's compute graph.

ParameterTypeDescription
ATensor [d, d]Square positive semi-definite matrix.
num_itersintNumber of Newton-Schulz iterations (default: 20).

Returns: sqrtA: Tensor [d, d]

compute_bures_term(cov1, cov2, sqrt_cov1=None, diagonal_cov=False, num_iters=20)

Compute the covariance term of the Bures-Wasserstein distance:

Tr(Σ₁ + Σ₂ − 2(Σ₁½ Σ₂ Σ₁½)½)

Returns: scalar tensor.