DataMetaMap

DataMetaMap

License Contributors Issues PRs

DataMetaMap is a Python library designed to represent multiple datasets in the same vector space for direct comparison. The library offers a suite of advanced dataset embedding techniques compatible with PyTorch.

Motivation

In transfer learning, choosing the right pre-trained model for a new task is expensive. DataMetaMap solves this by embedding datasets into a shared vector space, so that semantically similar datasets appear close together. If a model performs well on dataset A, it will likely perform well on datasets near A in embedding space — dramatically narrowing the candidate search.

Algorithms

Maximum Mean Discrepancy

Kernel-based measure of distance between two probability distributions.

Review

Task2Vec

Fisher information-based task embeddings using a probe neural network.

Paper

Dataset2Vec

Meta-features extracted from dataset characteristics via neural networks.

Paper

Wasserstein Task Embedding

Optimal transport distances between class distributions, embedded via MDS.

Paper

Quick Example

from data_meta_map import WassersteinEmbedder
from torchvision.datasets import CIFAR10, MNIST
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
])

cifar = CIFAR10(root="./data", train=True, transform=transform, download=True)
mnist = MNIST(root="./data",  train=True, transform=transform, download=True)

embedder = WassersteinEmbedder(emb_dim=64)
task_embeddings, label_embeddings, _ = embedder.compute_wte([cifar, mnist])

print("Task embeddings shape:", task_embeddings.shape)
# → Task embeddings shape: torch.Size([2, ref_size, 64 + feature_dim])

See Quickstart for a full walkthrough.

Contents

Guidelines

Code