DataMetaMap

Quickstart

This guide shows the basic workflow: loading datasets, computing Wasserstein Task Embeddings, and visualizing similarity between datasets.

Open Notebook

1 Install and import

from data_meta_map import WassersteinEmbedder
from data_meta_map.task2vec.task2vec import Task2Vec

import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST, STL10

2 Load datasets

DataMetaMap works with any torch.utils.data.Dataset. Below we use three standard image datasets.

transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

cifar10 = CIFAR10(root="./data", train=True, transform=transform, download=True)
mnist   = MNIST(root="./data",   train=True, transform=transform, download=True)
stl10   = STL10(root="./data",   split="train", transform=transform, download=True)

datasets = [cifar10, mnist, stl10]

3 Compute Wasserstein Task Embeddings

WassersteinEmbedder maps each dataset into a shared embedding space using Optimal Transport distances between class distributions.

embedder = WassersteinEmbedder(
    emb_dim=32,          # dimensionality of class embeddings
    device="cuda",       # or "cpu"
    max_samples=2000,    # subsample for speed
    gaussian_assumption=True,   # use Bures-Wasserstein (faster)
)

task_embeddings, label_embeddings, augmented = embedder.compute_wte(datasets)

print("Task embeddings :", task_embeddings.shape)
# torch.Size([3, ref_size, feature_dim + 32])

print("Label embeddings:", label_embeddings.shape)
# torch.Size([total_classes, 32])

4 Compute pairwise distances

Use the distance matrix to rank datasets by similarity to a target.

D = embedder.compute_pairwise_distances(datasets, symmetric=True)
print("Distance matrix:\n", D)
# tensor([[0.0000, 0.4231, 0.6782],
#         [0.4231, 0.0000, 0.5103],
#         [0.6782, 0.5103, 0.0000]])

5 Visualize with MDS

import matplotlib.pyplot as plt
from sklearn.manifold import MDS

D_np = D.cpu().numpy()
mds = MDS(n_components=2, dissimilarity="precomputed", random_state=42)
coords = mds.fit_transform(D_np)

names = ["CIFAR-10", "MNIST", "STL-10"]
plt.figure(figsize=(5, 4))
plt.scatter(coords[:, 0], coords[:, 1], s=120, zorder=3)
for i, name in enumerate(names):
    plt.annotate(name, coords[i], textcoords="offset points", xytext=(8, 4))
plt.title("Dataset similarity map")
plt.axis("equal")
plt.tight_layout()
plt.savefig("dataset_map.png", dpi=150)
plt.show()

6 Using Task2Vec

Task2Vec computes the diagonal of the Fisher Information Matrix of a pre-trained probe network as a task embedding. The library ships ResNet probes in data_meta_map.models that already expose the required .classifier property.

from data_meta_map.models import resnet18
from data_meta_map.task2vec.task2vec import Task2Vec, task2vec
from data_meta_map.task2vec.task_similarity import pdist

# Use the built-in ResNet-18 probe (already implements ProbeNetwork)
probe = resnet18(pretrained=True)

# Embed each dataset
embeddings = [
    task2vec(probe, ds, max_samples=500)
    for ds in [cifar10, mnist, stl10]
]

# Pairwise distances
D = pdist(embeddings, distance="cosine")
print("Distance matrix:\n", D)
# [[0.    0.312 0.481]
#  [0.312 0.    0.395]
#  [0.481 0.395 0.   ]]

# Or use the class directly for more control
embedder = Task2Vec(probe, max_samples=500, method="montecarlo")
embedding = embedder.embed(cifar10)
print("Hessian shape:", embedding.hessian.shape)

7 Using Dataset2Vec

Dataset2Vec learns a permutation-invariant meta-feature extractor for tabular datasets via contrastive training.

import torch
from data_meta_map.models import get_model
from data_meta_map.dataset2vec_embedder import Dataset2VecEmbedder

# Build model with default config (output_size=16)
model    = get_model("dataset2vec")
embedder = Dataset2VecEmbedder(model, max_epochs=20, batch_size=32)

# train_datasets: list of DataFrames / NDArrays (last column = target)
embedder.fit(train_datasets)

# Embed a new tabular dataset
X = torch.randn(200, 10)   # 200 samples, 10 features
y = torch.randint(0, 3, (200,)).float()

vec = embedder.embed(X, y)
print("Dataset2Vec embedding:", vec.shape)  # (16,)

# Save / reload
embedder.save("d2v.pt")
embedder2 = Dataset2VecEmbedder(get_model("dataset2vec"))
embedder2.load("d2v.pt")
Tip: For a fully worked example with plots and benchmarks, see the demo notebook.