From 824e6c02b64e2bb3749a9d9ce42d45200abaaec6 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 14 Oct 2024 11:59:53 +0200 Subject: [PATCH] Add a check for same object partitioners --- datasets/flwr_datasets/federated_dataset.py | 19 +++++++ .../flwr_datasets/federated_dataset_test.py | 52 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 72ea5477356..509716c852e 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -128,6 +128,7 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) + self._check_partitioners_correctness() self._shuffle = shuffle self._seed = seed # _dataset is prepared lazily on the first call to `load_partition` @@ -336,3 +337,21 @@ def _check_if_no_split_keyword_possible(self) -> None: "Please set the `split` argument. You can only omit the split keyword " "if there is exactly one partitioner specified." ) + + def _check_partitioners_correctness(self) -> None: + """Check if the partitioners are correctly specified. + + Check if the multiple partitioner objects are not the same Python object, which + is not allowed, as the partitioner objects should be independent (one + partitioner per split). + """ + partitioners_keys = list(self._partitioners.keys()) + for i, first_split in enumerate(partitioners_keys): + for j in range(i + 1, len(partitioners_keys)): + second_split = partitioners_keys[j] + if self._partitioners[first_split] is self._partitioners[second_split]: + raise ValueError( + f"The same partitioner object is used for multiple splits: " + f"('{first_split}', '{second_split}'). " + "Each partitioner should be a separate object." + ) diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index bbdfa42292c..895d4419bf7 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -32,6 +32,7 @@ _load_mocked_dataset_dict_by_partial_download, ) from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner +from flwr_datasets.preprocessor.divider import Divider mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"] @@ -568,6 +569,57 @@ def test_use_load_dataset_kwargs(self) -> None: with self.assertRaises(ValueError): _ = fds.load_partition(0) + def test_incorrect_two_partitioners(self) -> None: + """Test if the method raises ValueError with incorrect partitioners.""" + partitioner = IidPartitioner(num_partitions=10) + partitioners: dict[str, Partitioner | int] = { + "train": partitioner, + "test": partitioner, + } + first_split = "train" + second_split = "test" + with self.assertRaises(ValueError) as context: + FederatedDataset( + dataset="mnist", + partitioners=partitioners, + ) + self.assertIn( + f"The same partitioner object is used for multiple splits: " + f"('{first_split}', '{second_split}'). " + "Each partitioner should be a separate object.", + str(context.exception), + ) + + def test_incorrect_three_partitioners(self) -> None: + """Test if the method raises ValueError with incorrect partitioners.""" + partitioner = IidPartitioner(num_partitions=10) + partitioners: dict[str, int | Partitioner] = { + "train1": partitioner, + "train2": 10, + "test": partitioner, + } + divider = Divider( + divide_config={ + "train1": 0.5, + "train2": 0.5, + }, + divide_split="train", + ) + + with self.assertRaises( + ValueError, + ) as context: + + FederatedDataset( + dataset="mnist", partitioners=partitioners, preprocessor=divider + ) + + self.assertIn( + "The same partitioner object is used for multiple splits: " + "('train1', 'test'). Each partitioner should be a separate object.", + str(context.exception), + ) + def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: """Check if two Datasets have the same values."""