DataMetaMap is a Python library designed to represent multiple datasets in the same vector space for direct comparison. The library offers a suite of advanced dataset embedding techniques compatible with PyTorch.
In transfer learning, choosing the right pre-trained model for a new task is expensive. DataMetaMap solves this by embedding datasets into a shared vector space, so that semantically similar datasets appear close together. If a model performs well on dataset A, it will likely perform well on datasets near A in embedding space — dramatically narrowing the candidate search.
Kernel-based measure of distance between two probability distributions.
ReviewOptimal transport distances between class distributions, embedded via MDS.
Paperfrom data_meta_map import WassersteinEmbedder
from torchvision.datasets import CIFAR10, MNIST
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
])
cifar = CIFAR10(root="./data", train=True, transform=transform, download=True)
mnist = MNIST(root="./data", train=True, transform=transform, download=True)
embedder = WassersteinEmbedder(emb_dim=64)
task_embeddings, label_embeddings, _ = embedder.compute_wte([cifar, mnist])
print("Task embeddings shape:", task_embeddings.shape)
# → Task embeddings shape: torch.Size([2, ref_size, 64 + feature_dim])
See Quickstart for a full walkthrough.
Guidelines
Code