diff --git a/mmlearn/modules/metrics/retrieval_recall.py b/mmlearn/modules/metrics/retrieval_recall.py index abee79d..91f9ba6 100644 --- a/mmlearn/modules/metrics/retrieval_recall.py +++ b/mmlearn/modules/metrics/retrieval_recall.py @@ -12,6 +12,7 @@ from torchmetrics.utilities.compute import _safe_matmul from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.distributed import gather_all_tensors +from tqdm import tqdm @store(group="modules/metrics", provider="mmlearn") @@ -159,7 +160,7 @@ def update(self, x: torch.Tensor, y: torch.Tensor, indexes: torch.Tensor) -> Non self._batch_size = x.size(0) # global batch size def compute(self) -> torch.Tensor: - """Compute the metric. + """Compute the metric in a RAM-efficient manner. Returns ------- @@ -169,10 +170,11 @@ def compute(self) -> torch.Tensor: x = dim_zero_cat(self.x) y = dim_zero_cat(self.y) - # compute the cosine similarity + # normalize embeddings x_norm = x / x.norm(dim=-1, p=2, keepdim=True) y_norm = y / y.norm(dim=-1, p=2, keepdim=True) - similarity = _safe_matmul(x_norm, y_norm) + + # instantiate reduction function reduction_mapping: Dict[ Optional[str], Callable[[torch.Tensor], torch.Tensor] ] = { @@ -181,18 +183,24 @@ def compute(self) -> torch.Tensor: "none": lambda x: x, None: lambda x: x, } - scores: torch.Tensor = reduction_mapping[self.reduction](similarity) + # concatenate indexes of true pairs indexes = dim_zero_cat(self.indexes) - positive_pairs = torch.zeros_like(scores, dtype=torch.bool) - positive_pairs[torch.arange(len(scores)), indexes] = True results = [] - for start in range(0, len(scores), self._batch_size): + for start in tqdm( + range(0, len(x), self._batch_size), desc=f"Recall@{self.top_k}" + ): end = start + self._batch_size - x = scores[start:end] - y = positive_pairs[start:end] - result = recall_at_k(x, y, self.top_k) + # compute the cosine similarity + x_norm_batch = x_norm[start:end] + similarity = _safe_matmul(x_norm_batch, y_norm) + scores: torch.Tensor = reduction_mapping[self.reduction](similarity) + indexes_batch = indexes[start:end] + positive_pairs = torch.zeros_like(scores, dtype=torch.bool) + positive_pairs[torch.arange(len(scores)), indexes_batch] = True + # compute recall_at_k + result = recall_at_k(scores, positive_pairs, self.top_k) results.append(result) return _retrieval_aggregate( diff --git a/mmlearn/tasks/zero_shot_retrieval.py b/mmlearn/tasks/zero_shot_retrieval.py index 982dbca..6bb9b43 100644 --- a/mmlearn/tasks/zero_shot_retrieval.py +++ b/mmlearn/tasks/zero_shot_retrieval.py @@ -1,14 +1,14 @@ """Zero-shot cross-modal retrieval evaluation task.""" from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Union import lightning.pytorch as pl import torch import torch.distributed import torch.distributed.nn from hydra_zen import store -from torchmetrics import MetricCollection +from torchmetrics import Metric, MetricCollection from mmlearn.datasets.core import Modalities from mmlearn.datasets.core.modalities import Modality @@ -48,7 +48,7 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]): super().__init__() self.task_specs = task_specs - self.metrics: Dict[Tuple[Modality, Modality], MetricCollection] = {} + self.metrics: Union[Dict[str, Metric], MetricCollection] = {} for spec in self.task_specs: assert Modalities.has_modality(spec.query_modality) @@ -57,7 +57,7 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]): query_modality = Modalities.get_modality(spec.query_modality) target_modality = Modalities.get_modality(spec.target_modality) - self.metrics[(query_modality, target_modality)] = MetricCollection( + self.metrics.update( { f"{query_modality}_to_{target_modality}_R@{k}": RetrievalRecallAtK( top_k=k, aggregation="mean", reduction="none" @@ -65,11 +65,20 @@ def __init__(self, task_specs: List[RetrievalTaskSpec]): for k in spec.top_k } ) + self.metrics = MetricCollection(self.metrics) + + self.modality_pairs = [ + (key.split("_to_")[0], key.split("_to_")[1].split("_R@")[0]) + for key in self.metrics + ] + self.modality_pairs = [ + (Modalities.get_modality(query), Modalities.get_modality(target)) + for (query, target) in self.modality_pairs + ] def on_evaluation_epoch_start(self, pl_module: pl.LightningModule) -> None: """Move the metrics to the device of the Lightning module.""" - for metric in self.metrics.values(): - metric.to(pl_module.device) + self.metrics.to(pl_module.device) # type: ignore [union-attr] def evaluation_step( self, @@ -96,7 +105,9 @@ def evaluation_step( return outputs: Dict[Union[str, Modality], Any] = pl_module(batch) - for (query_modality, target_modality), metric in self.metrics.items(): + for (query_modality, target_modality), metric in zip( + self.modality_pairs, self.metrics.values() + ): query_embeddings: torch.Tensor = outputs[query_modality.embedding] target_embeddings: torch.Tensor = outputs[target_modality.embedding] indexes = torch.arange(query_embeddings.size(0), device=pl_module.device) @@ -111,9 +122,8 @@ def on_evaluation_epoch_end(self, pl_module: pl.LightningModule) -> Dict[str, An pl_module : pl.LightningModule A reference to the Lightning module being evaluated. """ - results = {} - for metric in self.metrics.values(): - results.update(metric.compute()) - metric.reset() + results: Dict[str, Any] = {} + results.update(self.metrics.compute()) # type: ignore [union-attr] + self.metrics.reset() # type: ignore [union-attr] return results