Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAM-efficient Retrieval #23

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions mmlearn/modules/metrics/retrieval_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchmetrics.utilities.compute import _safe_matmul
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.distributed import gather_all_tensors
from tqdm import tqdm


@store(group="modules/metrics", provider="mmlearn")
Expand Down Expand Up @@ -159,7 +160,7 @@ def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> Non
self._batch_size = x.size(0) # global batch size

def compute(self) -> torch.Tensor:
"""Compute the metric.
"""Compute the metric in a RAM-efficient manner.

Returns
-------
Expand All @@ -169,10 +170,11 @@ def compute(self) -> torch.Tensor:
x = dim_zero_cat(self.x)
y = dim_zero_cat(self.y)

# compute the cosine similarity
# normalize embeddings
x_norm = x / x.norm(dim=-1, p=2, keepdim=True)
y_norm = y / y.norm(dim=-1, p=2, keepdim=True)
similarity = _safe_matmul(x_norm, y_norm)

# instantiate reduction function
reduction_mapping: Dict[
Optional[str], Callable[[torch.Tensor], torch.Tensor]
] = {
Expand All @@ -181,18 +183,24 @@ def compute(self) -> torch.Tensor:
"none": lambda x: x,
None: lambda x: x,
}
scores: torch.Tensor = reduction_mapping[self.reduction](similarity)

# concatenate indexes of true pairs
indexes = dim_zero_cat(self.indexes)
positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
positive_pairs[torch.arange(len(scores)), indexes] = True

results = []
for start in range(0, len(scores), self._batch_size):
for start in tqdm(
range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}"
):
end = start + self._batch_size
x = scores[start:end]
y = positive_pairs[start:end]
result = recall_at_k(x, y, self.top_k)
# compute the cosine similarity
x_norm_batch = x_norm[start:end]
similarity = _safe_matmul(x_norm_batch, y_norm)
scores: torch.Tensor = reduction_mapping[self.reduction](similarity)
indexes_batch = indexes[start:end]
positive_pairs = torch.zeros_like(scores, dtype=torch.bool)
positive_pairs[torch.arange(len(scores)), indexes_batch] = True
# compute recall_at_k
result = recall_at_k(scores, positive_pairs, self.top_k)
results.append(result)

return _retrieval_aggregate(
Expand Down
32 changes: 21 additions & 11 deletions mmlearn/tasks/zero_shot_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Zero-shot cross-modal retrieval evaluation task."""

from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union

import lightning.pytorch as pl
import torch
import torch.distributed
import torch.distributed.nn
from hydra_zen import store
from torchmetrics import MetricCollection
from torchmetrics import Metric, MetricCollection

from mmlearn.datasets.core import Modalities
from mmlearn.datasets.core.modalities import Modality
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]):
super().__init__()

self.task_specs = task_specs
self.metrics: Dict[Tuple[Modality, Modality], MetricCollection] = {}
self.metrics: Union[Dict[str, Metric], MetricCollection] = {}

for spec in self.task_specs:
assert Modalities.has_modality(spec.query_modality)
Expand All @@ -57,19 +57,28 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]):
query_modality = Modalities.get_modality(spec.query_modality)
target_modality = Modalities.get_modality(spec.target_modality)

self.metrics[(query_modality, target_modality)] = MetricCollection(
self.metrics.update(
{
f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK(
top_k=k, aggregation="mean", reduction="none"
)
for k in spec.top_k
}
)
self.metrics = MetricCollection(self.metrics)

self.modality_pairs = [
(key.split("_to_")[0], key.split("_to_")[1].split("_R@")[0])
for key in self.metrics
]
self.modality_pairs = [
(Modalities.get_modality(query), Modalities.get_modality(target))
for (query, target) in self.modality_pairs
]

def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None:
"""Move the metrics to the device of the Lightning module."""
for metric in self.metrics.values():
metric.to(pl_module.device)
self.metrics.to(pl_module.device) # type: ignore [union-attr]

def evaluation_step(
self,
Expand All @@ -96,7 +105,9 @@ def evaluation_step(
return

outputs: Dict[Union[str, Modality], Any] = pl_module(batch)
for (query_modality, target_modality), metric in self.metrics.items():
for (query_modality, target_modality), metric in zip(
self.modality_pairs, self.metrics.values()
):
query_embeddings: torch.Tensor = outputs[query_modality.embedding]
target_embeddings: torch.Tensor = outputs[target_modality.embedding]
indexes = torch.arange(query_embeddings.size(0), device=pl_module.device)
Expand All @@ -111,9 +122,8 @@ def on_evaluation_epoch_end(self, pl_module: pl.LightningModule) -> Dict[str, An
pl_module : pl.LightningModule
A reference to the Lightning module being evaluated.
"""
results = {}
for metric in self.metrics.values():
results.update(metric.compute())
metric.reset()
results: Dict[str, Any] = {}
results.update(self.metrics.compute()) # type: ignore [union-attr]
self.metrics.reset() # type: ignore [union-attr]

return results