This guide shows the basic workflow: loading datasets, computing Wasserstein Task Embeddings, and visualizing similarity between datasets.
from data_meta_map import WassersteinEmbedder
from data_meta_map.task2vec.task2vec import Task2Vec
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST, STL10
DataMetaMap works with any torch.utils.data.Dataset.
Below we use three standard image datasets.
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
cifar10 = CIFAR10(root="./data", train=True, transform=transform, download=True)
mnist = MNIST(root="./data", train=True, transform=transform, download=True)
stl10 = STL10(root="./data", split="train", transform=transform, download=True)
datasets = [cifar10, mnist, stl10]
WassersteinEmbedder maps each dataset into a shared embedding space using
Optimal Transport distances between class distributions.
embedder = WassersteinEmbedder(
emb_dim=32, # dimensionality of class embeddings
device="cuda", # or "cpu"
max_samples=2000, # subsample for speed
gaussian_assumption=True, # use Bures-Wasserstein (faster)
)
task_embeddings, label_embeddings, augmented = embedder.compute_wte(datasets)
print("Task embeddings :", task_embeddings.shape)
# torch.Size([3, ref_size, feature_dim + 32])
print("Label embeddings:", label_embeddings.shape)
# torch.Size([total_classes, 32])
Use the distance matrix to rank datasets by similarity to a target.
D = embedder.compute_pairwise_distances(datasets, symmetric=True)
print("Distance matrix:\n", D)
# tensor([[0.0000, 0.4231, 0.6782],
# [0.4231, 0.0000, 0.5103],
# [0.6782, 0.5103, 0.0000]])
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
D_np = D.cpu().numpy()
mds = MDS(n_components=2, dissimilarity="precomputed", random_state=42)
coords = mds.fit_transform(D_np)
names = ["CIFAR-10", "MNIST", "STL-10"]
plt.figure(figsize=(5, 4))
plt.scatter(coords[:, 0], coords[:, 1], s=120, zorder=3)
for i, name in enumerate(names):
plt.annotate(name, coords[i], textcoords="offset points", xytext=(8, 4))
plt.title("Dataset similarity map")
plt.axis("equal")
plt.tight_layout()
plt.savefig("dataset_map.png", dpi=150)
plt.show()
Task2Vec computes the diagonal of the Fisher Information Matrix
of a pre-trained probe network as a task embedding.
The library ships ResNet probes in data_meta_map.models
that already expose the required .classifier property.
from data_meta_map.models import resnet18
from data_meta_map.task2vec.task2vec import Task2Vec, task2vec
from data_meta_map.task2vec.task_similarity import pdist
# Use the built-in ResNet-18 probe (already implements ProbeNetwork)
probe = resnet18(pretrained=True)
# Embed each dataset
embeddings = [
task2vec(probe, ds, max_samples=500)
for ds in [cifar10, mnist, stl10]
]
# Pairwise distances
D = pdist(embeddings, distance="cosine")
print("Distance matrix:\n", D)
# [[0. 0.312 0.481]
# [0.312 0. 0.395]
# [0.481 0.395 0. ]]
# Or use the class directly for more control
embedder = Task2Vec(probe, max_samples=500, method="montecarlo")
embedding = embedder.embed(cifar10)
print("Hessian shape:", embedding.hessian.shape)
Dataset2Vec learns a permutation-invariant meta-feature extractor for tabular datasets via contrastive training.
import torch
from data_meta_map.models import get_model
from data_meta_map.dataset2vec_embedder import Dataset2VecEmbedder
# Build model with default config (output_size=16)
model = get_model("dataset2vec")
embedder = Dataset2VecEmbedder(model, max_epochs=20, batch_size=32)
# train_datasets: list of DataFrames / NDArrays (last column = target)
embedder.fit(train_datasets)
# Embed a new tabular dataset
X = torch.randn(200, 10) # 200 samples, 10 features
y = torch.randint(0, 3, (200,)).float()
vec = embedder.embed(X, y)
print("Dataset2Vec embedding:", vec.shape) # (16,)
# Save / reload
embedder.save("d2v.pt")
embedder2 = Dataset2VecEmbedder(get_model("dataset2vec"))
embedder2.load("d2v.pt")