diff --git a/.gitignore b/.gitignore index 9a52656a8..16988e7af 100644 --- a/.gitignore +++ b/.gitignore @@ -134,6 +134,7 @@ dmypy.json # vscode launch.json settings.json +.devcontainer* #mac .DS_Store @@ -169,6 +170,7 @@ settings.json **/*.pkl **/*.png **/*.pt +**/*.ckpt /metrics/ # dev diff --git a/examples/ae_examples/cvae_dim_example/server.py b/examples/ae_examples/cvae_dim_example/server.py index f00524198..f7159b9de 100644 --- a/examples/ae_examples/cvae_dim_example/server.py +++ b/examples/ae_examples/cvae_dim_example/server.py @@ -67,7 +67,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py index 4f97f1b94..4d3d8c2cc 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py @@ -68,7 +68,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py index 2459431aa..0be57fe18 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py @@ -68,7 +68,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/ae_examples/fedprox_vae_example/server.py b/examples/ae_examples/fedprox_vae_example/server.py index 0f84c7ab3..60ab69b43 100644 --- a/examples/ae_examples/fedprox_vae_example/server.py +++ b/examples/ae_examples/fedprox_vae_example/server.py @@ -70,7 +70,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/apfl_example/client.py b/examples/apfl_example/client.py index 6cd7900c8..2a5eebdd8 100644 --- a/examples/apfl_example/client.py +++ b/examples/apfl_example/client.py @@ -13,6 +13,7 @@ from examples.models.cnn_model import MnistNetWithBnAndFrozen from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -58,7 +59,6 @@ def get_criterion(self, config: Config) -> _Loss: # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistApflClient(data_path, [Accuracy()], DEVICE) + client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) - - client.metrics_reporter.dump() + client.shutdown() # This will tell the JsonReporter to dump data diff --git a/examples/apfl_example/server.py b/examples/apfl_example/server.py index 03b640577..01fd3a4d5 100644 --- a/examples/apfl_example/server.py +++ b/examples/apfl_example/server.py @@ -10,6 +10,7 @@ from examples.models.cnn_model import MnistNetWithBnAndFrozen from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting import JsonReporter from fl4health.server.base_server import FlServer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -59,7 +60,7 @@ def main(config: Dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager, strategy) + server = FlServer(client_manager, strategy, reporters=[JsonReporter()]) fl.server.start_server( server=server, @@ -67,7 +68,6 @@ def main(config: Dict[str, Any]) -> None: config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - server.metrics_reporter.dump() server.shutdown() diff --git a/examples/basic_example/server.py b/examples/basic_example/server.py index 961118866..def22b981 100644 --- a/examples/basic_example/server.py +++ b/examples/basic_example/server.py @@ -66,7 +66,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointers, ) diff --git a/examples/ditto_example/client.py b/examples/ditto_example/client.py index 9ad8c88f0..5c22b32db 100644 --- a/examples/ditto_example/client.py +++ b/examples/ditto_example/client.py @@ -14,6 +14,7 @@ from examples.models.cnn_model import MnistNet from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -68,10 +69,8 @@ def get_criterion(self, config: Config) -> _Loss: # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistDittoClient(data_path, [Accuracy()], DEVICE) + client = MnistDittoClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully client.shutdown() - - client.metrics_reporter.dump() diff --git a/examples/dp_fed_examples/instance_level_dp/server.py b/examples/dp_fed_examples/instance_level_dp/server.py index 738ac5387..ff2f27e98 100644 --- a/examples/dp_fed_examples/instance_level_dp/server.py +++ b/examples/dp_fed_examples/instance_level_dp/server.py @@ -1,5 +1,6 @@ import argparse import string +from collections.abc import Sequence from functools import partial from random import choices from typing import Any, Dict, Optional @@ -15,7 +16,7 @@ from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer, OpacusCheckpointer from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.instance_level_dp_server import InstanceLevelDpServer from fl4health.strategies.basic_fedavg import OpacusBasicFedAvg from fl4health.utils.config import load_config @@ -71,8 +72,8 @@ def __init__( strategy: OpacusBasicFedAvg, local_epochs: Optional[int] = None, local_steps: Optional[int] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[OpacusCheckpointer] = None, + reporters: Sequence[BaseReporter] | None = None, delta: Optional[float] = None, ) -> None: super().__init__( @@ -83,8 +84,8 @@ def __init__( strategy, local_epochs, local_steps, - wandb_reporter, checkpointer, + reporters, delta, ) self.parameter_exchanger = FullParameterExchanger() diff --git a/examples/feddg_ga_example/client.py b/examples/feddg_ga_example/client.py index 6cd7900c8..81023fd69 100644 --- a/examples/feddg_ga_example/client.py +++ b/examples/feddg_ga_example/client.py @@ -13,6 +13,7 @@ from examples.models.cnn_model import MnistNetWithBnAndFrozen from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -58,7 +59,6 @@ def get_criterion(self, config: Config) -> _Loss: # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistApflClient(data_path, [Accuracy()], DEVICE) + client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) - - client.metrics_reporter.dump() + client.shutdown() diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index 4cca07800..64c3429c1 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -9,6 +9,7 @@ from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting import JsonReporter from fl4health.server.base_server import FlServer from fl4health.strategies.feddg_ga_strategy import FedDgGaStrategy from fl4health.utils.config import load_config @@ -66,7 +67,7 @@ def main(config: Dict[str, Any]) -> None: # will return the same sampling until it is told to reset, which in FedDgGaStrategy # is done right before fit_round. client_manager = FixedSamplingClientManager() - server = FlServer(strategy=strategy, client_manager=client_manager) + server = FlServer(strategy=strategy, client_manager=client_manager, reporters=[JsonReporter()]) fl.server.start_server( server=server, @@ -74,7 +75,7 @@ def main(config: Dict[str, Any]) -> None: config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - server.metrics_reporter.dump() + server.shutdown() if __name__ == "__main__": diff --git a/examples/federated_eval_example/client.py b/examples/federated_eval_example/client.py index 57770a6dc..53bf742b1 100644 --- a/examples/federated_eval_example/client.py +++ b/examples/federated_eval_example/client.py @@ -11,7 +11,7 @@ from examples.models.cnn_model import Net from fl4health.clients.evaluate_client import EvaluateClient -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_test_data from fl4health.utils.losses import LossMeterType @@ -25,7 +25,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, model_checkpoint_path: Optional[Path], - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( data_path=data_path, @@ -33,7 +33,7 @@ def __init__( device=device, model_checkpoint_path=model_checkpoint_path, loss_meter_type=LossMeterType.AVERAGE, - metrics_reporter=metrics_reporter, + reporters=reporters, ) def initialize_global_model(self, config: Config) -> Optional[nn.Module]: diff --git a/examples/fedpca_examples/dim_reduction/server.py b/examples/fedpca_examples/dim_reduction/server.py index 2105c98a9..9cfd6b35d 100644 --- a/examples/fedpca_examples/dim_reduction/server.py +++ b/examples/fedpca_examples/dim_reduction/server.py @@ -67,7 +67,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/fedprox_example/client.py b/examples/fedprox_example/client.py index 16ad9bc45..a8871d03a 100644 --- a/examples/fedprox_example/client.py +++ b/examples/fedprox_example/client.py @@ -13,6 +13,7 @@ from examples.models.cnn_model import MnistNet from fl4health.clients.fed_prox_client import FedProxClient +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -64,10 +65,8 @@ def get_criterion(self, config: Config) -> _Loss: # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistFedProxClient(data_path, [Accuracy()], DEVICE) + client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully client.shutdown() - - client.metrics_reporter.dump() diff --git a/examples/fedprox_example/config.yaml b/examples/fedprox_example/config.yaml index 40c7c0e0d..4863105a2 100644 --- a/examples/fedprox_example/config.yaml +++ b/examples/fedprox_example/config.yaml @@ -16,10 +16,9 @@ local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training reporting_config: - enabled: False - project_name: FL4Health # Name of the project under which everything should be logged - run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name - group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored + project: FL4Health # Name of the project under which everything should be logged + name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name + group: "FedProx Experiment" # Group under which each of the FL run logging will be stored entity: "your_entity_here" # WandB user name notes: "Testing WB reporting" tags: ["Test", "FedProx"] diff --git a/examples/fedprox_example/server.py b/examples/fedprox_example/server.py index cbb1fc7f5..a73fea05f 100644 --- a/examples/fedprox_example/server.py +++ b/examples/fedprox_example/server.py @@ -10,7 +10,7 @@ from examples.models.cnn_model import MnistNet from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting import JsonReporter, WandBReporter from fl4health.server.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -22,24 +22,24 @@ def fit_config( batch_size: int, n_server_rounds: int, - reporting_enabled: bool, - project_name: str, - group_name: str, - entity: str, current_round: int, + reporting_config: Optional[Dict[str, str]] = None, local_epochs: Optional[int] = None, local_steps: Optional[int] = None, ) -> Config: - return { + base_config: Config = { **make_dict_with_epochs_or_steps(local_epochs, local_steps), "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "reporting_enabled": reporting_enabled, - "project_name": project_name, - "group_name": group_name, - "entity": entity, } + if reporting_config is not None: + # NOTE: that name is not included, it will be set in the clients + base_config["project"] = reporting_config.get("project", "") + base_config["group"] = reporting_config.get("group", "") + base_config["entity"] = reporting_config.get("entity", "") + + return base_config def main(config: Dict[str, Any], server_address: str) -> None: @@ -48,11 +48,7 @@ def main(config: Dict[str, Any], server_address: str) -> None: fit_config, config["batch_size"], config["n_server_rounds"], - config["reporting_config"].get("enabled", False), - # Note that run name is not included, it will be set in the clients - config["reporting_config"].get("project_name", ""), - config["reporting_config"].get("group_name", ""), - config["reporting_config"].get("entity", ""), + reporting_config=config.get("reporting_config"), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), ) @@ -78,9 +74,15 @@ def main(config: Dict[str, Any], server_address: str) -> None: loss_weight_patience=config["proximal_weight_patience"], ) - wandb_reporter = ServerWandBReporter.from_config(config) + json_reporter = JsonReporter() client_manager = SimpleClientManager() - server = FedProxServer(client_manager=client_manager, strategy=strategy, model=None, wandb_reporter=wandb_reporter) + if "reporting_config" in config: + wandb_reporter = WandBReporter("round", **config["reporting_config"]) + server = FedProxServer( + client_manager=client_manager, strategy=strategy, model=None, reporters=[wandb_reporter, json_reporter] + ) + else: + server = FedProxServer(client_manager=client_manager, strategy=strategy, model=None, reporters=[json_reporter]) fl.server.start_server( server=server, @@ -89,7 +91,6 @@ def main(config: Dict[str, Any], server_address: str) -> None: ) # Shutdown the server gracefully server.shutdown() - server.metrics_reporter.dump() if __name__ == "__main__": diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py index 794fdb3c5..e8d662e72 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py @@ -71,7 +71,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py index fd5be6dfb..3c1735472 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py @@ -43,7 +43,10 @@ def main(config: Dict[str, Any]) -> None: # Initializing the model on the server side model: nn.Module = FedSimClrModel( - CifarSslEncoder(), CifarSslProjectionHead(), CifarSslPredictionHead(), pretrain=True + CifarSslEncoder(), + CifarSslProjectionHead(), + CifarSslPredictionHead(), + pretrain=True, ) # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() @@ -67,7 +70,6 @@ def main(config: Dict[str, Any]) -> None: client_manager=SimpleClientManager(), parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointer, ) diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index ba7123365..1ae7d2a1b 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -22,6 +22,7 @@ from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -114,7 +115,8 @@ def get_criterion(self, config: Config) -> _Loss: ) checkpointer = ClientCheckpointModule( - pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer + pre_aggregation=pre_aggregation_checkpointer, + post_aggregation=post_aggregation_checkpointer, ) client = MnistFendaDittoClient( data_path, @@ -122,10 +124,9 @@ def get_criterion(self, config: Config) -> _Loss: DEVICE, args.checkpoint_path, checkpointer=checkpointer, + reporters=[JsonReporter()], ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully client.shutdown() - - client.metrics_reporter.dump() diff --git a/examples/mr_mtl_example/client.py b/examples/mr_mtl_example/client.py index 3fb2da3bd..c8eda73b9 100644 --- a/examples/mr_mtl_example/client.py +++ b/examples/mr_mtl_example/client.py @@ -13,6 +13,7 @@ from examples.models.cnn_model import MnistNet from fl4health.clients.mr_mtl_client import MrMtlClient +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -64,10 +65,9 @@ def get_criterion(self, config: Config) -> _Loss: # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistMrMtlClient(data_path, [Accuracy()], DEVICE) + client = MnistMrMtlClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) + fl.client.start_client(server_address=args.server_address, client=client.to_client()) # Shutdown the client gracefully client.shutdown() - - client.metrics_reporter.dump() diff --git a/examples/scaffold_example/client.py b/examples/scaffold_example/client.py index 76d8bec87..19a9f4eb8 100644 --- a/examples/scaffold_example/client.py +++ b/examples/scaffold_example/client.py @@ -12,6 +12,7 @@ from examples.models.cnn_model import MnistNetWithBnAndFrozen from fl4health.clients.scaffold_client import ScaffoldClient +from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data from fl4health.utils.metrics import Accuracy @@ -55,8 +56,6 @@ def get_criterion(self, config: Config) -> _Loss: # Set the random seed for reproducibility set_all_random_seeds(args.seed) - client = MnistScaffoldClient(data_path, [Accuracy()], DEVICE) + client = MnistScaffoldClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()]) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() - - client.metrics_reporter.dump() diff --git a/examples/scaffold_example/server.py b/examples/scaffold_example/server.py index 8138c7de3..15b6c908a 100644 --- a/examples/scaffold_example/server.py +++ b/examples/scaffold_example/server.py @@ -7,6 +7,7 @@ from examples.models.cnn_model import MnistNetWithBnAndFrozen from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager +from fl4health.reporting import JsonReporter from fl4health.server.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config @@ -51,14 +52,19 @@ def main(config: Dict[str, Any]) -> None: # ClientManager that performs Poisson type sampling client_manager = PoissonSamplingClientManager() - server = ScaffoldServer(client_manager=client_manager, strategy=strategy, warm_start=True) + server = ScaffoldServer( + client_manager=client_manager, + strategy=strategy, + warm_start=True, + reporters=[JsonReporter()], + ) fl.server.start_server( server=server, server_address="0.0.0.0:8080", config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - server.metrics_reporter.dump() + server.shutdown() if __name__ == "__main__": diff --git a/examples/warm_up_example/fedavg_warm_up/config.yaml b/examples/warm_up_example/fedavg_warm_up/config.yaml index 93bb98c13..057c0e29a 100644 --- a/examples/warm_up_example/fedavg_warm_up/config.yaml +++ b/examples/warm_up_example/fedavg_warm_up/config.yaml @@ -7,10 +7,9 @@ local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training reporting_config: - enabled: False - project_name: FL4Health # Name of the project under which everything should be logged - run_name: "FedAvg Server" # Name of the run on the server-side, each client will also have it's own run name - group_name: "FedAvg Experiment" # Group under which each of the FL run logging will be stored + project: FL4Health # Name of the project under which everything should be logged + name: "FedAvg Server" # Name of the run on the server-side, each client will also have it's own run name + group: "FedAvg Experiment" # Group under which each of the FL run logging will be stored entity: "your_entity_here" # WandB user name notes: "Testing WB reporting" tags: ["Test", "FedAvg"] diff --git a/examples/warm_up_example/fedavg_warm_up/server.py b/examples/warm_up_example/fedavg_warm_up/server.py index ba343fe3b..553604600 100644 --- a/examples/warm_up_example/fedavg_warm_up/server.py +++ b/examples/warm_up_example/fedavg_warm_up/server.py @@ -10,7 +10,7 @@ from examples.models.cnn_model import MnistNet from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting import WandBReporter from fl4health.server.base_server import FlServer from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.utils.config import load_config @@ -22,9 +22,8 @@ def fit_config( batch_size: int, n_server_rounds: int, - reporting_enabled: bool, - project_name: str, - group_name: str, + project: str, + group: str, entity: str, current_round: int, local_epochs: Optional[int] = None, @@ -35,9 +34,8 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "reporting_enabled": reporting_enabled, - "project_name": project_name, - "group_name": group_name, + "project": project, + "group": group, "entity": entity, } @@ -48,10 +46,9 @@ def main(config: Dict[str, Any], server_address: str) -> None: fit_config, config["batch_size"], config["n_server_rounds"], - config["reporting_config"].get("enabled", False), - # Note that run name is not included, it will be set in the clients - config["reporting_config"].get("project_name", ""), - config["reporting_config"].get("group_name", ""), + # NOTE: that name is not included, it will be set in the clients + config["reporting_config"].get("project", ""), + config["reporting_config"].get("group", ""), config["reporting_config"].get("entity", ""), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), @@ -73,9 +70,10 @@ def main(config: Dict[str, Any], server_address: str) -> None: initial_parameters=get_all_model_parameters(initial_model), ) - wandb_reporter = ServerWandBReporter.from_config(config) + if "reporting_config" in config: + wandb_reporter = WandBReporter("round", **config["reporting_config"]) client_manager = SimpleClientManager() - server = FlServer(client_manager, strategy, wandb_reporter) + server = FlServer(client_manager, strategy, reporters=[wandb_reporter]) fl.server.start_server( server=server, diff --git a/examples/warm_up_example/warmed_up_fedprox/config.yaml b/examples/warm_up_example/warmed_up_fedprox/config.yaml index bb61d3a1d..420590b72 100644 --- a/examples/warm_up_example/warmed_up_fedprox/config.yaml +++ b/examples/warm_up_example/warmed_up_fedprox/config.yaml @@ -16,10 +16,9 @@ local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training reporting_config: - enabled: False - project_name: FL4Health # Name of the project under which everything should be logged - run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name - group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored + project: FL4Health # Name of the project under which everything should be logged + name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name + group: "FedProx Experiment" # Group under which each of the FL run logging will be stored entity: "your_entity_here" # WandB user name notes: "Testing WB reporting" tags: ["Test", "FedProx"] diff --git a/examples/warm_up_example/warmed_up_fedprox/server.py b/examples/warm_up_example/warmed_up_fedprox/server.py index faba1958e..dd85f2902 100644 --- a/examples/warm_up_example/warmed_up_fedprox/server.py +++ b/examples/warm_up_example/warmed_up_fedprox/server.py @@ -10,9 +10,9 @@ from examples.models.cnn_model import MnistNet from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting import WandBReporter from fl4health.server.base_server import FlServer -from fl4health.strategies.fedprox import FedProx +from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -22,9 +22,8 @@ def fit_config( batch_size: int, n_server_rounds: int, - reporting_enabled: bool, - project_name: str, - group_name: str, + project: str, + group: str, entity: str, current_round: int, local_epochs: Optional[int] = None, @@ -35,9 +34,9 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "reporting_enabled": reporting_enabled, - "project_name": project_name, - "group_name": group_name, + "project": project, + "group": group, + "entity": entity, "entity": entity, } @@ -48,10 +47,9 @@ def main(config: Dict[str, Any], server_address: str) -> None: fit_config, config["batch_size"], config["n_server_rounds"], - config["reporting_config"].get("enabled", False), - # Note that run name is not included, it will be set in the clients - config["reporting_config"].get("project_name", ""), - config["reporting_config"].get("group_name", ""), + # NOTE: that name is not included, it will be set in the clients + config["reporting_config"].get("project", ""), + config["reporting_config"].get("group", ""), config["reporting_config"].get("entity", ""), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), @@ -60,7 +58,7 @@ def main(config: Dict[str, Any], server_address: str) -> None: initial_model = MnistNet() # Server performs simple FedAveraging as its server-side optimization strategy - strategy = FedProx( + strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], min_evaluate_clients=config["n_clients"], # Server waits for min_available_clients before starting FL rounds @@ -71,15 +69,16 @@ def main(config: Dict[str, Any], server_address: str) -> None: fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(initial_model), - adaptive_proximal_weight=config["adaptive_proximal_weight"], - proximal_weight=config["proximal_weight"], - proximal_weight_delta=config["proximal_weight_delta"], - proximal_weight_patience=config["proximal_weight_patience"], + adapt_loss_weight=config["adaptive_proximal_weight"], + initial_loss_weight=config["proximal_weight"], + loss_weight_delta=config["proximal_weight_delta"], + loss_weight_patience=config["proximal_weight_patience"], ) - wandb_reporter = ServerWandBReporter.from_config(config) + if "reporting_config" in config: + wandb_reporter = WandBReporter("round", **config["reporting_config"]) client_manager = SimpleClientManager() - server = FlServer(client_manager, strategy, wandb_reporter) + server = FlServer(client_manager, strategy, reporters=[wandb_reporter]) fl.server.start_server( server=server, diff --git a/examples/warm_up_example/warmed_up_fenda/config.yaml b/examples/warm_up_example/warmed_up_fenda/config.yaml index 350bd86b0..2a2162164 100644 --- a/examples/warm_up_example/warmed_up_fenda/config.yaml +++ b/examples/warm_up_example/warmed_up_fenda/config.yaml @@ -7,10 +7,9 @@ local_epochs: 1 # The number of epochs to complete for client batch_size: 128 # The batch size for client training reporting_config: - enabled: False - project_name: FL4Health # Name of the project under which everything should be logged - run_name: "Fenda Server" # Name of the run on the server-side, each client will also have it's own run name - group_name: "Fenda Experiment" # Group under which each of the FL run logging will be stored + project: FL4Health # Name of the project under which everything should be logged + name: "Fenda Server" # Name of the run on the server-side, each client will also have it's own run name + group: "Fenda Experiment" # Group under which each of the FL run logging will be stored entity: "your_entity_here" # WandB user name notes: "Testing WB reporting" tags: ["Test", "Fenda"] diff --git a/examples/warm_up_example/warmed_up_fenda/server.py b/examples/warm_up_example/warmed_up_fenda/server.py index ec3b68759..f1d67a23f 100644 --- a/examples/warm_up_example/warmed_up_fenda/server.py +++ b/examples/warm_up_example/warmed_up_fenda/server.py @@ -13,7 +13,7 @@ from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting import WandBReporter from fl4health.server.base_server import FlServer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -24,9 +24,8 @@ def fit_config( batch_size: int, n_server_rounds: int, - reporting_enabled: bool, - project_name: str, - group_name: str, + project: str, + group: str, entity: str, current_round: int, local_epochs: Optional[int] = None, @@ -37,9 +36,8 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "current_server_round": current_round, - "reporting_enabled": reporting_enabled, - "project_name": project_name, - "group_name": group_name, + "project": project, + "group": group, "entity": entity, } @@ -50,10 +48,9 @@ def main(config: Dict[str, Any], server_address: str) -> None: fit_config, config["batch_size"], config["n_server_rounds"], - config["reporting_config"].get("enabled", False), - # Note that run name is not included, it will be set in the clients - config["reporting_config"].get("project_name", ""), - config["reporting_config"].get("group_name", ""), + # NOTE: that name is not included, it will be set in the clients + config["reporting_config"].get("project", ""), + config["reporting_config"].get("group", ""), config["reporting_config"].get("entity", ""), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), @@ -77,9 +74,10 @@ def main(config: Dict[str, Any], server_address: str) -> None: initial_parameters=get_all_model_parameters(initial_model), ) - wandb_reporter = ServerWandBReporter.from_config(config) + if "reporting_config" in config: + wandb_reporter = WandBReporter("round", **config["reporting_config"]) client_manager = SimpleClientManager() - server = FlServer(client_manager, strategy, wandb_reporter) + server = FlServer(client_manager, strategy, reporters=[wandb_reporter]) fl.server.start_server( server=server, diff --git a/fl4health/clients/adaptive_drift_constraint_client.py b/fl4health/clients/adaptive_drift_constraint_client.py index 28ba539df..f3b02bdd2 100644 --- a/fl4health/clients/adaptive_drift_constraint_client.py +++ b/fl4health/clients/adaptive_drift_constraint_client.py @@ -13,7 +13,7 @@ from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType @@ -27,7 +27,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, ) -> None: """ @@ -48,8 +48,8 @@ def __init__( checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. progress_bar (bool): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False """ @@ -59,7 +59,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, progress_bar=progress_bar, ) # These are the tensors that will be used to compute the penalty loss diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index ca51e4171..458d6d0a2 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -9,6 +9,7 @@ from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.apfl_base import ApflModule from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -22,8 +23,9 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: - super().__init__(data_path, metrics, device, loss_meter_type, checkpointer) + super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, reporters) self.model: ApflModule self.learning_rate: float diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 04fd4a5f3..f262cd5de 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -1,9 +1,10 @@ import copy import datetime +from collections.abc import Iterable, Sequence from enum import Enum from logging import INFO, WARNING from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch import torch.nn as nn @@ -20,8 +21,8 @@ from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger -from fl4health.reporting.fl_wandb import ClientWandBReporter -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter +from fl4health.reporting.report_manager import ReportsManager from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager @@ -43,7 +44,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, intermediate_client_state_dir: Optional[Path] = None, client_name: Optional[str] = None, @@ -63,8 +64,8 @@ def __init__( checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. progress_bar (bool): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False @@ -96,10 +97,9 @@ def __init__( else: self.per_round_checkpointer = None - if metrics_reporter is not None: - self.metrics_reporter = metrics_reporter - else: - self.metrics_reporter = MetricsReporter(run_id=self.client_name) + # Initialize reporters with client information. + self.reports_manager = ReportsManager(reporters) + self.reports_manager.initialize(id=self.client_name) self.initialized = False # Whether or not the client has been setup @@ -114,13 +114,12 @@ def __init__( # Optional variable to store the weights that the client was initialized with during each round of training self.initial_weights: Optional[NDArrays] = None - self.wandb_reporter: Optional[ClientWandBReporter] = None self.total_steps: int = 0 # Need to track total_steps across rounds for WANDB reporting # Attributes to be initialized in setup_client self.parameter_exchanger: ParameterExchanger self.model: nn.Module - self.optimizers: Dict[str, torch.optim.Optimizer] + self.optimizers: dict[str, torch.optim.Optimizer] self.train_loader: DataLoader self.val_loader: DataLoader self.test_loader: Optional[DataLoader] @@ -129,13 +128,13 @@ def __init__( self.num_test_samples: Optional[int] = None self.learning_rate: Optional[float] = None - def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: + def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: """ If checkpointer exists, maybe checkpoint model based on the provided metric values. Args: loss (float): validation loss to potentially be used for checkpointing - metrics (Dict[str, float]): validation metrics to potentially be used for checkpointing + metrics (dict[str, float]): validation metrics to potentially be used for checkpointing """ if self.checkpointer: self.checkpointer.maybe_checkpoint(self.model, loss, metrics, checkpoint_mode) @@ -153,7 +152,10 @@ def get_parameters(self, config: Config) -> NDArrays: parameters to be aggregated, but can contain more information. """ if not self.initialized: - log(INFO, "Setting up client and providing full model parameters to the server for initialization") + log( + INFO, + "Setting up client and providing full model parameters to the server for initialization", + ) # If initialized==False, the server is requesting model parameters from which to initialize all other # clients. As such get_parameters is being called before fit or evaluate, so we must call @@ -208,10 +210,9 @@ def shutdown(self) -> None: """ Shuts down the client. Involves shutting down W&B reporter if one exists. """ - if self.wandb_reporter: - self.wandb_reporter.shutdown_reporter() - - self.metrics_reporter.add_to_metrics({"shutdown": datetime.datetime.now()}) + # Shutdown reporters + self.reports_manager.report({"shutdown": str(datetime.datetime.now())}) + self.reports_manager.shutdown() def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool]: """ @@ -250,7 +251,7 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N # Either local epochs or local steps is none based on what key is passed in the config return local_epochs, local_steps, current_server_round, evaluate_after_fit - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict[str, Scalar]]: """ Processes config, initializes client (if first round) and performs training based on the passed config. If per_round_checkpointer is not None, on initialization the client checks if a checkpointed client state @@ -261,12 +262,13 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict config (NDArrays): The config from the server. Returns: - Tuple[NDArrays, int, Dict[str, Scalar]]: The parameters following the local training along with the + Tuple[NDArrays, int, dict[str, Scalar]]: The parameters following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit. Raises: ValueError: If local_steps or local_epochs is not specified in config. """ + round_start_time = datetime.datetime.now() local_epochs, local_steps, current_server_round, evaluate_after_fit = self.process_config(config) if not self.initialized: @@ -277,15 +279,11 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict if self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists(): self.load_client_state() - self.metrics_reporter.add_to_metrics_at_round( - current_server_round, - data={"fit_start": datetime.datetime.now()}, - ) - self.set_parameters(parameters, config, fitting_round=True) self.update_before_train(current_server_round) + fit_start_time = datetime.datetime.now() if local_epochs is not None: loss_dict, metrics = self.train_by_epochs(local_epochs, current_server_round) local_steps = len(self.train_loader) * local_epochs # total steps over training round @@ -293,6 +291,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict loss_dict, metrics = self.train_by_steps(local_steps, current_server_round) else: raise ValueError("Must specify either local_epochs or local_steps in the Config.") + fit_end_time = datetime.datetime.now() # Perform necessary updates after training has completed for the current FL round self.update_after_train(local_steps, loss_dict, config) @@ -305,13 +304,17 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) - self.metrics_reporter.add_to_metrics_at_round( - current_server_round, - data={ + self.reports_manager.report( + { "fit_metrics": metrics, - "loss_dict": loss_dict, - "fit_end": datetime.datetime.now(), + "fit_losses": loss_dict, + "round": current_server_round, + "round_start": str(round_start_time), + "round_end": str(datetime.datetime.now()), + "fit_start": str(fit_start_time), + "fit_end": str(fit_end_time), }, + current_server_round, ) # After local client training has finished, checkpoint client state @@ -327,11 +330,11 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metrics, ) - def evaluate_after_fit(self) -> Tuple[float, Dict[str, Scalar]]: + def evaluate_after_fit(self) -> Tuple[float, dict[str, Scalar]]: """ Run self.validate right after fit to collect metrics on the local model against validation data. - Returns: (Dict[str, Scalar]) a dictionary with the metrics. + Returns: (dict[str, Scalar]) a dictionary with the metrics. """ loss, metric_values = self.validate() @@ -342,7 +345,7 @@ def evaluate_after_fit(self) -> Tuple[float, Dict[str, Scalar]]: } return loss, metrics_after_fit - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, dict[str, Scalar]]: """ Evaluates the model on the validation set, and test set (if defined). @@ -351,32 +354,33 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di config (NDArrays): The config object from the server. Returns: - Tuple[float, int, Dict[str, Scalar]]: A loss associated with the evaluation, the number of samples in the + Tuple[float, int, dict[str, Scalar]]: A loss associated with the evaluation, the number of samples in the validation/test set and the metric_values associated with evaluation. """ if not self.initialized: self.setup_client(config) + start_time = datetime.datetime.now() current_server_round = narrow_dict_type(config, "current_server_round", int) - self.metrics_reporter.add_to_metrics_at_round( - current_server_round, - data={"evaluate_start": datetime.datetime.now()}, - ) self.set_parameters(parameters, config, fitting_round=False) loss, metrics = self.validate() + end_time = datetime.datetime.now() + elapsed = end_time - start_time # Checkpoint based on the loss and metrics produced during validation AFTER server-side aggregation # NOTE: This assumes that the loss returned in the checkpointing loss self._maybe_checkpoint(loss, metrics, CheckpointMode.POST_AGGREGATION) - self.metrics_reporter.add_to_metrics_at_round( - current_server_round, - data={ - "evaluate_metrics": metrics, - "loss": loss, - "evaluate_end": datetime.datetime.now(), + self.reports_manager.report( + { + "eval_metrics": metrics, + "eval_loss": loss, + "eval_start": str(start_time), + "eval_time_elapsed": str(elapsed), + "eval_end": str(end_time), }, + current_server_round, ) # EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics @@ -435,8 +439,8 @@ def _log_header_str( def _log_results( self, - loss_dict: Dict[str, float], - metrics_dict: Dict[str, Scalar], + loss_dict: dict[str, float], + metrics_dict: dict[str, Scalar], current_round: Optional[int] = None, current_epoch: Optional[int] = None, logging_mode: LoggingMode = LoggingMode.TRAIN, @@ -446,8 +450,8 @@ def _log_results( output file. Called only at the end of an epoch or server round Args: - loss_dict (Dict[str, float]): A dictionary of losses to log. - metrics_dict (Dict[str, Scalar]): A dictionary of the metric to log. + loss_dict (dict[str, float]): A dictionary of losses to log. + metrics_dict (dict[str, Scalar]): A dictionary of the metric to log. current_round (Optional[int]): The current FL round (i.e., current server round). current_epoch (Optional[int]): The current epoch of local training. logging_mode (LoggingMode): The logging mode (Training, Validation, or Testing). @@ -472,8 +476,11 @@ def _log_results( [log(level.value, msg) for level, msg in client_logs] def get_client_specific_logs( - self, current_round: Optional[int], current_epoch: Optional[int], logging_mode: LoggingMode - ) -> Tuple[str, List[Tuple[LogLevel, str]]]: + self, + current_round: Optional[int], + current_epoch: Optional[int], + logging_mode: LoggingMode, + ) -> Tuple[str, list[Tuple[LogLevel, str]]]: """ This function can be overridden to provide any client specific information to the basic client logging. For example, perhaps a client @@ -492,7 +499,7 @@ def get_client_specific_logs( Optional[str]: A string to append to the header log string that typically announces the current server round and current epoch at the beginning of each round or local epoch. - Optional[List[Tuple[LogLevel, str]]]]: A list of tuples where the + Optional[list[Tuple[LogLevel, str]]]]: A list of tuples where the first element is a LogLevel as defined in fl4health.utils. typing and the second element is a string message. Each item in the list will be logged at the end of each server round or epoch. @@ -500,41 +507,13 @@ def get_client_specific_logs( """ return "", [] - def _handle_reporting( - self, - loss_dict: Dict[str, float], - metric_dict: Dict[str, Scalar], - current_round: Optional[int] = None, - ) -> None: - """ - Handles reporting of losses and metrics to W&B. - Args: - loss_dict (Dict[str, float]): A dictionary of losses to log. - metrics_dict (Dict[str, Scalar]): A dictionary of metrics to log. - current_round (Optional[int]): The current FL round. - """ - # If reporter is None we do not report to wandb and return - if self.wandb_reporter is None: - return - - # If no current_round is passed or current_round is None, set current_round to 0 - # This situation only arises when we do local fine-tuning and call train_by_epochs or train_by_steps explicitly - current_round = current_round if current_round is not None else 0 - - reporting_dict: Dict[str, Any] = {"server_round": current_round} - reporting_dict.update({"step": self.total_steps}) - reporting_dict.update(loss_dict) - reporting_dict.update(metric_dict) - reporting_dict.update(self.get_client_specific_reports()) - self.wandb_reporter.report_metrics(reporting_dict) - - def get_client_specific_reports(self) -> Dict[str, Any]: + def get_client_specific_reports(self) -> dict[str, Any]: """ This function can be overridden by an inheriting client to report additional client specific information to the wandb_reporter Returns: - Dict[str, Any]: A dictionary of things to report + dict[str, Any]: A dictionary of things to report """ return {} @@ -564,25 +543,25 @@ def _move_data_to_device( return {key: value.to(self.device) for key, value in data.items()} else: raise TypeError( - "data must be of type torch.Tensor or Dict[str, torch.Tensor]. \ + "data must be of type torch.Tensor or dict[str, torch.Tensor]. \ If definition of TorchInputType or TorchTargetType has \ changed this method might need to be updated or split into \ two" ) - def is_empty_batch(self, input: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> bool: + def is_empty_batch(self, input: Union[torch.Tensor, dict[str, torch.Tensor]]) -> bool: """ Check whether input, which represents a batch of inputs to a model, is empty. Args: - input (Union[torch.Tensor, Dict[str, torch.Tensor]]): input batch. - input can be of type torch.Tensor or Dict[str, torch.Tensor], and in the + input (Union[torch.Tensor, dict[str, torch.Tensor]]): input batch. + input can be of type torch.Tensor or dict[str, torch.Tensor], and in the latter case, the batch is considered to be empty if all tensors in the dictionary have length zero. Raises: - TypeError: raised if input is not of type torch.Tensor or Dict[str, torch.Tensor]. - ValueError: raised if input has type Dict[str, torch.Tensor] and not all tensors + TypeError: raised if input is not of type torch.Tensor or dict[str, torch.Tensor]. + ValueError: raised if input has type dict[str, torch.Tensor] and not all tensors within the dictionary have the same size. Returns: @@ -599,10 +578,13 @@ def is_empty_batch(self, input: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> else: return first_val_len == 0 else: - raise TypeError("Input must be of type torch.Tensor or Dict[str, torch.Tensor].") + raise TypeError("Input must be of type torch.Tensor or dict[str, torch.Tensor].") def update_metric_manager( - self, preds: TorchPredType, target: TorchTargetType, metric_manager: MetricManager + self, + preds: TorchPredType, + target: TorchTargetType, + metric_manager: MetricManager, ) -> None: """ Updates a metric manager with the provided model predictions and @@ -675,7 +657,7 @@ def train_by_epochs( self, epochs: int, current_round: Optional[int] = None, - ) -> Tuple[Dict[str, float], Dict[str, Scalar]]: + ) -> Tuple[dict[str, float], dict[str, Scalar]]: """ Train locally for the specified number of epochs. @@ -684,11 +666,12 @@ def train_by_epochs( current_round (Optional[int], optional): The current FL round. Returns: - Tuple[Dict[str, float], Dict[str, Scalar]]: The loss and metrics dictionary from the local training. + Tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss. """ self.model.train() steps_this_round = 0 # Reset number of steps this round + report_data: dict = {"round": current_round} for local_epoch in range(epochs): self.train_metric_manager.clear() self.train_loss_meter.clear() @@ -696,6 +679,8 @@ def train_by_epochs( self._log_header_str(current_round, local_epoch) # update before epoch hook self.update_before_epoch(epoch=local_epoch) + # Update report data dict + report_data.update({"fit_epoch": local_epoch}) for input, target in self.maybe_progress_bar(self.train_loader): self.update_before_step(steps_this_round, current_round) # Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can @@ -711,14 +696,20 @@ def train_by_epochs( self.update_metric_manager(preds, target, self.train_metric_manager) self.update_after_step(steps_this_round, current_round) self.update_lr_schedulers(epoch=local_epoch) + report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps}) + report_data.update(self.get_client_specific_reports()) + self.reports_manager.report(report_data, current_round, local_epoch, self.total_steps) self.total_steps += 1 steps_this_round += 1 + metrics = self.train_metric_manager.compute() loss_dict = self.train_loss_meter.compute().as_dict() - # Log results and maybe report via WANDB + # Log and report results self._log_results(loss_dict, metrics, current_round, local_epoch) - self._handle_reporting(loss_dict, metrics, current_round=current_round) + report_data.update({"fit_metrics": metrics}) + report_data.update(self.get_client_specific_reports()) + self.reports_manager.report(report_data, current_round, local_epoch) # Return final training metrics return loss_dict, metrics @@ -727,7 +718,7 @@ def train_by_steps( self, steps: int, current_round: Optional[int] = None, - ) -> Tuple[Dict[str, float], Dict[str, Scalar]]: + ) -> Tuple[dict[str, float], dict[str, Scalar]]: """ Train locally for the specified number of steps. @@ -736,7 +727,7 @@ def train_by_steps( current_round (Optional[int], optional): The current FL round Returns: - Tuple[Dict[str, float], Dict[str, Scalar]]: The loss and metrics dictionary from the local training. + Tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss. """ self.model.train() @@ -747,8 +738,8 @@ def train_by_steps( self.train_loss_meter.clear() self.train_metric_manager.clear() self._log_header_str(current_round) + report_data: dict = {"round": current_round} for step in self.maybe_progress_bar(range(steps)): - self.update_before_step(step, current_round) try: @@ -772,14 +763,17 @@ def train_by_steps( self.update_metric_manager(preds, target, self.train_metric_manager) self.update_after_step(step, current_round) self.update_lr_schedulers(step=step) + report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps}) + report_data.update(self.get_client_specific_reports()) + self.reports_manager.report(report_data, current_round, None, self.total_steps) self.total_steps += 1 loss_dict = self.train_loss_meter.compute().as_dict() metrics = self.train_metric_manager.compute() - # Log results and maybe report via WANDB + # Log and report results self._log_results(loss_dict, metrics, current_round) - self._handle_reporting(loss_dict, metrics, current_round=current_round) + report_data = {} return loss_dict, metrics @@ -789,7 +783,7 @@ def _validate_or_test( loss_meter: LossMeter, metric_manager: MetricManager, logging_mode: LoggingMode = LoggingMode.VALIDATION, - ) -> Tuple[float, Dict[str, Scalar]]: + ) -> Tuple[float, dict[str, Scalar]]: """ Evaluate the model on the given validation or test dataset. @@ -801,9 +795,12 @@ def _validate_or_test( Default is for validation. Returns: - Tuple[float, Dict[str, Scalar]]: The loss and a dictionary of metrics from evaluation. + Tuple[float, dict[str, Scalar]]: The loss and a dictionary of metrics from evaluation. """ - assert logging_mode in [LoggingMode.VALIDATION, LoggingMode.TEST], "logging_mode must be VALIDATION or TEST" + assert logging_mode in [ + LoggingMode.VALIDATION, + LoggingMode.TEST, + ], "logging_mode must be VALIDATION or TEST" self.model.eval() metric_manager.clear() loss_meter.clear() @@ -822,19 +819,22 @@ def _validate_or_test( return loss_dict["checkpoint"], metrics - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self) -> Tuple[float, dict[str, Scalar]]: """ Validate the current model on the entire validation and potentially an entire test dataset if it has been defined. Returns: - Tuple[float, Dict[str, Scalar]]: The validation loss and a dictionary of metrics + Tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation (and test if present). """ val_loss, val_metrics = self._validate_or_test(self.val_loader, self.val_loss_meter, self.val_metric_manager) if self.test_loader: test_loss, test_metrics = self._validate_or_test( - self.test_loader, self.test_loss_meter, self.test_metric_manager, LoggingMode.TEST + self.test_loader, + self.test_loss_meter, + self.test_metric_manager, + LoggingMode.TEST, ) # There will be no clashes due to the naming convention associated with the metric managers if self.num_test_samples is not None: @@ -844,7 +844,7 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: return val_loss, val_metrics - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """ Return properties (train and validation dataset sample counts) of client. @@ -852,13 +852,16 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: config (Config): The config from the server. Returns: - Dict[str, Scalar]: A dictionary with two entries corresponding to the sample counts in + dict[str, Scalar]: A dictionary with two entries corresponding to the sample counts in the train and validation set. """ if not self.initialized: self.setup_client(config) - return {"num_train_samples": self.num_train_samples, "num_val_samples": self.num_val_samples} + return { + "num_train_samples": self.num_train_samples, + "num_val_samples": self.num_val_samples, + } def setup_client(self, config: Config) -> None: """ @@ -897,9 +900,7 @@ def setup_client(self, config: Config) -> None: self.criterion = self.get_criterion(config).to(self.device) self.parameter_exchanger = self.get_parameter_exchanger(config) - self.wandb_reporter = ClientWandBReporter.from_config(self.client_name, config) - - self.metrics_reporter.add_to_metrics({"type": "client", "initialized": datetime.datetime.now()}) + self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())}) self.initialized = True @@ -921,7 +922,7 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp Args: input (TorchInputType): Inputs to be fed into the model. If input is - of type Dict[str, torch.Tensor], it is assumed that the keys of + of type dict[str, torch.Tensor], it is assumed that the keys of input match the names of the keyword arguments of self.model. forward(). @@ -947,6 +948,8 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp # Note that this assumes the keys of the input match (exactly) the keyword args # of self.model.forward(). output = self.model(**input) + else: + raise TypeError('"input" must be of type torch.Tensor or dict[str, torch.Tensor].') if isinstance(output, dict): return output, {} @@ -962,7 +965,7 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp def compute_loss_and_additional_losses( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType - ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[dict[str, torch.Tensor]]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. @@ -972,7 +975,7 @@ def compute_loss_and_additional_losses( target (TorchTargetType): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]; A tuple with: + Tuple[torch.Tensor, Union[dict[str, torch.Tensor], None]]; A tuple with: - The tensor for the loss - A dictionary of additional losses with their names and values, or None if there are no additional losses. @@ -1116,7 +1119,7 @@ def get_criterion(self, config: Config) -> _Loss: """ raise NotImplementedError - def get_optimizer(self, config: Config) -> Union[Optimizer, Dict[str, Optimizer]]: + def get_optimizer(self, config: Config) -> Union[Optimizer, dict[str, Optimizer]]: """ Method to be defined by user that returns the PyTorch optimizer used to train models locally Return value can be a single torch optimizer or a dictionary of string and torch optimizer. @@ -1127,7 +1130,7 @@ def get_optimizer(self, config: Config) -> Union[Optimizer, Dict[str, Optimizer] config (Config): The config sent from the server. Returns: - Union[Optimizer, Dict[str, Optimizer]]: An optimizer or dictionary of optimizers to + Union[Optimizer, dict[str, Optimizer]]: An optimizer or dictionary of optimizers to train model. Raises: @@ -1192,7 +1195,7 @@ def update_before_train(self, current_server_round: int) -> None: """ pass - def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], config: Config) -> None: + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ Hook method called after training with the number of local_steps performed over the FL round and the corresponding loss dictionary. For example, used by Scaffold to update the control variates @@ -1201,8 +1204,9 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], conf aggregation. Args: - local_steps (int): The number of steps so far in the round in the local training. - loss_dict (Dict[str, float]): A dictionary of losses from local training. + local_steps (int): The number of steps so far in the round in the local + training. + loss_dict (dict[str, float]): A dictionary of losses from local training. config (Config): The config from the server """ pass @@ -1295,13 +1299,16 @@ def save_client_state(self) -> None: "lr_schedulers_state": {key: scheduler.state_dict() for key, scheduler in self.lr_schedulers.items()}, "total_steps": self.total_steps, "client_name": self.client_name, - "metrics_reporter": self.metrics_reporter, + "reports_manager": self.reports_manager, "optimizers_state": {key: optimizer.state_dict()["state"] for key, optimizer in self.optimizers.items()}, } self.per_round_checkpointer.save_checkpoint(ckpt) - log(INFO, f"Saving client state to checkpoint at {self.per_round_checkpointer.checkpoint_path}") + log( + INFO, + f"Saving client state to checkpoint at {self.per_round_checkpointer.checkpoint_path}", + ) def load_client_state(self) -> None: """ @@ -1314,7 +1321,7 @@ def load_client_state(self) -> None: narrow_dict_type_and_set_attribute(self, ckpt, "client_name", "client_name", str) narrow_dict_type_and_set_attribute(self, ckpt, "total_steps", "total_steps", int) - narrow_dict_type_and_set_attribute(self, ckpt, "metrics_reporter", "metrics_reporter", MetricsReporter) + narrow_dict_type_and_set_attribute(self, ckpt, "reports_manager", "reports_manager", ReportsManager) assert "lr_schedulers_state" in ckpt and isinstance(ckpt["lr_schedulers_state"], dict) assert "optimizers_state" in ckpt and isinstance(ckpt["optimizers_state"], dict) diff --git a/fl4health/clients/ditto_client.py b/fl4health/clients/ditto_client.py index d39629e29..27f443ced 100644 --- a/fl4health/clients/ditto_client.py +++ b/fl4health/clients/ditto_client.py @@ -11,7 +11,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.adaptive_drift_constraint_client import AdaptiveDriftConstraintClient from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric @@ -26,7 +26,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, ) -> None: """ @@ -40,17 +40,19 @@ def __init__( corresponding strategy used by the server Args: - data_path (Path): path to the data to be used to load the data for client-side training - metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model - device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or - 'cuda' - loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over - each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + data_path (Path): path to the data to be used to load the data for + client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and + predictions of the client model + device (torch.device): Device indicator for where to send the model, + batches, labels etc. Often 'cpu' or 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and + compute the losses over each batch. Defaults to LossMeterType.AVERAGE. + checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer + module defining when and how to do checkpointing during client-side + training. No checkpointing is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. progress_bar (bool): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False """ @@ -60,7 +62,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, progress_bar=progress_bar, ) self.global_model: nn.Module @@ -120,7 +122,10 @@ def get_parameters(self, config: Config) -> NDArrays: NDArrays: GLOBAL model weights to be sent to the server for aggregation """ if not self.initialized: - log(INFO, "Setting up client and providing full model parameters to the server for initialization") + log( + INFO, + "Setting up client and providing full model parameters to the server for initialization", + ) # If initialized==False, the server is requesting model parameters from which to initialize all other # clients. As such get_parameters is being called before fit or evaluate, so we must call diff --git a/fl4health/clients/evaluate_client.py b/fl4health/clients/evaluate_client.py index 000084518..004422ab1 100644 --- a/fl4health/clients/evaluate_client.py +++ b/fl4health/clients/evaluate_client.py @@ -13,7 +13,8 @@ from fl4health.clients.basic_client import BasicClient from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter +from fl4health.reporting.report_manager import ReportsManager from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType from fl4health.utils.metrics import Metric, MetricManager from fl4health.utils.random import generate_hash @@ -34,7 +35,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, model_checkpoint_path: Optional[Path] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: # EvaluateClient does not call BasicClient constructor and sets attributes # in a custom way to account for the fact it does not involve any training @@ -45,10 +46,9 @@ def __init__( self.metrics = metrics self.initialized = False - if metrics_reporter is not None: - self.metrics_reporter = metrics_reporter - else: - self.metrics_reporter = MetricsReporter(run_id=self.client_name) + # Initialize reporters with client information. + self.reports_manager = ReportsManager(reporters) + self.reports_manager.initialize(id=self.client_name) # This data loader should be instantiated as the one on which to run evaluation self.global_loss_meter = LossMeter[EvaluationLosses](loss_meter_type, EvaluationLosses) @@ -63,7 +63,6 @@ def __init__( self.criterion: _Loss self.local_model: Optional[nn.Module] = None self.global_model: Optional[nn.Module] = None - self.wandb_reporter = None def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: raise ValueError("Get Parameters is not implemented for an Evaluation-Only Client") @@ -88,7 +87,7 @@ def setup_client(self, config: Config) -> None: self.criterion = self.get_criterion(config) self.parameter_exchanger = self.get_parameter_exchanger(config) - self.metrics_reporter.add_to_metrics({"type": "client", "initialized": datetime.datetime.now()}) + self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())}) self.initialized = True @@ -110,20 +109,24 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di if not self.initialized: self.setup_client(config) - self.metrics_reporter.add_to_metrics({"evaluate_start": datetime.datetime.now()}) - + start_time = datetime.datetime.now() self.set_parameters(parameters, config, fitting_round=False) # Make sure at least one of local or global model is not none (i.e. there is something to evaluate) assert self.local_model or self.global_model loss, metric_values = self.validate() + end_time = datetime.datetime.now() + elapsed = end_time - start_time - self.metrics_reporter.add_to_metrics( + self.reports_manager.report( { - "metrics": metric_values, - "loss": loss, - "evaluate_end": datetime.datetime.now(), - } + "eval_metrics": metric_values, + "eval_loss": loss, + "eval_start": str(start_time), + "eval_time_elapsed": str(elapsed), + "eval_end": str(end_time), + }, + 0, ) # EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics @@ -147,7 +150,11 @@ def _handle_logging( # type: ignore ) def validate_on_model( - self, model: nn.Module, metric_meter: MetricManager, loss_meter: LossMeter, is_global: bool + self, + model: nn.Module, + metric_meter: MetricManager, + loss_meter: LossMeter, + is_global: bool, ) -> Tuple[EvaluationLosses, Dict[str, Scalar]]: model.eval() metric_meter.clear() @@ -178,13 +185,19 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: if self.local_model: log(INFO, "Performing evaluation on local model") local_loss, local_metrics = self.validate_on_model( - self.local_model, self.local_metric_manager, self.local_loss_meter, is_global=False + self.local_model, + self.local_metric_manager, + self.local_loss_meter, + is_global=False, ) if self.global_model: log(INFO, "Performing evaluation on global model") global_loss, global_metrics = self.validate_on_model( - self.global_model, self.global_metric_manager, self.global_loss_meter, is_global=True + self.global_model, + self.global_metric_manager, + self.global_loss_meter, + is_global=True, ) # Store the losses in the metrics, since we can't return more than one loss. @@ -199,7 +212,8 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: @staticmethod def merge_metrics( - global_metrics: Optional[Dict[str, Scalar]], local_metrics: Optional[Dict[str, Scalar]] + global_metrics: Optional[Dict[str, Scalar]], + local_metrics: Optional[Dict[str, Scalar]], ) -> Dict[str, Scalar]: # Merge metrics if necessary if global_metrics: @@ -248,7 +262,10 @@ def get_local_model(self, config: Config) -> Optional[nn.Module]: """ # If a model checkpoint is provided, we load the checkpoint into the local model to be evaluated. if self.model_checkpoint_path: - log(INFO, f"Loading model checkpoint at: {self.model_checkpoint_path.__str__()}") + log( + INFO, + f"Loading model checkpoint at: {self.model_checkpoint_path.__str__()}", + ) return torch.load(self.model_checkpoint_path) else: return None diff --git a/fl4health/clients/fedpm_client.py b/fl4health/clients/fedpm_client.py index 625ab0ea3..8612f7a61 100644 --- a/fl4health/clients/fedpm_client.py +++ b/fl4health/clients/fedpm_client.py @@ -9,7 +9,7 @@ from fl4health.model_bases.masked_layers import convert_to_masked_model from fl4health.parameter_exchange.fedpm_exchanger import FedPmExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -23,7 +23,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( data_path=data_path, @@ -31,7 +31,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, ) def setup_client(self, config: Config) -> None: diff --git a/fl4health/clients/fedrep_client.py b/fl4health/clients/fedrep_client.py index 71a3b219e..9181442b1 100644 --- a/fl4health/clients/fedrep_client.py +++ b/fl4health/clients/fedrep_client.py @@ -15,7 +15,7 @@ from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric @@ -37,9 +37,9 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: - super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, metrics_reporter) + super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, reporters) self.fedrep_train_mode = FedRepTrainMode.HEAD def _prepare_train_representations(self) -> None: @@ -109,7 +109,10 @@ def _extract_epochs_or_steps_specified(self, config: Config) -> EpochsAndStepsTu epochs_specified = ("local_head_epochs" in config) and ("local_rep_epochs" in config) steps_specified = ("local_head_steps" in config) and ("local_rep_steps" in config) if epochs_specified and not steps_specified: - log(INFO, "Epochs for head and representation module specified. Proceeding with epoch-based training") + log( + INFO, + "Epochs for head and representation module specified. Proceeding with epoch-based training", + ) return ( narrow_dict_type(config, "local_head_epochs", int), narrow_dict_type(config, "local_rep_epochs", int), @@ -117,7 +120,10 @@ def _extract_epochs_or_steps_specified(self, config: Config) -> EpochsAndStepsTu None, ) elif steps_specified and not epochs_specified: - log(INFO, "Steps for head and representation module specified. Proceeding with step-based training") + log( + INFO, + "Steps for head and representation module specified. Proceeding with step-based training", + ) return ( None, None, @@ -207,6 +213,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict ValueError: If the steps or epochs for the representation and head module training processes are are correctly specified. """ + round_start_time = datetime.datetime.now() ( (local_head_epochs, local_rep_epochs, local_head_steps, local_rep_steps), current_server_round, @@ -216,15 +223,11 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict if not self.initialized: self.setup_client(config) - self.metrics_reporter.add_to_metrics_at_round( - current_server_round, - data={"fit_start": datetime.datetime.now()}, - ) - self.set_parameters(parameters, config, fitting_round=True) self.update_before_train(current_server_round) + fit_start_time = datetime.datetime.now() if local_head_epochs and local_rep_epochs: loss_dict, metrics = self.train_fedrep_by_epochs(local_head_epochs, local_rep_epochs, current_server_round) elif local_head_steps and local_rep_steps: @@ -234,6 +237,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict "Local epochs or local steps have not been correctly specified. They have values " f"{local_head_epochs}, {local_rep_epochs}, {local_head_steps}, {local_rep_steps}" ) + fit_time = datetime.datetime.now() - fit_start_time # Check if we should run an evaluation with validation data after fit # (for example, this is used by FedDGGA) @@ -243,12 +247,16 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) - self.metrics_reporter.add_to_metrics_at_round( - current_server_round, - data={ + # Report data at end of round + self.reports_manager.report( + { "fit_metrics": metrics, - "loss_dict": loss_dict, + "fit_losses": loss_dict, + "round": current_server_round, + "round_start": str(round_start_time), + "fit_time_elapsed": str(fit_time), }, + current_server_round, ) # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics @@ -284,7 +292,10 @@ def train_fedrep_by_epochs( # Second we train the representation module for rep_epochs with the head module frozen in place self._prepare_train_representations() - log(INFO, f"Beginning FedRep Representation Training Phase for {rep_epochs} Epochs") + log( + INFO, + f"Beginning FedRep Representation Training Phase for {rep_epochs} Epochs", + ) loss_dict_rep, metrics_dict_rep = self.train_by_epochs(rep_epochs, current_round) log(INFO, "Converting the loss and metrics dictionary keys for Rep training") # The loss and metrics coming from train_by_epochs are generically keyed, for example "backward." To avoid @@ -320,7 +331,10 @@ def train_fedrep_by_steps( # Second we train the representation module for rep_steps with the head module frozen in place self._prepare_train_representations() - log(INFO, f"Beginning FedRep Representation Training Phase for {rep_steps} Steps") + log( + INFO, + f"Beginning FedRep Representation Training Phase for {rep_steps} Steps", + ) loss_dict_rep, metrics_dict_rep = self.train_by_steps(rep_steps, current_round) log(INFO, "Converting the loss and metrics dictionary keys for Rep training") # The loss and metrics coming from train_by_steps are generically keyed, for example "backward." To avoid diff --git a/fl4health/clients/fenda_ditto_client.py b/fl4health/clients/fenda_ditto_client.py index 69abbe4db..9cad768fb 100644 --- a/fl4health/clients/fenda_ditto_client.py +++ b/fl4health/clients/fenda_ditto_client.py @@ -11,7 +11,7 @@ from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.parameter_extraction import check_shape_match @@ -26,7 +26,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, freeze_global_feature_extractor: bool = False, ) -> None: @@ -67,12 +67,11 @@ def __init__( checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. progress_bar (bool): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False + freeze_global_feature_extractor (bool, optional): Determines whether we freeze the FENDA global feature extractor during training. If freeze_global_feature_extractor is False, both the global and the local feature extractor in the local FENDA model will be trained. Otherwise, the global feature extractor @@ -86,7 +85,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, progress_bar=progress_bar, ) self.global_model: SequentiallySplitExchangeBaseModel diff --git a/fl4health/clients/flash_client.py b/fl4health/clients/flash_client.py index cf6ccde19..da8150015 100644 --- a/fl4health/clients/flash_client.py +++ b/fl4health/clients/flash_client.py @@ -62,10 +62,12 @@ def train_by_epochs( self.model.train() local_step = 0 previous_loss = float("inf") + report_data: dict = {"round": current_round} for local_epoch in range(epochs): self.train_metric_manager.clear() self.train_loss_meter.clear() self._log_header_str(current_round, local_epoch) + report_data.update({"fit_epoch": local_epoch}) for input, target in self.train_loader: if self.is_empty_batch(input): log(INFO, "Empty batch generated by data loader. Skipping step.") @@ -77,6 +79,9 @@ def train_by_epochs( self.train_loss_meter.update(losses) self.train_metric_manager.update(preds, target) self.update_after_step(local_step, current_round) + report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps}) + report_data.update(self.get_client_specific_reports()) + self.reports_manager.report(report_data, current_round, local_epoch, self.total_steps) self.total_steps += 1 local_step += 1 @@ -84,8 +89,12 @@ def train_by_epochs( loss_dict = self.train_loss_meter.compute().as_dict() current_loss, _ = self.validate() - self._log_results(loss_dict, metrics, current_round=current_round, current_epoch=local_epoch) - self._handle_reporting(loss_dict, metrics, current_round=current_round) + self._log_results( + loss_dict, + metrics, + current_round=current_round, + current_epoch=local_epoch, + ) if self.gamma is not None and previous_loss - current_loss < self.gamma / (local_epoch + 1): log( diff --git a/fl4health/clients/instance_level_dp_client.py b/fl4health/clients/instance_level_dp_client.py index 2d125b62f..9b093300d 100644 --- a/fl4health/clients/instance_level_dp_client.py +++ b/fl4health/clients/instance_level_dp_client.py @@ -7,6 +7,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -25,6 +26,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( data_path=data_path, @@ -32,6 +34,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, + reporters=reporters, ) self.clipping_bound: float self.noise_multiplier: float diff --git a/fl4health/clients/mr_mtl_client.py b/fl4health/clients/mr_mtl_client.py index 7879bcba0..80f4e231c 100644 --- a/fl4health/clients/mr_mtl_client.py +++ b/fl4health/clients/mr_mtl_client.py @@ -9,7 +9,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.adaptive_drift_constraint_client import AdaptiveDriftConstraintClient -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchPredType, TorchTargetType @@ -23,7 +23,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, ) -> None: """ @@ -48,8 +48,8 @@ def __init__( checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. progress_bar (bool): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False """ @@ -59,7 +59,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, progress_bar=progress_bar, ) # NOTE: The initial global model is used to house the aggregate weight updates at the beginning of a round, diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 87a8b47d1..f3f333f18 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -22,7 +22,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.basic_client import BasicClient, LoggingMode -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric, MetricManager @@ -74,7 +74,7 @@ def __init__( intermediate_client_state_dir: Optional[Path] = None, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, client_name: Optional[str] = None, ) -> None: """ @@ -137,9 +137,8 @@ def __init__( Checkpointer module defining when and how to do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics - reporter instance to record the metrics during the execution. - Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. """ metrics = metrics if metrics else [] # Parent method sets up several class attributes @@ -149,7 +148,7 @@ def __init__( device=device, # self.device loss_meter_type=loss_meter_type, checkpointer=checkpointer, # self.checkpointer - metrics_reporter=metrics_reporter, # self.metrics_reporter + reporters=reporters, progress_bar=progress_bar, intermediate_client_state_dir=intermediate_client_state_dir, client_name=client_name, @@ -283,7 +282,9 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler: # compatible with Torch LRScheduler # Create and return LR Scheduler. This is nnunet default for version 2.5.1 return PolyLRSchedulerWrapper( - self.optimizers[optimizer_key], initial_lr=self.nnunet_trainer.initial_lr, max_steps=total_steps + self.optimizers[optimizer_key], + initial_lr=self.nnunet_trainer.initial_lr, + max_steps=total_steps, ) def create_plans(self, config: Config) -> Dict[str, Any]: @@ -337,7 +338,10 @@ def create_plans(self, config: Config) -> Dict[str, Any]: new_bs = max(min(old_bs, bs_5percent), 2) plans["configurations"][c]["batch_size"] = new_bs else: - log(WARNING, ("Did not find a 'batch_size' key in the nnunet plans " f"dict for nnunet config: {c}")) + log( + WARNING, + ("Did not find a 'batch_size' key in the nnunet plans " f"dict for nnunet config: {c}"), + ) # Can't run nnunet preprocessing without saving plans file os.makedirs(join(nnUNet_preprocessed, self.dataset_name), exist_ok=True) @@ -375,7 +379,10 @@ def maybe_preprocess(self, nnunet_config: NnunetConfig) -> None: configurations=[nnunet_config.value], ) elif self.verbose: - log(INFO, "\tnnunet preprocessed data seems to already exist. Skipping preprocessing") + log( + INFO, + "\tnnunet preprocessed data seems to already exist. Skipping preprocessing", + ) @use_default_signal_handlers # Fingerprint extraction spawns subprocess def maybe_extract_fingerprint(self) -> None: @@ -393,11 +400,20 @@ def maybe_extract_fingerprint(self) -> None: with redirect_stdout(self.stream2debug): extract_fingerprints(dataset_ids=[self.dataset_id]) if self.verbose: - log(INFO, f"\tExtracted dataset fingerprint in {time.time()-start:.1f}s") + log( + INFO, + f"\tExtracted dataset fingerprint in {time.time()-start:.1f}s", + ) elif self.verbose: - log(INFO, "\tnnunet dataset fingerprint already exists. Skipping fingerprint extraction") + log( + INFO, + "\tnnunet dataset fingerprint already exists. Skipping fingerprint extraction", + ) elif self.verbose: - log(INFO, "\tThis client instance has already extracted the dataset fingerprint. Skipping.") + log( + INFO, + "\tThis client instance has already extracted the dataset fingerprint. Skipping.", + ) # Avoid extracting fingerprint multiple times when always_preprocess is true self.fingerprint_extracted = True @@ -499,7 +515,10 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, Dict[str, torch ) def compute_loss_and_additional_losses( - self, preds: TorchPredType, features: Dict[str, torch.Tensor], target: TorchTargetType + self, + preds: TorchPredType, + features: Dict[str, torch.Tensor], + target: TorchTargetType, ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: """ Checks the pred and target types and computes the loss @@ -570,10 +589,16 @@ def mask_data(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Ten # Tile the mask to be one hot encoded mask_here = torch.tile(mask, (1, pred.shape[1], *[1 for _ in range(2, pred.ndim)])) - return pred * mask_here, new_target # Mask the input tensor and return the modified target + return ( + pred * mask_here, + new_target, + ) # Mask the input tensor and return the modified target def update_metric_manager( - self, preds: TorchPredType, target: TorchTargetType, metric_manager: MetricManager + self, + preds: TorchPredType, + target: TorchTargetType, + metric_manager: MetricManager, ) -> None: """ Update the metrics with preds and target. Overridden because we might @@ -637,7 +662,10 @@ def empty_cache(self) -> None: torch.mps.empty_cache() def get_client_specific_logs( - self, current_round: Optional[int], current_epoch: Optional[int], logging_mode: LoggingMode + self, + current_round: Optional[int], + current_epoch: Optional[int], + logging_mode: LoggingMode, ) -> Tuple[str, List[Tuple[LogLevel, str]]]: if logging_mode == LoggingMode.TRAIN: lr = float(self.optimizers["global"].param_groups[0]["lr"]) diff --git a/fl4health/clients/partial_weight_exchange_client.py b/fl4health/clients/partial_weight_exchange_client.py index 8a67b3899..7ae7037a2 100644 --- a/fl4health/clients/partial_weight_exchange_client.py +++ b/fl4health/clients/partial_weight_exchange_client.py @@ -13,7 +13,7 @@ from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.parameter_exchange.partial_parameter_exchanger import PartialParameterExchanger -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -26,7 +26,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, store_initial_model: bool = False, ) -> None: """ @@ -45,6 +45,8 @@ def __init__( checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. store_initial_model (bool): Indicates whether the client should store a copy of the model weights at the beginning of each training round. The model copy might be required to select the subset of model parameters to be exchanged with the server, depending on the selection criterion used. @@ -56,7 +58,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, ) # Initial model parameters to be used in selecting parameters to be exchanged during training. self.initial_model: Optional[nn.Module] @@ -108,7 +110,10 @@ def get_parameters(self, config: Config) -> NDArrays: NDArrays: The list of weights to be sent to the server from the client """ if not self.initialized: - log(INFO, "Setting up client and providing full model parameters to the server for initialization") + log( + INFO, + "Setting up client and providing full model parameters to the server for initialization", + ) # If initialized==False, the server is requesting model parameters from which to initialize all other # clients. As such get_parameters is being called before fit or evaluate, so we must call diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 2fc0a98ba..7bdfad00b 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -15,6 +15,7 @@ from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric @@ -35,6 +36,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( data_path=data_path, @@ -42,6 +44,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, + reporters=reporters, ) self.learning_rate: float # eta_l in paper self.client_control_variates: Optional[NDArrays] = None # c_i in paper @@ -58,7 +61,10 @@ def get_parameters(self, config: Config) -> NDArrays: Packs the parameters and control variates into a single NDArrays to be sent to the server for aggregation """ if not self.initialized: - log(INFO, "Setting up client and providing full model parameters to the server for initialization") + log( + INFO, + "Setting up client and providing full model parameters to the server for initialization", + ) # If initialized==False, the server is requesting model parameters from which to initialize all other # clients. As such get_parameters is being called before fit or evaluate, so we must call @@ -161,7 +167,9 @@ def modify_grad(self) -> None: ] for param, client_cv, server_cv in zip( - model_params_with_grad, self.client_control_variates, self.server_control_variates + model_params_with_grad, + self.client_control_variates, + self.server_control_variates, ): assert param.grad is not None tensor_type = param.grad.dtype @@ -187,7 +195,10 @@ def transform_gradients(self, losses: TrainingLosses) -> None: self.modify_grad() def compute_updated_control_variates( - self, local_steps: int, delta_model_weights: NDArrays, delta_control_variates: NDArrays + self, + local_steps: int, + delta_model_weights: NDArrays, + delta_control_variates: NDArrays, ) -> NDArrays: """ Computes the updated local control variates according to option 2 in Equation 4 of paper diff --git a/fl4health/reporting/__init__.py b/fl4health/reporting/__init__.py index e69de29bb..59a2d47e2 100644 --- a/fl4health/reporting/__init__.py +++ b/fl4health/reporting/__init__.py @@ -0,0 +1,5 @@ +from .json_reporter import JsonReporter +from .wandb_reporter import WandBReporter + +# Must add unused imports to __all__ so that flake8 knows whats going on +__all__ = ["JsonReporter", "WandBReporter"] diff --git a/fl4health/reporting/base_reporter.py b/fl4health/reporting/base_reporter.py new file mode 100644 index 000000000..ae417ff8c --- /dev/null +++ b/fl4health/reporting/base_reporter.py @@ -0,0 +1,51 @@ +"""Base Class for Reporters. + +Super simple for now but keeping it in a seperate file in case we add more base methods. +""" + +from typing import Any + + +class BaseReporter: + def report( + self, + data: dict, + round: int | None = None, + epoch: int | None = None, + step: int | None = None, + ) -> None: + """A method called by clients or servers to send data to the reporter. + + The report method is called by the client/server at frequent intervals (ie step, + epoch, round) and sometimes outside of a FL round (for high level summary data). + It is up to the reporter to determine when and what to report. + + Args: + data (dict): The data to maybe report from the server or client. + round (int | None, optional): The current FL round. If None, this indicates + that the method was called outside of a round (eg. for summary + information). Defaults to None. + epoch (int | None, optional): The current epoch. If None then this method + was not called within the scope of an epoch. Defaults to None. + step (int | None, optional): The current step (total). If None then this + method was called outside the scope of a training or evaluation step + (eg. at the end of an epoch or round) Defaults to None. + """ + raise NotImplementedError + + def initialize(self, **kwargs: Any) -> None: + """Method for initializing reporters with client/server information + + This method is called once by the client or server during initialization. + + Args: + kwargs (Any): arbitrary keyword arguments containing information from the + client or server that might be useful for initializing the reporter. + This information should be treated as optional and this method should + work even if no keyword arguments are passed. + """ + pass + + def shutdown(self) -> None: + """Called by the client/server on shutdown.""" + pass diff --git a/fl4health/reporting/fl_wandb.py b/fl4health/reporting/fl_wandb.py deleted file mode 100644 index 92d223a7a..000000000 --- a/fl4health/reporting/fl_wandb.py +++ /dev/null @@ -1,167 +0,0 @@ -import os -from logging import INFO -from typing import Any, Dict, List, Optional, Tuple - -import wandb -from flwr.common.logger import log -from flwr.common.typing import Scalar -from flwr.server.history import History -from wandb.wandb_run import Run - - -class WandBReporter: - def __init__( - self, - project_name: str, - run_name: str, - group_name: str, - entity: str, - notes: Optional[str], - tags: Optional[List], - config: Dict[str, Any], - local_log_directory: str = "./fl_wandb_logs", - ) -> None: - # Name of the project under which to store all of the logged values - self.project_name = project_name - # Name of the run under the group (server or client associated) - self.run_name = run_name - # Name of grouping on the W and B dashboard - self.group_name = group_name - # W and B username under which these experiments are to be logged - self.entity = entity - # Any notes to go along with the experiment to be logged - self.notes = notes - # Any tags to make searching easier.\ - self.tags = tags - # Initialize the WandB logger - self._maybe_create_local_log_directory(local_log_directory) - wandb.init( - dir=local_log_directory, - project=self.project_name, - name=self.run_name, - group=self.group_name, - entity=self.entity, - notes=self.notes, - tags=self.tags, - config=config, - ) - assert wandb.run is not None - self.wandb_run: Run = wandb.run - - def _maybe_create_local_log_directory(self, local_log_directory: str) -> None: - log_directory_exists = os.path.isdir(local_log_directory) - if not log_directory_exists: - os.mkdir(local_log_directory) - log(INFO, f"Logging directory {local_log_directory} does not exist. Creating it.") - else: - log(INFO, f"Logging directory {local_log_directory} exists.") - - def _log_metrics(self, metric_dict: Dict[str, Any]) -> None: - self.wandb_run.log(metric_dict) - - def shutdown_reporter(self) -> None: - self.wandb_run.finish() - - -class ServerWandBReporter(WandBReporter): - def __init__( - self, - project_name: str, - run_name: str, - group_name: str, - entity: str, - notes: Optional[str], - tags: Optional[List], - fl_config: Dict[str, Any], - ) -> None: - super().__init__(project_name, run_name, group_name, entity, notes, tags, fl_config) - - def _convert_losses_history( - self, history_to_log: List[Dict[str, Scalar]], loss_history: List[Tuple[int, float]], loss_stub: str - ) -> None: - for server_round, loss in loss_history: - # Server rounds are indexed starting at 1 - history_to_log[server_round - 1][f"{loss_stub}_loss"] = loss - - def _flatten_metrics_history( - self, - history_to_log: List[Dict[str, Scalar]], - metrics_history: Dict[str, List[Tuple[int, Scalar]]], - metric_stub: str, - ) -> None: - for metric_name, metric_history in metrics_history.items(): - metric = f"{metric_stub}_{metric_name}" - for server_round, history_value in metric_history: - # Server rounds are indexed starting at 1 - history_to_log[server_round - 1][metric] = history_value - - def report_metrics(self, server_rounds: int, history: History) -> None: - # Servers construct a history object that collects aggregated metrics over the set of server rounds conducted. - # So we need to reformat the history object into a W and B log-able object - history_to_log: List[Dict[str, Scalar]] = [ - {"server_round": server_round} for server_round in range(server_rounds) - ] - if len(history.losses_centralized) > 0: - self._convert_losses_history(history_to_log, history.losses_centralized, "centralized") - if len(history.losses_distributed) > 0: - self._convert_losses_history(history_to_log, history.losses_distributed, "distributed_val") - if history.metrics_centralized: - self._flatten_metrics_history(history_to_log, history.metrics_centralized, "centralized") - if history.metrics_distributed: - self._flatten_metrics_history(history_to_log, history.metrics_distributed, "distributed") - if history.metrics_distributed_fit: - self._flatten_metrics_history(history_to_log, history.metrics_distributed_fit, "distributed_fit") - - for server_metrics in history_to_log: - self._log_metrics(server_metrics) - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> Optional["ServerWandBReporter"]: - assert "reporting_config" in config - reporter_config = config["reporting_config"] - if reporter_config["enabled"]: - # Strip out the reporting configuration variables. - fl_config = {key: value for key, value in config.items() if key != "reporting_config"} - return ServerWandBReporter( - reporter_config["project_name"], - reporter_config["run_name"], - reporter_config["group_name"], - reporter_config["entity"], - reporter_config.get("notes"), - reporter_config.get("tags"), - fl_config, - ) - else: - return None - - -class ClientWandBReporter(WandBReporter): - def __init__( - self, - client_name: str, - project_name: str, - group_name: str, - entity: str, - ) -> None: - self.client_name = client_name - config = {"client_name": client_name} - run_name = f"Client_{client_name}" - super().__init__(project_name, run_name, group_name, entity, None, None, config) - - def log_model_type(self, model_type: str) -> None: - self.add_client_model_type(f"{self.client_name}_model", model_type) - - def report_metrics(self, metrics: Dict[str, Any]) -> None: - # Attach client name for W and B logging - client_metrics = {f"{self.client_name}_{key}": metric for key, metric in metrics.items()} - self._log_metrics(client_metrics) - - def add_client_model_type(self, client_name: str, model_type: str) -> None: - self.wandb_run.config[client_name] = model_type - - @classmethod - def from_config(cls, client_name: str, config: Dict[str, Any]) -> Optional["ClientWandBReporter"]: - if "reporting_enabled" in config and config["reporting_enabled"]: - return ClientWandBReporter(client_name, config["project_name"], config["group_name"], config["entity"]) - else: - return None diff --git a/fl4health/reporting/json_reporter.py b/fl4health/reporting/json_reporter.py new file mode 100644 index 000000000..2d73b5585 --- /dev/null +++ b/fl4health/reporting/json_reporter.py @@ -0,0 +1,90 @@ +import datetime +import json +import uuid +from logging import INFO +from pathlib import Path +from typing import Any + +from flwr.common.logger import log + +from fl4health.reporting.base_reporter import BaseReporter + + +class DateTimeEncoder(json.JSONEncoder): + """ + Converts a datetime object to string in order to make json encoding easier. + """ + + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime.datetime): + return str(obj) + else: + return json.JSONEncoder.default(self, obj) + + +class FileReporter(BaseReporter): + def __init__( + self, + run_id: str | None = None, + output_folder: str | Path = Path("metrics"), + ): + """Reports data each round and saves as a json. + + Args: + run_id (str | None, optional): the identifier for the run which these + metrics are from. If left as None will check if an id is provided during + initialize, otherwise uses a UUID. + output_folder (str | Path): the folder to save the metrics to. The metrics + will be saved in a file named {output_folder}/{run_id}.json. Optional, + default is "metrics". + """ + self.run_id = run_id + + self.output_folder = Path(output_folder) + self.metrics: dict[str, Any] = {} + + self.output_folder.mkdir(exist_ok=True) + assert self.output_folder.is_dir(), f"Output folder '{self.output_folder}' is not a valid directory." + + def initialize(self, **kwargs: Any) -> None: + # If run_id was not specified on init try first to initialize with client name + if self.run_id is None: + self.run_id = kwargs.get("id") + # If client name was not provided, init run id manually + if self.run_id is None: + self.run_id = str(uuid.uuid4()) + + def report( + self, + data: dict[str, Any], + round: int | None = None, + epoch: int | None = None, + step: int | None = None, + ) -> None: + if round is None: # Reports outside of a fit round + self.metrics.update(data) + # Ensure we don't report for each epoch or step + elif epoch is None and step is None: + if "rounds" not in self.metrics: + self.metrics["rounds"] = {} + if round not in self.metrics["rounds"]: + self.metrics["rounds"][round] = {} + + self.metrics["rounds"][round].update(data) + + def dump(self) -> None: + raise NotImplementedError + + def shutdown(self) -> None: + self.dump() + + +class JsonReporter(FileReporter): + def dump(self) -> None: + assert isinstance(self.run_id, str) + """Dumps the current metrics to a JSON file at {output_folder}/{run_id.json}""" + output_file_path = Path(self.output_folder, self.run_id).with_suffix(".json") + log(INFO, f"Dumping metrics to {output_file_path}") + + with open(output_file_path, "w") as output_file: + json.dump(self.metrics, output_file, indent=4) diff --git a/fl4health/reporting/metrics.py b/fl4health/reporting/metrics.py deleted file mode 100644 index 1991c107b..000000000 --- a/fl4health/reporting/metrics.py +++ /dev/null @@ -1,84 +0,0 @@ -import datetime -import json -import uuid -from logging import INFO -from pathlib import Path -from typing import Any, Dict, Optional - -from flwr.common.logger import log - - -class MetricsReporter: - """ - Stores metrics for a training execution and saves it to a JSON file. - """ - - def __init__( - self, - run_id: Optional[str] = None, - output_folder: Path = Path("metrics"), - ): - """ - Args: - run_id (str): the identifier for the run which these metrics are from. - Optional, default is a random UUID. - output_folder (str): the folder to save the metrics to. The metrics will be saved in a file - named {output_folder}/{run_id}.json. Optional, default is "metrics". - """ - if run_id is not None: - self.run_id = run_id - else: - self.run_id = str(uuid.uuid4()) - - self.output_folder = output_folder - self.metrics: Dict[str, Any] = {} - - self.output_folder.mkdir(exist_ok=True) - assert self.output_folder.is_dir(), f"Output folder '{self.output_folder}' is not a valid directory." - - def add_to_metrics(self, data: Dict[str, Any]) -> None: - """ - Adds a dictionary of data into the main metrics dictionary. - - Args: - data (Dict[str, Any]): Data to be added to the metrics dictionary via .update(). - """ - self.metrics.update(data) - - def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None: - """ - Adds a dictionary of data into the metrics dictionary for a specific FL round. - - Args: - fl_round (int): the FL round these metrics are from. - data (Dict[str, Any]): Data to be added to the round's metrics dictionary via .update(). - """ - if "rounds" not in self.metrics: - self.metrics["rounds"] = {} - - if fl_round not in self.metrics["rounds"]: - self.metrics["rounds"][fl_round] = {} - - self.metrics["rounds"][fl_round].update(data) - - def dump(self) -> None: - """ - Dumps the current metrics to a JSON file at {self.output_folder}/{self.run_id}.json - """ - output_file_path = Path(self.output_folder, self.run_id).with_suffix(".json") - log(INFO, f"Dumping metrics to {output_file_path}") - - with open(output_file_path, "w") as output_file: - json.dump(self.metrics, output_file, indent=4, cls=DateTimeEncoder) - - -class DateTimeEncoder(json.JSONEncoder): - """ - Converts a datetime object to string in order to make json encoding easier. - """ - - def default(self, obj: Any) -> Any: - if isinstance(obj, datetime.datetime): - return str(obj) - else: - return json.JSONEncoder.default(self, obj) diff --git a/fl4health/reporting/report_manager.py b/fl4health/reporting/report_manager.py new file mode 100644 index 000000000..1e2b44edf --- /dev/null +++ b/fl4health/reporting/report_manager.py @@ -0,0 +1,21 @@ +from collections.abc import Sequence +from typing import Any + +from fl4health.reporting.base_reporter import BaseReporter + + +class ReportsManager: + def __init__(self, reporters: Sequence[BaseReporter] | None = None) -> None: + self.reporters = [] if reporters is None else list(reporters) + + def initialize(self, **kwargs: Any) -> None: + for r in self.reporters: + r.initialize(**kwargs) + + def report(self, data: dict, round: int | None = None, epoch: int | None = None, step: int | None = None) -> None: + for r in self.reporters: + r.report(data, round, epoch, step) + + def shutdown(self) -> None: + for r in self.reporters: + r.shutdown() diff --git a/fl4health/reporting/wandb_reporter.py b/fl4health/reporting/wandb_reporter.py new file mode 100644 index 000000000..71201bbfc --- /dev/null +++ b/fl4health/reporting/wandb_reporter.py @@ -0,0 +1,123 @@ +from enum import Enum +from pathlib import Path +from typing import Any + +import wandb +import wandb.wandb_run + +from fl4health.reporting.base_reporter import BaseReporter + + +class StepType(Enum): + ROUND = "round" + EPOCH = "epoch" + BATCH = "batch" + + +# TODO: Add ability to parse data types and save certain data types in specific ways +# (eg. Artifacts, Tables, etc.) + + +class WandBReporter(BaseReporter): + def __init__(self, timestep: StepType | str = StepType.ROUND, **kwargs: Any) -> None: + """Reporter that logs data to a wandb server. + + Args: + timestep (StepType | str, optional): How frequently to log data. Either + every 'round', 'epoch' or 'step'. Defaults to StepType.ROUND. + **kwargs (Any): + Keyword arguments to wandb.init + + """ + # Create wandb metadata dir if necessary + if kwargs.get("dir") is not None: + Path(kwargs["dir"]).mkdir(exist_ok=True) + + # Create run and set attrbutes + self.wandb_init_kwargs = kwargs + self.timestep_type = StepType(timestep) if isinstance(timestep, str) else timestep + self.run_started = False + self.initialized = False + # To maybe be initialized later + self.run_id = kwargs.get("id") + self.run: wandb.wandb_run.Run + + def initialize(self, **kwargs: Any) -> None: + """Checks if an id was provided by the client or server. + + If an id was passed to the WandBReporter on init then it takes priority over the + one passed by the client/server. + """ + if self.run_id is None: + self.run_id = kwargs.get("id") + self.initialized = True + + def start_run(self, **kwargs: Any) -> None: + """Initializes the wandb run. + + Args: + kwargs (Any): Keyword arguments for wandb.init() + """ + if not self.initialized: + self.initialize() + + self.run = wandb.init(id=self.run_id, **kwargs) + self.run_id = self.run._run_id # If run_id was None, we need to reset run id + self.run_started = True + + def get_wandb_timestep( + self, + round: int | None, + epoch: int | None, + step: int | None, + ) -> int | None: + """Determines the current step based on the timestep type. + + Args: + round (int | None): The current round or None if called outside of a round. + epoch (int | None): The current epoch or None if called outside of a epoch. + step (int | None): The current step (total) or None if called outside of + step. + + Returns: + int | None: Returns None if the reporter should not report metrics on this + call. If an integer is returned then it is what the reporter should use + as the current wandb step. + """ + if self.timestep_type == StepType.ROUND and epoch is None and step is None: + return round + elif self.timestep_type == StepType.EPOCH and step is None: + return epoch + elif self.timestep_type == StepType.BATCH: + return step + return None + + def report( + self, + data: dict, + round: int | None = None, + epoch: int | None = None, + batch: int | None = None, + ) -> None: + # If round is None, assume data is summary information + if round is None: + if not self.run_started: + self.start_run(**self.wandb_init_kwargs) + self.run.summary.update(data) + + # Get wandb step based on timestep_type + timestep = self.get_wandb_timestep(round, epoch, batch) + + # If timestep is None, then we should not report on this call + if timestep is None: + return + + # Check if wandb run has been initialized + if not self.run_started: + self.start_run(**self.wandb_init_kwargs) + + # Log data + self.run.log(data, step=timestep) + + def shutdown(self) -> None: + self.run.finish() diff --git a/fl4health/server/adaptive_constraint_servers/ditto_server.py b/fl4health/server/adaptive_constraint_servers/ditto_server.py index 4a4ece73f..cb37218b4 100644 --- a/fl4health/server/adaptive_constraint_servers/ditto_server.py +++ b/fl4health/server/adaptive_constraint_servers/ditto_server.py @@ -3,8 +3,7 @@ from flwr.server.client_manager import ClientManager from fl4health.checkpointing.checkpointer import TorchCheckpointer -from fl4health.reporting.fl_wandb import ServerWandBReporter -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -14,9 +13,8 @@ def __init__( self, client_manager: ClientManager, strategy: FedAvgWithAdaptiveConstraint, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: """ This is a very basic wrapper class over the FlServer to ensure that the strategy used for Ditto is of type @@ -28,18 +26,17 @@ def __init__( strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. For Ditto, the strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. checkpointer (Optional[Union[TorchCheckpointer, Sequence [TorchCheckpointer]]], optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to checkpointer based on multiple criteria. Ensure checkpoint names are different for each checkpoint or they will overwrite on another. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" - super().__init__(client_manager, strategy, wandb_reporter, checkpointer, metrics_reporter) + super().__init__( + client_manager=client_manager, strategy=strategy, checkpointer=checkpointer, reporters=reporters + ) diff --git a/fl4health/server/adaptive_constraint_servers/fedprox_server.py b/fl4health/server/adaptive_constraint_servers/fedprox_server.py index 53412702b..4004f81a7 100644 --- a/fl4health/server/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/server/adaptive_constraint_servers/fedprox_server.py @@ -7,8 +7,7 @@ from fl4health.checkpointing.checkpointer import TorchCheckpointer from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint -from fl4health.reporting.fl_wandb import ServerWandBReporter -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServerWithCheckpointing from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -19,9 +18,8 @@ def __init__( client_manager: ClientManager, strategy: FedAvgWithAdaptiveConstraint, model: Optional[nn.Module] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: """ This is a wrapper class around FlServerWithCheckpointing for using the FedProx method that enforces that the @@ -32,20 +30,20 @@ def __init__( client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. parameter_exchanger (ExchangerType): This is the parameter exchanger to be used to hydrate the model. - strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. - client updates and other information potentially sent by the participating clients. For FedProx, the - strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. - model (Optional[nn.Module], optional): This is the torch model to be hydrated by the - _hydrate_model_for_checkpointing function, Defaults to None - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. + strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used + by the server to handle. client updates and other information + potentially sent by the participating clients. For FedProx, the strategy + must be a derivative of the FedAvgWithAdaptiveConstraint class. + model (Optional[nn.Module], optional): This is the torch model to be + hydrated by the _hydrate_model_for_checkpointing function, Defaults to + None checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to - checkpoint based on multiple criteria. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + if the server should perform server side checkpointing based on some + criteria. If none, then no server-side checkpointing is performed. + Multiple checkpointers can also be passed in a sequence to checkpoint + based on multiple criteria. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint @@ -55,10 +53,9 @@ def __init__( client_manager=client_manager, parameter_exchanger=parameter_exchanger, model=model, - wandb_reporter=wandb_reporter, strategy=strategy, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, ) def _hydrate_model_for_checkpointing(self) -> nn.Module: diff --git a/fl4health/server/adaptive_constraint_servers/mrmtl_server.py b/fl4health/server/adaptive_constraint_servers/mrmtl_server.py index 3374a2042..340f6fc93 100644 --- a/fl4health/server/adaptive_constraint_servers/mrmtl_server.py +++ b/fl4health/server/adaptive_constraint_servers/mrmtl_server.py @@ -3,8 +3,7 @@ from flwr.server.client_manager import ClientManager from fl4health.checkpointing.checkpointer import TorchCheckpointer -from fl4health.reporting.fl_wandb import ServerWandBReporter -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -14,9 +13,8 @@ def __init__( self, client_manager: ClientManager, strategy: FedAvgWithAdaptiveConstraint, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: """ This is a very basic wrapper class over the FlServer to ensure that the strategy used for MR-MTL is of type @@ -28,18 +26,17 @@ def __init__( strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. For MR-MTL, the strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. checkpointer (Optional[Union[TorchCheckpointer, Sequence [TorchCheckpointer]]], optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to checkpointer based on multiple criteria. Ensure checkpoint names are different for each checkpoint or they will overwrite on another. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" - super().__init__(client_manager, strategy, wandb_reporter, checkpointer, metrics_reporter) + super().__init__( + client_manager=client_manager, strategy=strategy, checkpointer=checkpointer, reporters=reporters + ) diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 09f81029f..cb26a7386 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -1,5 +1,4 @@ import datetime -import timeit from logging import DEBUG, INFO, WARN, WARNING from pathlib import Path from typing import Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union @@ -17,8 +16,7 @@ from fl4health.checkpointing.checkpointer import PerRoundCheckpointer, TorchCheckpointer from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger -from fl4health.reporting.fl_wandb import ServerWandBReporter -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.polling import poll_clients from fl4health.strategies.strategy_with_poll import StrategyWithPolling from fl4health.utils.config import narrow_dict_type_and_set_attribute @@ -32,42 +30,56 @@ def __init__( self, client_manager: ClientManager, strategy: Optional[Strategy] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, + reporters: Sequence[BaseReporter] | None = None, checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, server_name: Optional[str] = None, ) -> None: """ Base Server for the library to facilitate strapping additional/useful machinery to the base flwr server. Args: - client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if - they are to be sampled at all. - strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. - client updates and other information potentially sent by the participating clients. If None the - strategy is FedAvg as set by the flwr Server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. + client_manager (ClientManager): Determines the mechanism by which clients + are sampled by the server, if they are to be sampled at all. + strategy (Optional[Strategy], optional): The aggregation strategy to be + used by the server to handle. client updates and other information + potentially sent by the participating clients. If None the strategy is + FedAvg as set by the flwr Server. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. + checkpointer (TorchCheckpointer | Sequence [TorchCheckpointer], optional): + To be provided if the server should perform server side checkpointing + based on some criteria. If none, then no server-side checkpointing is + performed. Multiple checkpointers can also be passed in a sequence to + checkpointer based on multiple criteria. Ensure checkpoint names are + different for each checkpoint or they will overwrite on another. Defaults to None. - checkpointer (Optional[Union[TorchCheckpointer, Sequence [TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to - checkpointer based on multiple criteria. Ensure checkpoint names are different for each checkpoint - or they will overwrite on another. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. - server_name (Optional[str]): An optional string name to uniquely identify server. + server_name (Optional[str]): An optional string name to uniquely identify + server. """ super().__init__(client_manager=client_manager, strategy=strategy) - self.wandb_reporter = wandb_reporter self.checkpointer = [checkpointer] if isinstance(checkpointer, TorchCheckpointer) else checkpointer self.server_name = server_name if server_name is not None else generate_hash() - if metrics_reporter is not None: - self.metrics_reporter = metrics_reporter - else: - self.metrics_reporter = MetricsReporter() + self.reporters = [] if reporters is None else list(reporters) + for r in self.reporters: + r.initialize(id=self.server_name) + + def report_centralized_eval(self, history: History, num_rounds: int) -> None: + if len(history.losses_centralized) == 0: + return + + # Parse and report history for loss and metrics on centralized validation set. + for round in range(num_rounds): + for r in self.reporters: + r.report( + {"val - loss - centralized": history.losses_centralized[round][1]}, + round + 1, + ) + round_metrics = {} + for metric, vals in history.metrics_centralized.items(): + round_metrics.update({metric: vals[round][1]}) + r.report({"eval_metrics_centralized": round_metrics}, round + 1) def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: """ @@ -83,20 +95,21 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float FL training results, including things like aggregated loss and metrics. Tuple also contains the elapsed time in seconds for the round. """ - self.metrics_reporter.add_to_metrics({"type": "server", "fit_start": datetime.datetime.now()}) - + start_time = datetime.datetime.now() history, elapsed_time = super().fit(num_rounds, timeout) - if self.wandb_reporter: - # report history to W and B - self.wandb_reporter.report_metrics(num_rounds, history) - - self.metrics_reporter.add_to_metrics( - data={ - "fit_end": datetime.datetime.now(), - "metrics_centralized": history.metrics_centralized, - "losses_centralized": history.losses_centralized, - } - ) + end_time = datetime.datetime.now() + for r in self.reporters: + r.report( + { + "fit_elapsed_time": str(start_time - end_time), + "fit_start": str(start_time), + "fit_end": str(end_time), + "num_rounds": num_rounds, + "host_type": "server", + } + ) + + self.report_centralized_eval(history, num_rounds) return history, elapsed_time @@ -105,25 +118,24 @@ def fit_round( server_round: int, timeout: Optional[float], ) -> Optional[Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]]: - self.metrics_reporter.add_to_metrics_at_round(server_round, data={"fit_start": datetime.datetime.now()}) - + round_start = datetime.datetime.now() fit_round_results = super().fit_round(server_round, timeout) + round_end = datetime.datetime.now() - if fit_round_results is not None: - _, metrics_aggregated, _ = fit_round_results - self.metrics_reporter.add_to_metrics_at_round( + for r in self.reporters: + r.report( + {"fit_round_start": str(round_start), "fit_round_end": str(round_end)}, server_round, - data={ - "metrics_aggregated": metrics_aggregated, - "fit_end": datetime.datetime.now(), - }, ) + if fit_round_results is not None: + _, metrics, _ = fit_round_results + r.report({"fit_metrics": metrics}, server_round) return fit_round_results def shutdown(self) -> None: - if self.wandb_reporter: - self.wandb_reporter.shutdown_reporter() + for r in self.reporters: + r.shutdown() def _hydrate_model_for_checkpointing(self) -> nn.Module: """ @@ -140,7 +152,10 @@ def _hydrate_model_for_checkpointing(self) -> nn.Module: raise NotImplementedError() def _maybe_checkpoint( - self, loss_aggregated: float, metrics_aggregated: Dict[str, Scalar], server_round: int + self, + loss_aggregated: float, + metrics_aggregated: Dict[str, Scalar], + server_round: int, ) -> None: if self.checkpointer: try: @@ -159,7 +174,10 @@ def _maybe_checkpoint( ) elif server_round == 1: # No checkpointer, just log message on the first round - log(INFO, "No checkpointer present. Models will not be checkpointed on server-side.") + log( + INFO, + "No checkpointer present. Models will not be checkpointed on server-side.", + ) def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]: """ @@ -179,7 +197,9 @@ def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]: assert isinstance(self.strategy, StrategyWithPolling) client_instructions = self.strategy.configure_poll(server_round=1, client_manager=self._client_manager) results, _ = poll_clients( - client_instructions=client_instructions, max_workers=self.max_workers, timeout=timeout + client_instructions=client_instructions, + max_workers=self.max_workers, + timeout=timeout, ) sample_counts: List[int] = [ @@ -276,7 +296,10 @@ def _evaluate_round( # Collect `evaluate` results from all clients participating in this round # flwr sets group_id to server_round by default, so we follow that convention results, failures = evaluate_clients( - client_instructions, max_workers=self.max_workers, timeout=timeout, group_id=server_round + client_instructions, + max_workers=self.max_workers, + timeout=timeout, + group_id=server_round, ) log( DEBUG, @@ -295,25 +318,32 @@ def evaluate_round( server_round: int, timeout: Optional[float], ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: - self.metrics_reporter.add_to_metrics_at_round(server_round, data={"evaluate_start": datetime.datetime.now()}) - # By default the checkpointing works off of the aggregated evaluation loss from each of the clients # NOTE: parameter aggregation occurs **before** evaluation, so the parameters held by the server have been # updated prior to this function being called. + start_time = datetime.datetime.now() eval_round_results = self._evaluate_round(server_round, timeout) + end_time = datetime.datetime.now() if eval_round_results: loss_aggregated, metrics_aggregated, _ = eval_round_results if loss_aggregated: self._maybe_checkpoint(loss_aggregated, metrics_aggregated, server_round) - - self.metrics_reporter.add_to_metrics_at_round( - server_round, - data={ - "metrics_aggregated": metrics_aggregated, - "loss_aggregated": loss_aggregated, - "evaluate_end": datetime.datetime.now(), - }, - ) + # Report evaluation results + for r in self.reporters: + r.report( + { + "val - loss - aggregated": loss_aggregated, + "round": server_round, + "eval_round_start": str(start_time), + "eval_round_end": str(end_time), + }, + server_round, + ) + if len(metrics_aggregated) > 0: + r.report( + {"eval_metrics_aggregated": metrics_aggregated}, + server_round, + ) return eval_round_results @@ -327,10 +357,9 @@ def __init__( client_manager: ClientManager, parameter_exchanger: ExchangerType, model: Optional[nn.Module] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, strategy: Optional[Strategy] = None, + reporters: Sequence[BaseReporter] | None = None, checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, intermediate_server_state_dir: Optional[Path] = None, server_name: Optional[str] = None, ) -> None: @@ -350,20 +379,22 @@ def __init__( strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. - checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to - checkpoint based on multiple criteria. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. + checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): + To be provided if the server should perform server side checkpointing + based on some criteria. If none, then no server-side checkpointing is performed. Multiple checkpointers + can also be passed in a sequence to checkpoint based on multiple criteria. Defaults to None. intermediate_server_state_dir (Path): A directory to store and load checkpoints from for the server during an FL experiment. server_name (Optional[str]): An optional string name to uniquely identify server. """ super().__init__( - client_manager, strategy, wandb_reporter, checkpointer, metrics_reporter, server_name=server_name + client_manager, + strategy, + reporters, + checkpointer, + server_name=server_name, ) self.server_model = model # To facilitate model rehydration from server-side state for checkpointing @@ -408,14 +439,20 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float metrics computed during training and validation. The second element of the tuple is the elapsed time in seconds. """ - self.metrics_reporter.add_to_metrics({"type": "server", "fit_start": datetime.datetime.now()}) - if self.per_round_checkpointer is not None: + start_time = datetime.datetime.now() history, elapsed_time = self.fit_with_per_epoch_checkpointing(num_rounds, timeout) - - if self.wandb_reporter: - # report history to W and B - self.wandb_reporter.report_metrics(num_rounds, history) + end_time = datetime.datetime.now() + for r in self.reporters: + r.report( + { + "fit_elapsed_time": str(start_time - end_time), + "fit_start": str(start_time), + "fit_end": str(end_time), + "num_rounds": num_rounds, + "host_type": "server", + } + ) else: # parent method includes call to report metrics to wandb if specified history, elapsed_time = super().fit(num_rounds, timeout) @@ -468,7 +505,7 @@ def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[fl # Run federated learning for num_rounds log(INFO, "FL starting") - start_time = timeit.default_timer() + start_time = datetime.datetime.now() while self.current_round < (num_rounds + 1): # Train model and replace previous global model @@ -489,7 +526,7 @@ def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[fl self.current_round, loss_cen, metrics_cen, - timeit.default_timer() - start_time, + (datetime.datetime.now() - start_time).total_seconds(), ) self.history.add_loss_centralized(server_round=self.current_round, loss=loss_cen) self.history.add_metrics_centralized(server_round=self.current_round, metrics=metrics_cen) @@ -509,10 +546,10 @@ def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[fl self.save_server_state() # Bookkeeping - end_time = timeit.default_timer() + end_time = datetime.datetime.now() elapsed_time = end_time - start_time - log(INFO, "FL finished in %s", elapsed_time) - return self.history, elapsed_time + log(INFO, "FL finished in %s", str(elapsed_time)) + return self.history, elapsed_time.total_seconds() def save_server_state(self) -> None: """ @@ -526,13 +563,16 @@ def save_server_state(self) -> None: "model": self.server_model, "history": self.history, "current_round": self.current_round, - "metrics_reporter": self.metrics_reporter, + "reporters": self.reporters, "server_name": self.server_name, } self.per_round_checkpointer.save_checkpoint(ckpt) - log(INFO, f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}") + log( + INFO, + f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}", + ) def load_server_state(self) -> None: """ @@ -543,11 +583,14 @@ def load_server_state(self) -> None: ckpt = self.per_round_checkpointer.load_checkpoint() - log(INFO, f"Loading server state from checkpoint at {self.per_round_checkpointer.checkpoint_path}") + log( + INFO, + f"Loading server state from checkpoint at {self.per_round_checkpointer.checkpoint_path}", + ) narrow_dict_type_and_set_attribute(self, ckpt, "server_name", "server_name", str) narrow_dict_type_and_set_attribute(self, ckpt, "current_round", "current_round", int) - narrow_dict_type_and_set_attribute(self, ckpt, "metrics_reporter", "metrics_reporter", MetricsReporter) + narrow_dict_type_and_set_attribute(self, ckpt, "reporters", "reporters", list) narrow_dict_type_and_set_attribute(self, ckpt, "history", "history", History) narrow_dict_type_and_set_attribute(self, ckpt, "model", "parameters", nn.Module, func=get_all_model_parameters) @@ -559,9 +602,8 @@ def __init__( self, client_manager: ClientManager, strategy: Optional[Strategy] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, + reporters: Sequence[BaseReporter] | None = None, checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, ) -> None: """ Server with an initialize hook method that is called prior to fit. Override the self.initialize method to do @@ -575,16 +617,15 @@ def __init__( strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. - checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. + checkpointer (Optional[Union[TorchCheckpointer, Sequence + [TorchCheckpointer]]], optional): To be provided if the server + should perform server side checkpointing based on some + criteria. If none, then no server-side checkpointing is + performed. Defaults to None. """ - super().__init__(client_manager, strategy, wandb_reporter, checkpointer, metrics_reporter) + super().__init__(client_manager, strategy, reporters, checkpointer) self.initialized = False def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: diff --git a/fl4health/server/client_level_dp_fed_avg_server.py b/fl4health/server/client_level_dp_fed_avg_server.py index e201a8964..dac561fd8 100644 --- a/fl4health/server/client_level_dp_fed_avg_server.py +++ b/fl4health/server/client_level_dp_fed_avg_server.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from logging import INFO from math import ceil from typing import List, Optional, Tuple @@ -14,7 +15,7 @@ FlClientLevelAccountantFixedSamplingNoReplacement, FlClientLevelAccountantPoissonSampling, ) -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.strategies.client_dp_fedavgm import ClientLevelDPFedAvgM @@ -26,8 +27,8 @@ def __init__( strategy: ClientLevelDPFedAvgM, server_noise_multiplier: float, num_server_rounds: int, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[TorchCheckpointer] = None, + reporters: Sequence[BaseReporter] | None = None, delta: Optional[int] = None, ) -> None: """ @@ -40,20 +41,19 @@ def __init__( client updates and other information potentially sent by the participating clients. server_noise_multiplier (float): Magnitude of noise added to the weights aggregation process by the server. num_server_rounds (int): Number of rounds of FL training carried out by the server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each round. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. """ super().__init__( client_manager=client_manager, strategy=strategy, - wandb_reporter=wandb_reporter, checkpointer=checkpointer, + reporters=reporters, ) self.accountant: ClientLevelAccountant self.server_noise_multiplier = server_noise_multiplier @@ -102,7 +102,8 @@ def setup_privacy_accountant(self, sample_counts: List[int]) -> None: if isinstance(self._client_manager, PoissonSamplingClientManager): self.accountant = FlClientLevelAccountantPoissonSampling( - client_sampling_rate=self.strategy.fraction_fit, noise_multiplier=self.server_noise_multiplier + client_sampling_rate=self.strategy.fraction_fit, + noise_multiplier=self.server_noise_multiplier, ) else: assert isinstance(self._client_manager, FixedSamplingByFractionClientManager) @@ -115,4 +116,7 @@ def setup_privacy_accountant(self, sample_counts: List[int]) -> None: # Note that this assumes that the FL round has exactly n_clients participating. epsilon = self.accountant.get_epsilon(self.num_server_rounds, target_delta) - log(INFO, f"Model privacy after full training will be ({epsilon}, {target_delta})") + log( + INFO, + f"Model privacy after full training will be ({epsilon}, {target_delta})", + ) diff --git a/fl4health/server/evaluate_server.py b/fl4health/server/evaluate_server.py index 7c6ac523b..a9139462d 100644 --- a/fl4health/server/evaluate_server.py +++ b/fl4health/server/evaluate_server.py @@ -1,5 +1,5 @@ import datetime -import timeit +from collections.abc import Sequence from logging import INFO, WARNING from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -14,7 +14,8 @@ from flwr.server.server import EvaluateResultsAndFailures, Server, evaluate_clients from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter +from fl4health.utils.random import generate_hash class EvaluateServer(Server): @@ -27,7 +28,7 @@ def __init__( evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, accept_failures: bool = True, min_available_clients: int = 1, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: """ Args: @@ -43,8 +44,8 @@ def __init__( accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 1. Defaults to 1. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - during the execution. Defaults to an instance of MetricsReporter with default init parameters. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. """ # We aren't aggregating model weights, so setting the strategy to be none. super().__init__(client_manager=client_manager, strategy=None) @@ -63,11 +64,10 @@ def __init__( f"Fraction Evaluate is {self.fraction_evaluate}. " "Thus, some clients may not participate in evaluation", ) - - if metrics_reporter is not None: - self.metrics_reporter = metrics_reporter - else: - self.metrics_reporter = MetricsReporter() + self.server_name = generate_hash() + self.reporters = [] if reporters is None else list(reporters) + for r in self.reporters: + r.initialize(id=self.server_name) def load_model_checkpoint_to_parameters(self) -> Parameters: assert self.model_checkpoint_path @@ -92,33 +92,38 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float Tuple[History, float]: The first element of the tuple is a History object containing the aggregated metrics returned from the clients. Tuple also contains elapsed time in seconds for round. """ - self.metrics_reporter.add_to_metrics({"type": "server", "fit_start": datetime.datetime.now()}) - history = History() # Run Federated Evaluation log(INFO, "Federated Evaluation Starting") - start_time = timeit.default_timer() - + start_time = datetime.datetime.now() # We're only performing federated evaluation. So we make use of the evaluate round function, but simply # perform such evaluation once. res_fed = self.federated_evaluate(timeout=timeout) + end_time = datetime.datetime.now() + + for r in self.reporters: + r.report( + { + "fit_elapsed_time": str(start_time - end_time), + "fit_start": str(start_time), + "fit_end": str(end_time), + "num_rounds": num_rounds, + "host_type": "server", + } + ) if res_fed: _, evaluate_metrics_fed, _ = res_fed if evaluate_metrics_fed: history.add_metrics_distributed(server_round=0, metrics=evaluate_metrics_fed) - self.metrics_reporter.add_to_metrics( - { - "metrics": evaluate_metrics_fed, - "fit_end": datetime.datetime.now(), - } - ) + if evaluate_metrics_fed: + for r in self.reporters: + r.report({"fit_metrics": evaluate_metrics_fed}) # Bookkeeping - end_time = timeit.default_timer() elapsed = end_time - start_time - log(INFO, "Federated Evaluation Finished in %s", elapsed) - return history, elapsed + log(INFO, "Federated Evaluation Finished in %s", str(elapsed)) + return history, elapsed.total_seconds() def federated_evaluate( self, @@ -152,9 +157,15 @@ def federated_evaluate( # Collect `evaluate` results from all clients participating in this round results, failures = evaluate_clients( - client_instructions, max_workers=self.max_workers, timeout=timeout, group_id=0 + client_instructions, + max_workers=self.max_workers, + timeout=timeout, + group_id=0, + ) + log( + INFO, + f"Federated Evaluation received {len(results)} results and {len(failures)} failures", ) - log(INFO, f"Federated Evaluation received {len(results)} results and {len(failures)} failures") # Aggregate the evaluation results, note that we assume that the losses have been packed and aggregated with # the metrics. A dummy loss is returned by each of the clients. We therefore return none for the aggregated diff --git a/fl4health/server/fedpm_server.py b/fl4health/server/fedpm_server.py index aedc2101e..2d6507c2b 100644 --- a/fl4health/server/fedpm_server.py +++ b/fl4health/server/fedpm_server.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Dict, Optional, Tuple from flwr.common import Parameters @@ -6,7 +7,7 @@ from flwr.server.server import FitResultsAndFailures from fl4health.checkpointing.checkpointer import TorchCheckpointer -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.strategies.fedpm import FedPm @@ -16,34 +17,36 @@ def __init__( self, client_manager: ClientManager, strategy: FedPm, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[TorchCheckpointer] = None, reset_frequency: int = 1, + reporters: Sequence[BaseReporter] | None = None, ) -> None: """ Custom FL Server for the FedPM algorithm to allow for resetting the beta priors in Bayesian aggregation, as specified in http://arxiv.org/pdf/2209.15328. Args: - client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if - they are to be sampled at all. - strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and - other information potentially sent by the participating clients. This strategy must be of SCAFFOLD - type. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform - server side checkpointing based on some criteria. If none, then no server-side checkpointing is - performed. Defaults to None. - reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1. + client_manager (ClientManager): Determines the mechanism by which clients + are sampled by the server, if they are to be sampled at all. + strategy (Scaffold): The aggregation strategy to be used by the server to + handle client updates and other information potentially sent by the + participating clients. This strategy must be of SCAFFOLD type. + checkpointer (Optional[TorchCheckpointer], optional): To be provided if the + server should perform server side checkpointing based on some criteria. + If none, then no server-side checkpointing is performed. Defaults to + None. + reset_frequency (int): Determines the frequency with which the beta priors + are reset. Defaults to 1. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each + round. """ FlServer.__init__( self, client_manager=client_manager, strategy=strategy, - wandb_reporter=wandb_reporter, checkpointer=checkpointer, + reporters=reporters, ) self.reset_frequency = reset_frequency diff --git a/fl4health/server/instance_level_dp_server.py b/fl4health/server/instance_level_dp_server.py index 372ebb565..7c9b7aa1e 100644 --- a/fl4health/server/instance_level_dp_server.py +++ b/fl4health/server/instance_level_dp_server.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from logging import INFO from math import ceil from typing import List, Optional, Tuple @@ -9,7 +10,7 @@ from fl4health.checkpointing.opacus_checkpointer import OpacusCheckpointer from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.privacy.fl_accountants import FlInstanceLevelAccountant -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.strategies.strategy_with_poll import StrategyWithPolling @@ -25,8 +26,8 @@ def __init__( strategy: BasicFedAvg, local_epochs: Optional[int] = None, local_steps: Optional[int] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[OpacusCheckpointer] = None, + reporters: Sequence[BaseReporter] | None = None, delta: Optional[float] = None, ) -> None: """ @@ -49,20 +50,19 @@ def __init__( strategy (OpacusBasicFedAvg): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. this must be an OpacusBasicFedAvg strategy to ensure proper treatment of the model in the Opacus framework - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. checkpointer (Optional[OpacusCheckpointer], optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. """ super().__init__( client_manager=client_manager, strategy=strategy, - wandb_reporter=wandb_reporter, checkpointer=checkpointer, + reporters=reporters, ) # Ensure that one of local_epochs and local_steps is passed (and not both) @@ -133,4 +133,7 @@ def setup_privacy_accountant(self, sample_counts: List[int]) -> None: target_delta = 1.0 / total_samples if self.delta is None else self.delta epsilon = self.accountant.get_epsilon(self.num_server_rounds, target_delta) - log(INFO, f"Model privacy after full training will be ({epsilon}, {target_delta})") + log( + INFO, + f"Model privacy after full training will be ({epsilon}, {target_delta})", + ) diff --git a/fl4health/server/nnunet_server.py b/fl4health/server/nnunet_server.py index fa9769e3f..12d154e98 100644 --- a/fl4health/server/nnunet_server.py +++ b/fl4health/server/nnunet_server.py @@ -1,8 +1,9 @@ import pickle import warnings +from collections.abc import Callable, Sequence from logging import INFO from pathlib import Path -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch.nn as nn from flwr.common import Parameters @@ -15,8 +16,7 @@ from fl4health.checkpointing.checkpointer import TorchCheckpointer from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger -from fl4health.reporting.fl_wandb import ServerWandBReporter -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServerWithCheckpointing, FlServerWithInitializer from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute from fl4health.utils.nnunet_utils import NnunetConfig @@ -26,8 +26,8 @@ from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.plans_handling.plans_handler import PlansManager -FIT_CFG_FN = Callable[[int, Parameters, ClientManager], List[Tuple[ClientProxy, FitIns]]] -EVAL_CFG_FN = Callable[[int, Parameters, ClientManager], List[Tuple[ClientProxy, EvaluateIns]]] +FIT_CFG_FN = Callable[[int, Parameters, ClientManager], list[Tuple[ClientProxy, FitIns]]] +EVAL_CFG_FN = Callable[[int, Parameters, ClientManager], list[Tuple[ClientProxy, EvaluateIns]]] CFG_FN = Union[FIT_CFG_FN, EVAL_CFG_FN] @@ -61,13 +61,12 @@ def __init__( self, client_manager: ClientManager, parameter_exchanger: ParameterExchanger, - model: Optional[nn.Module] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, - strategy: Optional[Strategy] = None, - checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, - metrics_reporter: Optional[MetricsReporter] = None, - intermediate_server_state_dir: Optional[Path] = None, - server_name: Optional[str] = None, + model: nn.Module | None = None, + strategy: Strategy | None = None, + checkpointer: TorchCheckpointer | Sequence[TorchCheckpointer] | None = None, + reporters: Sequence[BaseReporter] | None = None, + intermediate_server_state_dir: Path | None = None, + server_name: str | None = None, ) -> None: """ A Basic FlServer with added functionality to ask a client to initialize @@ -75,34 +74,36 @@ def __init__( for use with NnUNetClient. Args: - client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if - they are to be sampled at all. - model (nn.Module): This is the torch model to be hydrated by the _hydrate_model_for_checkpointing function - parameter_exchanger (ExchangerType): This is the parameter exchanger to be used to hydrate the model. - strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle - client updates and other information potentially sent by the participating clients. If None the - strategy is FedAvg as set by the flwr Server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. - checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to + client_manager (ClientManager): Determines the mechanism by which clients + are sampled by the server, if they are to be sampled at all. + model (nn.Module): This is the torch model to be hydrated by the + _hydrate_model_for_checkpointing function + parameter_exchanger (ExchangerType): This is the parameter exchanger to be + used to hydrate the model. + strategy (Optional[Strategy], optional): The aggregation strategy to be + used by the server to handle client updates and other information + potentially sent by the participating clients. If None the strategy is + FedAvg as set by the flwr Server. + checkpointer (TorchCheckpointer | Sequence[TorchCheckpointer], optional): + To be provided if the server should perform server side checkpointing + based on some criteria. If none, then no server-side checkpointing is + performed. Multiple checkpointers can also be passed in a sequence to checkpoint based on multiple criteria. Defaults to None. - metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics - intermediate_server_state_dir (Path): A directory to store and load checkpoints from for the server - during an FL experiment. - server_name (Optional[str]): An optional string name to uniquely identify server. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. + intermediate_server_state_dir (Path): A directory to store and load + checkpoints from for the server during an FL experiment. + server_name (Optional[str]): An optional string name to uniquely identify + server. """ FlServerWithCheckpointing.__init__( self, client_manager=client_manager, model=model, parameter_exchanger=parameter_exchanger, - wandb_reporter=wandb_reporter, strategy=strategy, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, intermediate_server_state_dir=intermediate_server_state_dir, server_name=server_name, ) @@ -182,7 +183,10 @@ def initialize(self, server_round: int, timeout: Optional[float] = None) -> None # Sample properties from a random client to initialize plans log(INFO, "") log(INFO, "[PRE-INIT]") - log(INFO, "Requesting initialization of global nnunet plans from one random client via get_properties") + log( + INFO, + "Requesting initialization of global nnunet plans from one random client via get_properties", + ) random_client = self._client_manager.sample(1)[0] ins = GetPropertiesIns(config=config) properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round) @@ -246,7 +250,7 @@ def save_server_state(self) -> None: "model": self.server_model, "history": self.history, "current_round": self.current_round, - "metrics_reporter": self.metrics_reporter, + "reporters": self.reporters, "server_name": self.server_name, "nnunet_plans_bytes": self.nnunet_plans_bytes, "num_input_channels": self.num_input_channels, @@ -257,7 +261,10 @@ def save_server_state(self) -> None: self.per_round_checkpointer.save_checkpoint(ckpt) - log(INFO, f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}") + log( + INFO, + f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}", + ) def load_server_state(self) -> None: """ @@ -268,12 +275,15 @@ def load_server_state(self) -> None: ckpt = self.per_round_checkpointer.load_checkpoint() - log(INFO, f"Loading server state from checkpoint at {self.per_round_checkpointer.checkpoint_path}") + log( + INFO, + f"Loading server state from checkpoint at {self.per_round_checkpointer.checkpoint_path}", + ) # Standard attributes to load narrow_dict_type_and_set_attribute(self, ckpt, "current_round", "current_round", int) narrow_dict_type_and_set_attribute(self, ckpt, "server_name", "server_name", str) - narrow_dict_type_and_set_attribute(self, ckpt, "metrics_reporter", "metrics_reporter", MetricsReporter) + narrow_dict_type_and_set_attribute(self, ckpt, "reporters", "reporters", list) narrow_dict_type_and_set_attribute(self, ckpt, "history", "history", History) narrow_dict_type_and_set_attribute(self, ckpt, "model", "parameters", nn.Module, func=get_all_model_parameters) diff --git a/fl4health/server/scaffold_server.py b/fl4health/server/scaffold_server.py index a391554ab..f320ef05b 100644 --- a/fl4health/server/scaffold_server.py +++ b/fl4health/server/scaffold_server.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from logging import DEBUG, ERROR, INFO from typing import Optional, Tuple @@ -8,7 +9,7 @@ from flwr.server.server import fit_clients from fl4health.checkpointing.checkpointer import TorchCheckpointer -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.server.instance_level_dp_server import InstanceLevelDpServer from fl4health.strategies.scaffold import OpacusScaffold, Scaffold @@ -19,8 +20,8 @@ def __init__( self, client_manager: ClientManager, strategy: Scaffold, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[TorchCheckpointer] = None, + reporters: Sequence[BaseReporter] | None = None, warm_start: bool = False, # Whether or not to initialize control variates of each client as local gradient ) -> None: """ @@ -33,23 +34,25 @@ def __init__( strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform - server side checkpointing based on some criteria. If none, then no server-side checkpointing is - performed. Defaults to None. - warm_start (bool, optional): Whether or not to initialize control variates of each client as - local gradients. The clients will perform a training pass (without updating the weights) in order to - provide a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. - Defaults to False. + checkpointer (Optional[TorchCheckpointer], optional): To be provided if the + server should perform server side checkpointing based on some criteria. + If none, then no server-side checkpointing is performed. Defaults to + None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the server should send data to before and after each + round. + warm_start (bool, optional): Whether or not to initialize control variates + of each client as local gradients. The clients will perform a training + pass (without updating the weights) in order to provide a "warm" + estimate of the SCAFFOLD control variates. If false, variates are + initialized to 0. Defaults to False. """ FlServer.__init__( self, client_manager=client_manager, strategy=strategy, - wandb_reporter=wandb_reporter, checkpointer=checkpointer, + reporters=reporters, ) self.warm_start = warm_start @@ -78,9 +81,14 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) - # control variates are initialized as average local gradient over training steps # while the model weights remain unchanged (until the FL rounds start) if self.warm_start: - log(INFO, "Using Warm Start Strategy. Waiting for clients to be available for polling") + log( + INFO, + "Using Warm Start Strategy. Waiting for clients to be available for polling", + ) client_instructions = self.strategy.configure_fit_all( - server_round=0, parameters=initial_parameters, client_manager=self._client_manager + server_round=0, + parameters=initial_parameters, + client_manager=self._client_manager, ) if not client_instructions: log(ERROR, "Warm Start initialization failed: no clients selected", 1) @@ -91,7 +99,12 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) - clients (out of {self._client_manager.num_available()})", ) - results, failures = fit_clients(client_instructions, self.max_workers, timeout, group_id=server_round) + results, failures = fit_clients( + client_instructions, + self.max_workers, + timeout, + group_id=server_round, + ) log( DEBUG, @@ -149,9 +162,9 @@ def __init__( local_epochs: Optional[int] = None, local_steps: Optional[int] = None, delta: Optional[float] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[TorchCheckpointer] = None, warm_start: bool = False, + reporters: Sequence[BaseReporter] | None = None, ) -> None: """ Custom FL Server for Instance Level Differentially Private Scaffold algorithm as specified in @@ -173,9 +186,6 @@ def __init__( strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log - information and results to a Weights and Biases account. If None is provided, no logging occurs. - Defaults to None. checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Defaults to None. @@ -185,6 +195,8 @@ def __init__( Defaults to False. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + reporters which the client should send data to. """ # Require the strategy to be an OpacusStrategy to handle the Opacus model conversion etc. assert isinstance( @@ -194,9 +206,9 @@ def __init__( self, client_manager=client_manager, strategy=strategy, - wandb_reporter=wandb_reporter, checkpointer=checkpointer, warm_start=warm_start, + reporters=reporters, ) InstanceLevelDpServer.__init__( self, diff --git a/fl4health/server/tabular_feature_alignment_server.py b/fl4health/server/tabular_feature_alignment_server.py index ad6754e59..cff33517a 100644 --- a/fl4health/server/tabular_feature_alignment_server.py +++ b/fl4health/server/tabular_feature_alignment_server.py @@ -1,7 +1,7 @@ import random from functools import partial from logging import DEBUG, INFO, WARNING -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple from flwr.common import Parameters from flwr.common.logger import log @@ -18,7 +18,7 @@ SOURCE_SPECIFIED, ) from fl4health.feature_alignment.tab_features_info_encoder import TabularFeaturesInfoEncoder -from fl4health.reporting.fl_wandb import ServerWandBReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.server.base_server import FlServer from fl4health.server.polling import poll_clients from fl4health.strategies.basic_fedavg import BasicFedAvg @@ -52,16 +52,18 @@ def __init__( config: Config, initialize_parameters: Callable[..., Parameters], strategy: BasicFedAvg, - wandb_reporter: Optional[ServerWandBReporter] = None, checkpointer: Optional[TorchCheckpointer] = None, tabular_features_source_of_truth: Optional[TabularFeaturesInfoEncoder] = None, + reporters: Sequence[BaseReporter] | None = None, ) -> None: if strategy.on_fit_config_fn is not None: log(WARNING, "strategy.on_fit_config_fn will be overwritten.") if strategy.initial_parameters is not None: log(WARNING, "strategy.initial_parameters will be overwritten.") - super().__init__(client_manager, strategy, wandb_reporter, checkpointer) + super().__init__( + client_manager=client_manager, strategy=strategy, checkpointer=checkpointer, reporters=reporters + ) # The server performs one or two rounds of polls before the normal federated training. # The first one gathers feature information if the server does not already have it, # and the second one gathers the input/output dimensions of the model. diff --git a/fl4health/utils/random.py b/fl4health/utils/random.py index 323df029f..b84f2a90b 100644 --- a/fl4health/utils/random.py +++ b/fl4health/utils/random.py @@ -1,5 +1,5 @@ import random -import string +import uuid from logging import INFO from typing import Optional @@ -39,11 +39,12 @@ def unset_all_random_seeds() -> None: def generate_hash(length: int = 8) -> str: """ Generates unique hash used as id for client. + NOTE: This generation is unaffected by setting of random seeds. Args: - length (int): The length of the hash. + length (int): The length of the hash generated. Maximum length is 32 Returns: - str: client id + str: hash """ - return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) + return str(uuid.uuid4()).replace("-", "")[:length] diff --git a/fl4health/utils/typing.py b/fl4health/utils/typing.py index 7bafcd14d..0128a6878 100644 --- a/fl4health/utils/typing.py +++ b/fl4health/utils/typing.py @@ -1,17 +1,17 @@ import logging +from collections.abc import Callable from enum import Enum -from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from flwr.common.typing import NDArrays -TorchInputType = Union[torch.Tensor, Dict[str, torch.Tensor]] -TorchTargetType = Union[torch.Tensor, Dict[str, torch.Tensor]] -TorchPredType = Dict[str, torch.Tensor] -TorchFeatureType = Dict[str, torch.Tensor] +TorchInputType = torch.Tensor | dict[str, torch.Tensor] +TorchTargetType = torch.Tensor | dict[str, torch.Tensor] +TorchPredType = dict[str, torch.Tensor] +TorchFeatureType = dict[str, torch.Tensor] TorchTransformFunction = Callable[[torch.Tensor], torch.Tensor] -LayerSelectionFunction = Callable[[nn.Module, Optional[nn.Module]], Tuple[NDArrays, List[str]]] +LayerSelectionFunction = Callable[[nn.Module, nn.Module | None], tuple[NDArrays, list[str]]] class LogLevel(Enum): diff --git a/poetry.lock b/poetry.lock index 93744f578..cbd69ec8c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -194,17 +194,6 @@ files = [ {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, ] -[[package]] -name = "appdirs" -version = "1.4.4" -description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -optional = false -python-versions = "*" -files = [ - {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, - {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, -] - [[package]] name = "appnope" version = "0.1.4" @@ -487,6 +476,31 @@ numpy = "*" [package.extras] doc = ["gitpython", "numpydoc", "sphinx"] +[[package]] +name = "build" +version = "1.2.2.post1" +description = "A simple, correct Python build frontend" +optional = false +python-versions = ">=3.8" +files = [ + {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"}, + {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "os_name == \"nt\""} +importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} +packaging = ">=19.1" +pyproject_hooks = "*" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} + +[package.extras] +docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] +test = ["build[uv,virtualenv]", "filelock (>=3)", "pytest (>=6.2.4)", "pytest-cov (>=2.12)", "pytest-mock (>=2)", "pytest-rerunfailures (>=9.1)", "pytest-xdist (>=1.34)", "setuptools (>=42.0.0)", "setuptools (>=56.0.0)", "setuptools (>=56.0.0)", "setuptools (>=67.8.0)", "wheel (>=0.36.0)"] +typing = ["build[uv]", "importlib-metadata (>=5.1)", "mypy (>=1.9.0,<1.10.0)", "tomli", "typing-extensions (>=3.7.4.3)"] +uv = ["uv (>=0.1.18)"] +virtualenv = ["virtualenv (>=20.0.35)"] + [[package]] name = "cachecontrol" version = "0.14.0" @@ -740,6 +754,21 @@ files = [ numpy = "*" scipy = "*" +[[package]] +name = "cleo" +version = "2.1.0" +description = "Cleo allows you to create beautiful and testable command-line interfaces." +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "cleo-2.1.0-py3-none-any.whl", hash = "sha256:4a31bd4dd45695a64ee3c4758f583f134267c2bc518d8ae9a29cf237d009b07e"}, + {file = "cleo-2.1.0.tar.gz", hash = "sha256:0b2c880b5d13660a7ea651001fb4acb527696c01f15c9ee650f377aa543fd523"}, +] + +[package.dependencies] +crashtest = ">=0.4.1,<0.5.0" +rapidfuzz = ">=3.0.0,<4.0.0" + [[package]] name = "click" version = "8.1.7" @@ -990,6 +1019,17 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "crashtest" +version = "0.4.1" +description = "Manage Python errors with ease" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "crashtest-0.4.1-py3-none-any.whl", hash = "sha256:8d23eac5fa660409f57472e3851dab7ac18aba459a8d19cbbba86d3d5aecd2a5"}, + {file = "crashtest-0.4.1.tar.gz", hash = "sha256:80d7b1f316ebfbd429f648076d6275c877ba30ba48979de4191714a75266f0ce"}, +] + [[package]] name = "cryptography" version = "42.0.8" @@ -1338,6 +1378,13 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -1417,6 +1464,93 @@ mpmath = ">=1.2,<2.0" numpy = ">=1.21,<2.0" scipy = ">=1.7,<2.0" +[[package]] +name = "dulwich" +version = "0.21.7" +description = "Python Git Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "dulwich-0.21.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d4c0110798099bb7d36a110090f2688050703065448895c4f53ade808d889dd3"}, + {file = "dulwich-0.21.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2bc12697f0918bee324c18836053644035362bb3983dc1b210318f2fed1d7132"}, + {file = "dulwich-0.21.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:471305af74790827fcbafe330fc2e8bdcee4fb56ca1177c8c481b1c8f806c4a4"}, + {file = "dulwich-0.21.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d54c9d0e845be26f65f954dff13a1cd3f2b9739820c19064257b8fd7435ab263"}, + {file = "dulwich-0.21.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:12d61334a575474e707614f2e93d6ed4cdae9eb47214f9277076d9e5615171d3"}, + {file = "dulwich-0.21.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e274cebaf345f0b1e3b70197f2651de92b652386b68020cfd3bf61bc30f6eaaa"}, + {file = "dulwich-0.21.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:817822f970e196e757ae01281ecbf21369383285b9f4a83496312204cf889b8c"}, + {file = "dulwich-0.21.7-cp310-cp310-win32.whl", hash = "sha256:7836da3f4110ce684dcd53489015fb7fa94ed33c5276e3318b8b1cbcb5b71e08"}, + {file = "dulwich-0.21.7-cp310-cp310-win_amd64.whl", hash = "sha256:4a043b90958cec866b4edc6aef5fe3c2c96a664d0b357e1682a46f6c477273c4"}, + {file = "dulwich-0.21.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ce8db196e79c1f381469410d26fb1d8b89c6b87a4e7f00ff418c22a35121405c"}, + {file = "dulwich-0.21.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:62bfb26bdce869cd40be443dfd93143caea7089b165d2dcc33de40f6ac9d812a"}, + {file = "dulwich-0.21.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c01a735b9a171dcb634a97a3cec1b174cfbfa8e840156870384b633da0460f18"}, + {file = "dulwich-0.21.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa4d14767cf7a49c9231c2e52cb2a3e90d0c83f843eb6a2ca2b5d81d254cf6b9"}, + {file = "dulwich-0.21.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bca4b86e96d6ef18c5bc39828ea349efb5be2f9b1f6ac9863f90589bac1084d"}, + {file = "dulwich-0.21.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a7b5624b02ef808cdc62dabd47eb10cd4ac15e8ac6df9e2e88b6ac6b40133673"}, + {file = "dulwich-0.21.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c3a539b4696a42fbdb7412cb7b66a4d4d332761299d3613d90a642923c7560e1"}, + {file = "dulwich-0.21.7-cp311-cp311-win32.whl", hash = "sha256:675a612ce913081beb0f37b286891e795d905691dfccfb9bf73721dca6757cde"}, + {file = "dulwich-0.21.7-cp311-cp311-win_amd64.whl", hash = "sha256:460ba74bdb19f8d498786ae7776745875059b1178066208c0fd509792d7f7bfc"}, + {file = "dulwich-0.21.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4c51058ec4c0b45dc5189225b9e0c671b96ca9713c1daf71d622c13b0ab07681"}, + {file = "dulwich-0.21.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4bc4c5366eaf26dda3fdffe160a3b515666ed27c2419f1d483da285ac1411de0"}, + {file = "dulwich-0.21.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a0650ec77d89cb947e3e4bbd4841c96f74e52b4650830112c3057a8ca891dc2f"}, + {file = "dulwich-0.21.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f18f0a311fb7734b033a3101292b932158cade54b74d1c44db519e42825e5a2"}, + {file = "dulwich-0.21.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c589468e5c0cd84e97eb7ec209ab005a2cb69399e8c5861c3edfe38989ac3a8"}, + {file = "dulwich-0.21.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d62446797163317a397a10080c6397ffaaca51a7804c0120b334f8165736c56a"}, + {file = "dulwich-0.21.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e84cc606b1f581733df4350ca4070e6a8b30be3662bbb81a590b177d0c996c91"}, + {file = "dulwich-0.21.7-cp312-cp312-win32.whl", hash = "sha256:c3d1685f320907a52c40fd5890627945c51f3a5fa4bcfe10edb24fec79caadec"}, + {file = "dulwich-0.21.7-cp312-cp312-win_amd64.whl", hash = "sha256:6bd69921fdd813b7469a3c77bc75c1783cc1d8d72ab15a406598e5a3ba1a1503"}, + {file = "dulwich-0.21.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7d8ab29c660125db52106775caa1f8f7f77a69ed1fe8bc4b42bdf115731a25bf"}, + {file = "dulwich-0.21.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0d2e4485b98695bf95350ce9d38b1bb0aaac2c34ad00a0df789aa33c934469b"}, + {file = "dulwich-0.21.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e138d516baa6b5bafbe8f030eccc544d0d486d6819b82387fc0e285e62ef5261"}, + {file = "dulwich-0.21.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f34bf9b9fa9308376263fd9ac43143c7c09da9bc75037bb75c6c2423a151b92c"}, + {file = "dulwich-0.21.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2e2c66888207b71cd1daa2acb06d3984a6bc13787b837397a64117aa9fc5936a"}, + {file = "dulwich-0.21.7-cp37-cp37m-win32.whl", hash = "sha256:10893105c6566fc95bc2a67b61df7cc1e8f9126d02a1df6a8b2b82eb59db8ab9"}, + {file = "dulwich-0.21.7-cp37-cp37m-win_amd64.whl", hash = "sha256:460b3849d5c3d3818a80743b4f7a0094c893c559f678e56a02fff570b49a644a"}, + {file = "dulwich-0.21.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:74700e4c7d532877355743336c36f51b414d01e92ba7d304c4f8d9a5946dbc81"}, + {file = "dulwich-0.21.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c92e72c43c9e9e936b01a57167e0ea77d3fd2d82416edf9489faa87278a1cdf7"}, + {file = "dulwich-0.21.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d097e963eb6b9fa53266146471531ad9c6765bf390849230311514546ed64db2"}, + {file = "dulwich-0.21.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:808e8b9cc0aa9ac74870b49db4f9f39a52fb61694573f84b9c0613c928d4caf8"}, + {file = "dulwich-0.21.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1957b65f96e36c301e419d7adaadcff47647c30eb072468901bb683b1000bc5"}, + {file = "dulwich-0.21.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4b09bc3a64fb70132ec14326ecbe6e0555381108caff3496898962c4136a48c6"}, + {file = "dulwich-0.21.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5882e70b74ac3c736a42d3fdd4f5f2e6570637f59ad5d3e684760290b58f041"}, + {file = "dulwich-0.21.7-cp38-cp38-win32.whl", hash = "sha256:29bb5c1d70eba155ded41ed8a62be2f72edbb3c77b08f65b89c03976292f6d1b"}, + {file = "dulwich-0.21.7-cp38-cp38-win_amd64.whl", hash = "sha256:25c3ab8fb2e201ad2031ddd32e4c68b7c03cb34b24a5ff477b7a7dcef86372f5"}, + {file = "dulwich-0.21.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:8929c37986c83deb4eb500c766ee28b6670285b512402647ee02a857320e377c"}, + {file = "dulwich-0.21.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cc1e11be527ac06316539b57a7688bcb1b6a3e53933bc2f844397bc50734e9ae"}, + {file = "dulwich-0.21.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0fc3078a1ba04c588fabb0969d3530efd5cd1ce2cf248eefb6baf7cbc15fc285"}, + {file = "dulwich-0.21.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40dcbd29ba30ba2c5bfbab07a61a5f20095541d5ac66d813056c122244df4ac0"}, + {file = "dulwich-0.21.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8869fc8ec3dda743e03d06d698ad489b3705775fe62825e00fa95aa158097fc0"}, + {file = "dulwich-0.21.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d96ca5e0dde49376fbcb44f10eddb6c30284a87bd03bb577c59bb0a1f63903fa"}, + {file = "dulwich-0.21.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e0064363bd5e814359657ae32517fa8001e8573d9d040bd997908d488ab886ed"}, + {file = "dulwich-0.21.7-cp39-cp39-win32.whl", hash = "sha256:869eb7be48243e695673b07905d18b73d1054a85e1f6e298fe63ba2843bb2ca1"}, + {file = "dulwich-0.21.7-cp39-cp39-win_amd64.whl", hash = "sha256:404b8edeb3c3a86c47c0a498699fc064c93fa1f8bab2ffe919e8ab03eafaaad3"}, + {file = "dulwich-0.21.7-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e598d743c6c0548ebcd2baf94aa9c8bfacb787ea671eeeb5828cfbd7d56b552f"}, + {file = "dulwich-0.21.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4a2d76c96426e791556836ef43542b639def81be4f1d6d4322cd886c115eae1"}, + {file = "dulwich-0.21.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6c88acb60a1f4d31bd6d13bfba465853b3df940ee4a0f2a3d6c7a0778c705b7"}, + {file = "dulwich-0.21.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ecd315847dea406a4decfa39d388a2521e4e31acde3bd9c2609c989e817c6d62"}, + {file = "dulwich-0.21.7-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d05d3c781bc74e2c2a2a8f4e4e2ed693540fbe88e6ac36df81deac574a6dad99"}, + {file = "dulwich-0.21.7-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6de6f8de4a453fdbae8062a6faa652255d22a3d8bce0cd6d2d6701305c75f2b3"}, + {file = "dulwich-0.21.7-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e25953c7acbbe4e19650d0225af1c0c0e6882f8bddd2056f75c1cc2b109b88ad"}, + {file = "dulwich-0.21.7-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:4637cbd8ed1012f67e1068aaed19fcc8b649bcf3e9e26649826a303298c89b9d"}, + {file = "dulwich-0.21.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:858842b30ad6486aacaa607d60bab9c9a29e7c59dc2d9cb77ae5a94053878c08"}, + {file = "dulwich-0.21.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:739b191f61e1c4ce18ac7d520e7a7cbda00e182c3489552408237200ce8411ad"}, + {file = "dulwich-0.21.7-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:274c18ec3599a92a9b67abaf110e4f181a4f779ee1aaab9e23a72e89d71b2bd9"}, + {file = "dulwich-0.21.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:2590e9b431efa94fc356ae33b38f5e64f1834ec3a94a6ac3a64283b206d07aa3"}, + {file = "dulwich-0.21.7-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ed60d1f610ef6437586f7768254c2a93820ccbd4cfdac7d182cf2d6e615969bb"}, + {file = "dulwich-0.21.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8278835e168dd097089f9e53088c7a69c6ca0841aef580d9603eafe9aea8c358"}, + {file = "dulwich-0.21.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffc27fb063f740712e02b4d2f826aee8bbed737ed799962fef625e2ce56e2d29"}, + {file = "dulwich-0.21.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:61e3451bd3d3844f2dca53f131982553be4d1b1e1ebd9db701843dd76c4dba31"}, + {file = "dulwich-0.21.7.tar.gz", hash = "sha256:a9e9c66833cea580c3ac12927e4b9711985d76afca98da971405d414de60e968"}, +] + +[package.dependencies] +urllib3 = ">=1.25" + +[package.extras] +fastimport = ["fastimport"] +https = ["urllib3 (>=1.24.1)"] +paramiko = ["paramiko"] +pgp = ["gpg"] + [[package]] name = "dynamic-network-architectures" version = "0.3.1" @@ -1527,6 +1661,20 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "fastjsonschema" +version = "2.20.0" +description = "Fastest Python implementation of JSON schema" +optional = false +python-versions = "*" +files = [ + {file = "fastjsonschema-2.20.0-py3-none-any.whl", hash = "sha256:5875f0b0fa7a0043a91e93a9b8f793bcbbba9691e7fd83dca95c28ba26d21f0a"}, + {file = "fastjsonschema-2.20.0.tar.gz", hash = "sha256:3d48fc5300ee96f5d116f10fe6f28d938e6008f59a6a025c2649475b87f76a23"}, +] + +[package.extras] +devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] + [[package]] name = "fft-conv-pytorch" version = "1.2.0" @@ -2409,6 +2557,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "installer" +version = "0.7.0" +description = "A library for installing Python wheels." +optional = false +python-versions = ">=3.7" +files = [ + {file = "installer-0.7.0-py3-none-any.whl", hash = "sha256:05d1933f0a5ba7d8d6296bb6d5018e7c94fa473ceb10cf198a92ccea19c27b53"}, + {file = "installer-0.7.0.tar.gz", hash = "sha256:a26d3e3116289bb08216e0d0f7d925fcef0b0194eedfa0c944bcaaa106c4b631"}, +] + [[package]] name = "intel-openmp" version = "2021.4.0" @@ -2545,6 +2704,24 @@ files = [ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, ] +[[package]] +name = "jaraco-classes" +version = "3.4.0" +description = "Utility functions for Python class constructs" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jaraco.classes-3.4.0-py3-none-any.whl", hash = "sha256:f662826b6bed8cace05e7ff873ce0f9283b5c924470fe664fff1c2f00f581790"}, + {file = "jaraco.classes-3.4.0.tar.gz", hash = "sha256:47a024b51d0239c0dd8c8540c6c7f484be3b8fcf0b2d85c13825780d3b3f3acd"}, +] + +[package.dependencies] +more-itertools = "*" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"] + [[package]] name = "jedi" version = "0.19.1" @@ -2564,6 +2741,21 @@ docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alab qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] +[[package]] +name = "jeepney" +version = "0.8.0" +description = "Low-level, pure Python DBus protocol wrapper." +optional = false +python-versions = ">=3.7" +files = [ + {file = "jeepney-0.8.0-py3-none-any.whl", hash = "sha256:c0a454ad016ca575060802ee4d590dd912e35c122fa04e70306de3d076cce755"}, + {file = "jeepney-0.8.0.tar.gz", hash = "sha256:5efe48d255973902f6badc3ce55e2aa6c5c3b3bc642059ef3a91247bcfcc5806"}, +] + +[package.extras] +test = ["async-timeout", "pytest", "pytest-asyncio (>=0.17)", "pytest-trio", "testpath", "trio"] +trio = ["async_generator", "trio"] + [[package]] name = "jinja2" version = "3.1.4" @@ -2660,6 +2852,29 @@ files = [ {file = "keras-2.15.0.tar.gz", hash = "sha256:81871d298c064dc4ac6b58440fdae67bfcf47c8d7ad28580fab401834c06a575"}, ] +[[package]] +name = "keyring" +version = "24.3.1" +description = "Store and access your passwords safely." +optional = false +python-versions = ">=3.8" +files = [ + {file = "keyring-24.3.1-py3-none-any.whl", hash = "sha256:df38a4d7419a6a60fea5cef1e45a948a3e8430dd12ad88b0f423c5c143906218"}, + {file = "keyring-24.3.1.tar.gz", hash = "sha256:c3327b6ffafc0e8befbdb597cacdb4928ffe5c1212f7645f186e6d9957a898db"}, +] + +[package.dependencies] +importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""} +"jaraco.classes" = "*" +jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""} +pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} +SecretStorage = {version = ">=3.2", markers = "sys_platform == \"linux\""} + +[package.extras] +completion = ["shtab (>=1.1.0)"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -2810,6 +3025,7 @@ description = "Clang Python Bindings, mirrored from the official LLVM repo: http optional = false python-versions = "*" files = [ + {file = "libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a"}, {file = "libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5"}, {file = "libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8"}, {file = "libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b"}, @@ -3353,6 +3569,17 @@ tqdm = ["tqdm (>=4.47.0)"] transformers = ["transformers (>=4.36.0,<4.41.0)"] zarr = ["zarr"] +[[package]] +name = "more-itertools" +version = "10.5.0" +description = "More routines for operating on iterables, beyond itertools" +optional = false +python-versions = ">=3.8" +files = [ + {file = "more-itertools-10.5.0.tar.gz", hash = "sha256:5482bfef7849c25dc3c6dd53a6173ae4795da2a41a80faea6700d9f5846c5da6"}, + {file = "more_itertools-10.5.0-py3-none-any.whl", hash = "sha256:037b0d3203ce90cca8ab1defbbdac29d5f993fc20131f3664dc8d6acfa872aef"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -4514,6 +4741,20 @@ pyparsing = "*" docs = ["Sphinx (>=3.3.1)", "doc8 (>=0.8.1)", "sphinx-rtd-theme (>=0.5.0)"] testing = ["aboutcode-toolkit (>=6.0.0)", "black", "pytest (>=6,!=7.0.0)", "pytest-xdist (>=2)"] +[[package]] +name = "pkginfo" +version = "1.11.2" +description = "Query metadata from sdists / bdists / installed packages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pkginfo-1.11.2-py3-none-any.whl", hash = "sha256:9ec518eefccd159de7ed45386a6bb4c6ca5fa2cb3bd9b71154fae44f6f1b36a3"}, + {file = "pkginfo-1.11.2.tar.gz", hash = "sha256:c6bc916b8298d159e31f2c216e35ee5b86da7da18874f879798d0a1983537c86"}, +] + +[package.extras] +testing = ["pytest", "pytest-cov", "wheel"] + [[package]] name = "platformdirs" version = "4.2.2" @@ -4571,6 +4812,68 @@ files = [ {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"}, ] +[[package]] +name = "poetry" +version = "1.8.3" +description = "Python dependency management and packaging made easy." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "poetry-1.8.3-py3-none-any.whl", hash = "sha256:88191c69b08d06f9db671b793d68f40048e8904c0718404b63dcc2b5aec62d13"}, + {file = "poetry-1.8.3.tar.gz", hash = "sha256:67f4eb68288eab41e841cc71a00d26cf6bdda9533022d0189a145a34d0a35f48"}, +] + +[package.dependencies] +build = ">=1.0.3,<2.0.0" +cachecontrol = {version = ">=0.14.0,<0.15.0", extras = ["filecache"]} +cleo = ">=2.1.0,<3.0.0" +crashtest = ">=0.4.1,<0.5.0" +dulwich = ">=0.21.2,<0.22.0" +fastjsonschema = ">=2.18.0,<3.0.0" +installer = ">=0.7.0,<0.8.0" +keyring = ">=24.0.0,<25.0.0" +packaging = ">=23.1" +pexpect = ">=4.7.0,<5.0.0" +pkginfo = ">=1.10,<2.0" +platformdirs = ">=3.0.0,<5" +poetry-core = "1.9.0" +poetry-plugin-export = ">=1.6.0,<2.0.0" +pyproject-hooks = ">=1.0.0,<2.0.0" +requests = ">=2.26,<3.0" +requests-toolbelt = ">=1.0.0,<2.0.0" +shellingham = ">=1.5,<2.0" +tomli = {version = ">=2.0.1,<3.0.0", markers = "python_version < \"3.11\""} +tomlkit = ">=0.11.4,<1.0.0" +trove-classifiers = ">=2022.5.19" +virtualenv = ">=20.23.0,<21.0.0" +xattr = {version = ">=1.0.0,<2.0.0", markers = "sys_platform == \"darwin\""} + +[[package]] +name = "poetry-core" +version = "1.9.0" +description = "Poetry PEP 517 Build Backend" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "poetry_core-1.9.0-py3-none-any.whl", hash = "sha256:4e0c9c6ad8cf89956f03b308736d84ea6ddb44089d16f2adc94050108ec1f5a1"}, + {file = "poetry_core-1.9.0.tar.gz", hash = "sha256:fa7a4001eae8aa572ee84f35feb510b321bd652e5cf9293249d62853e1f935a2"}, +] + +[[package]] +name = "poetry-plugin-export" +version = "1.8.0" +description = "Poetry plugin to export the dependencies to various formats" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "poetry_plugin_export-1.8.0-py3-none-any.whl", hash = "sha256:adbe232cfa0cc04991ea3680c865cf748bff27593b9abcb1f35fb50ed7ba2c22"}, + {file = "poetry_plugin_export-1.8.0.tar.gz", hash = "sha256:1fa6168a85d59395d835ca564bc19862a7c76061e60c3e7dfaec70d50937fc61"}, +] + +[package.dependencies] +poetry = ">=1.8.0,<3.0.0" +poetry-core = ">=1.7.0,<3.0.0" + [[package]] name = "pre-commit" version = "3.8.0" @@ -5024,6 +5327,17 @@ files = [ flake8 = "5.0.4" tomli = {version = "*", markers = "python_version < \"3.11\""} +[[package]] +name = "pyproject-hooks" +version = "1.2.0" +description = "Wrappers to call pyproject.toml-based build backend hooks." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913"}, + {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, +] + [[package]] name = "pyre-extensions" version = "0.0.30" @@ -5121,6 +5435,13 @@ files = [ {file = "python_gdcm-3.0.24.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e537b9c3c582e0a19cd89791634da5ff48f1d61eeee633bf6e806c7bed10aba"}, {file = "python_gdcm-3.0.24.1-cp312-cp312-win32.whl", hash = "sha256:0fe3684df3be2abcf4ec6931e45f4caa8bd2aa60a84e65ddd612428f0fa39bcc"}, {file = "python_gdcm-3.0.24.1-cp312-cp312-win_amd64.whl", hash = "sha256:530e6b3f3904fd87c7e69ad0aee383f7a87213a8bf339314741ca64e3b6a3e94"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:45c5927af717f06f7ff8e0d6124746ef15e314954ae105d3a98410b6e327fb15"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ee88f9cdcd4f5e98da0e608d9692e96173ec8832d5aba1f0234db8af0835d9bd"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a89a8b1b666f2c3ebe6afbac3fcc3e256566ac8e55080ef03dd1ef7c98cd6b1"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5bc32309aeba3d3675ae0e5641aae03a8b9dc66ace058979debd2fea849dda7"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8ba54c9908ce117734b32340b2e5bbf96d5544109b1af468e2b99f2c6341a15"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-win32.whl", hash = "sha256:5920e63ac12b9a430108cd804ee2709fcefb9781bda6b6cb2f7d311a8dc61a04"}, + {file = "python_gdcm-3.0.24.1-cp313-cp313-win_amd64.whl", hash = "sha256:c586099268f0baf3cfda5851fa6115dc93394930da148fe3c350081b59e7551a"}, {file = "python_gdcm-3.0.24.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:de850cedc1dc58b8b5ee16c72cf67c5ee1021963a0bcbc0de58e162824afd6ff"}, {file = "python_gdcm-3.0.24.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8eddf3dc5d7793f3af407972f5185ec5d7edc4989ccaeafbf0d3e5e74f5ba88e"}, {file = "python_gdcm-3.0.24.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b9f5075cbd39fd4448ef83a298e5dbf0886dea8805d1de90df7de56d7930839"}, @@ -5178,6 +5499,17 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pywin32-ctypes" +version = "0.2.3" +description = "A (partial) reimplementation of pywin32 using ctypes/cffi" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pywin32-ctypes-0.2.3.tar.gz", hash = "sha256:d162dc04946d704503b2edc4d55f3dba5c1d539ead017afa00142c38b9885755"}, + {file = "pywin32_ctypes-0.2.3-py3-none-any.whl", hash = "sha256:8a1513379d709975552d202d942d9837758905c8d01eb82b8bcc30918929e7b8"}, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -5436,6 +5768,106 @@ files = [ [package.dependencies] six = "*" +[[package]] +name = "rapidfuzz" +version = "3.10.0" +description = "rapid fuzzy string matching" +optional = false +python-versions = ">=3.9" +files = [ + {file = "rapidfuzz-3.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:884453860de029380dded8f3c1918af2d8eb5adf8010261645c7e5c88c2b5428"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:718c9bd369288aca5fa929df6dbf66fdbe9768d90940a940c0b5cdc96ade4309"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a68e3724b7dab761c01816aaa64b0903734d999d5589daf97c14ef5cc0629a8e"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1af60988d47534246d9525f77288fdd9de652608a4842815d9018570b959acc6"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3084161fc3e963056232ef8d937449a2943852e07101f5a136c8f3cfa4119217"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6cd67d3d017296d98ff505529104299f78433e4b8af31b55003d901a62bbebe9"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b11a127ac590fc991e8a02c2d7e1ac86e8141c92f78546f18b5c904064a0552c"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:aadce42147fc09dcef1afa892485311e824c050352e1aa6e47f56b9b27af4cf0"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:b54853c2371bf0e38d67da379519deb6fbe70055efb32f6607081641af3dc752"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ce19887268e90ee81a3957eef5e46a70ecc000713796639f83828b950343f49e"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:f39a2a5ded23b9b9194ec45740dce57177b80f86c6d8eba953d3ff1a25c97766"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0ec338d5f4ad8d9339a88a08db5c23e7f7a52c2b2a10510c48a0cef1fb3f0ddc"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-win32.whl", hash = "sha256:56fd15ea8f4c948864fa5ebd9261c67cf7b89a1c517a0caef4df75446a7af18c"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:43dfc5e733808962a822ff6d9c29f3039a3cfb3620706f5953e17cfe4496724c"}, + {file = "rapidfuzz-3.10.0-cp310-cp310-win_arm64.whl", hash = "sha256:ae7966f205b5a7fde93b44ca8fed37c1c8539328d7f179b1197de34eceaceb5f"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bb0013795b40db5cf361e6f21ee7cda09627cf294977149b50e217d7fe9a2f03"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:69ef5b363afff7150a1fbe788007e307b9802a2eb6ad92ed51ab94e6ad2674c6"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c582c46b1bb0b19f1a5f4c1312f1b640c21d78c371a6615c34025b16ee56369b"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:288f6f6e7410cacb115fb851f3f18bf0e4231eb3f6cb5bd1cec0e7b25c4d039d"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9e29a13d2fd9be3e7d8c26c7ef4ba60b5bc7efbc9dbdf24454c7e9ebba31768"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea2da0459b951ee461bd4e02b8904890bd1c4263999d291c5cd01e6620177ad4"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:457827ba82261aa2ae6ac06a46d0043ab12ba7216b82d87ae1434ec0f29736d6"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5d350864269d56f51ab81ab750c9259ae5cad3152c0680baef143dcec92206a1"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a9b8f51e08c3f983d857c3889930af9ddecc768453822076683664772d87e374"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7f3a6aa6e70fc27e4ff5c479f13cc9fc26a56347610f5f8b50396a0d344c5f55"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:803f255f10d63420979b1909ef976e7d30dec42025c9b067fc1d2040cc365a7e"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2026651761bf83a0f31495cc0f70840d5c0d54388f41316e3f9cb51bd85e49a5"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-win32.whl", hash = "sha256:4df75b3ebbb8cfdb9bf8b213b168620b88fd92d0c16a8bc9f9234630b282db59"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:f9f0bbfb6787b97c51516f3ccf97737d504db5d239ad44527673b81f598b84ab"}, + {file = "rapidfuzz-3.10.0-cp311-cp311-win_arm64.whl", hash = "sha256:10fdad800441b9c97d471a937ba7d42625f1b530db05e572f1cb7d401d95c893"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7dc87073ba3a40dd65591a2100aa71602107443bf10770579ff9c8a3242edb94"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a425a0a868cf8e9c6e93e1cda4b758cdfd314bb9a4fc916c5742c934e3613480"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a86d5d1d75e61df060c1e56596b6b0a4422a929dff19cc3dbfd5eee762c86b61"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:34f213d59219a9c3ca14e94a825f585811a68ac56b4118b4dc388b5b14afc108"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:96ad46f5f56f70fab2be9e5f3165a21be58d633b90bf6e67fc52a856695e4bcf"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9178277f72d144a6c7704d7ae7fa15b7b86f0f0796f0e1049c7b4ef748a662ef"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76a35e9e19a7c883c422ffa378e9a04bc98cb3b29648c5831596401298ee51e6"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a6405d34c394c65e4f73a1d300c001f304f08e529d2ed6413b46ee3037956eb"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:bd393683129f446a75d8634306aed7e377627098a1286ff3af2a4f1736742820"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:b0445fa9880ead81f5a7d0efc0b9c977a947d8052c43519aceeaf56eabaf6843"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:c50bc308fa29767ed8f53a8d33b7633a9e14718ced038ed89d41b886e301da32"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e89605afebbd2d4b045bccfdc12a14b16fe8ccbae05f64b4b4c64a97dad1c891"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-win32.whl", hash = "sha256:2db9187f3acf3cd33424ecdbaad75414c298ecd1513470df7bda885dcb68cc15"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:50e3d0c72ea15391ba9531ead7f2068a67c5b18a6a365fef3127583aaadd1725"}, + {file = "rapidfuzz-3.10.0-cp312-cp312-win_arm64.whl", hash = "sha256:9eac95b4278bd53115903d89118a2c908398ee8bdfd977ae844f1bd2b02b917c"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fe5231e8afd069c742ac5b4f96344a0fe4aff52df8e53ef87faebf77f827822c"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:886882367dbc985f5736356105798f2ae6e794e671fc605476cbe2e73838a9bb"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b33e13e537e3afd1627d421a142a12bbbe601543558a391a6fae593356842f6e"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:094c26116d55bf9c53abd840d08422f20da78ec4c4723e5024322321caedca48"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:545fc04f2d592e4350f59deb0818886c1b444ffba3bec535b4fbb97191aaf769"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:916a6abf3632e592b937c3d04c00a6efadd8fd30539cdcd4e6e4d92be7ca5d90"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb6ec40cef63b1922083d33bfef2f91fc0b0bc07b5b09bfee0b0f1717d558292"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c77a7330dd15c7eb5fd3631dc646fc96327f98db8181138766bd14d3e905f0ba"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:949b5e9eeaa4ecb4c7e9c2a4689dddce60929dd1ff9c76a889cdbabe8bbf2171"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b5363932a5aab67010ae1a6205c567d1ef256fb333bc23c27582481606be480c"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:5dd6eec15b13329abe66cc241b484002ecb0e17d694491c944a22410a6a9e5e2"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:79e7f98525b60b3c14524e0a4e1fedf7654657b6e02eb25f1be897ab097706f3"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-win32.whl", hash = "sha256:d29d1b9857c65f8cb3a29270732e1591b9bacf89de9d13fa764f79f07d8f1fd2"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:fa9720e56663cc3649d62b4b5f3145e94b8f5611e8a8e1b46507777249d46aad"}, + {file = "rapidfuzz-3.10.0-cp313-cp313-win_arm64.whl", hash = "sha256:eda4c661e68dddd56c8fbfe1ca35e40dd2afd973f7ebb1605f4d151edc63dff8"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cffbc50e0767396ed483900900dd58ce4351bc0d40e64bced8694bd41864cc71"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c038b9939da3035afb6cb2f465f18163e8f070aba0482923ecff9443def67178"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca366c2e2a54e2f663f4529b189fdeb6e14d419b1c78b754ec1744f3c01070d4"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7c4c82b1689b23b1b5e6a603164ed2be41b6f6de292a698b98ba2381e889eb9d"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98f6ebe28831a482981ecfeedc8237047878424ad0c1add2c7f366ba44a20452"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4bd1a7676ee2a4c8e2f7f2550bece994f9f89e58afb96088964145a83af7408b"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec9139baa3f85b65adc700eafa03ed04995ca8533dd56c924f0e458ffec044ab"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:26de93e6495078b6af4c4d93a42ca067b16cc0e95699526c82ab7d1025b4d3bf"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f3a0bda83c18195c361b5500377d0767749f128564ca95b42c8849fd475bb327"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:63e4c175cbce8c3adc22dca5e6154588ae673f6c55374d156f3dac732c88d7de"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4dd3d8443970eaa02ab5ae45ce584b061f2799cd9f7e875190e2617440c1f9d4"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:e5ddb2388610799fc46abe389600625058f2a73867e63e20107c5ad5ffa57c47"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-win32.whl", hash = "sha256:2e9be5d05cd960914024412b5406fb75a82f8562f45912ff86255acbfdbfb78e"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-win_amd64.whl", hash = "sha256:47aca565a39c9a6067927871973ca827023e8b65ba6c5747f4c228c8d7ddc04f"}, + {file = "rapidfuzz-3.10.0-cp39-cp39-win_arm64.whl", hash = "sha256:b0732343cdc4273b5921268026dd7266f75466eb21873cb7635a200d9d9c3fac"}, + {file = "rapidfuzz-3.10.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:f744b5eb1469bf92dd143d36570d2bdbbdc88fe5cb0b5405e53dd34f479cbd8a"}, + {file = "rapidfuzz-3.10.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b67cc21a14327a0eb0f47bc3d7e59ec08031c7c55220ece672f9476e7a8068d3"}, + {file = "rapidfuzz-3.10.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fe5783676f0afba4a522c80b15e99dbf4e393c149ab610308a8ef1f04c6bcc8"}, + {file = "rapidfuzz-3.10.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d4688862f957c8629d557d084f20b2d803f8738b6c4066802a0b1cc472e088d9"}, + {file = "rapidfuzz-3.10.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20bd153aacc244e4c907d772c703fea82754c4db14f8aa64d75ff81b7b8ab92d"}, + {file = "rapidfuzz-3.10.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:50484d563f8bfa723c74c944b0bb15b9e054db9c889348c8c307abcbee75ab92"}, + {file = "rapidfuzz-3.10.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5897242d455461f2c5b82d7397b29341fd11e85bf3608a522177071044784ee8"}, + {file = "rapidfuzz-3.10.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:116c71a81e046ba56551d8ab68067ca7034d94b617545316d460a452c5c3c289"}, + {file = "rapidfuzz-3.10.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0a547e4350d1fa32624d3eab51eff8cf329f4cae110b4ea0402486b1da8be40"}, + {file = "rapidfuzz-3.10.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:399b9b79ccfcf50ca3bad7692bc098bb8eade88d7d5e15773b7f866c91156d0c"}, + {file = "rapidfuzz-3.10.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7947a425d1be3e744707ee58c6cb318b93a56e08f080722dcc0347e0b7a1bb9a"}, + {file = "rapidfuzz-3.10.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:94c48b4a2a4b1d22246f48e2b11cae01ec7d23f0c9123f8bb822839ad79d0a88"}, + {file = "rapidfuzz-3.10.0.tar.gz", hash = "sha256:6b62af27e65bb39276a66533655a2fa3c60a487b03935721c45b7809527979be"}, +] + +[package.extras] +all = ["numpy"] + [[package]] name = "rdflib" version = "7.0.0" @@ -5604,6 +6036,20 @@ requests = ">=2.0.0" [package.extras] rsa = ["oauthlib[signedtoken] (>=3.0.0)"] +[[package]] +name = "requests-toolbelt" +version = "1.0.0" +description = "A utility belt for advanced users of python-requests" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6"}, + {file = "requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06"}, +] + +[package.dependencies] +requests = ">=2.0.1,<3.0.0" + [[package]] name = "rich" version = "13.7.1" @@ -5973,6 +6419,21 @@ dev = ["flake8", "flit", "mypy", "pandas-stubs", "pre-commit", "pytest", "pytest docs = ["ipykernel", "nbconvert", "numpydoc", "pydata_sphinx_theme (==0.10.0rc2)", "pyyaml", "sphinx (<6.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-issues"] stats = ["scipy (>=1.7)", "statsmodels (>=0.12)"] +[[package]] +name = "secretstorage" +version = "3.3.3" +description = "Python bindings to FreeDesktop.org Secret Service API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "SecretStorage-3.3.3-py3-none-any.whl", hash = "sha256:f356e6628222568e3af06f2eba8df495efa13b3b63081dafd4f7d9a7b7bc9f99"}, + {file = "SecretStorage-3.3.3.tar.gz", hash = "sha256:2403533ef369eca6d2ba81718576c5e0f564d5cca1b58f73a8b23e7d4eeebd77"}, +] + +[package.dependencies] +cryptography = ">=2.0" +jeepney = ">=0.6" + [[package]] name = "semantic-version" version = "2.10.0" @@ -6751,6 +7212,17 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "tomlkit" +version = "0.13.2" +description = "Style preserving TOML library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, + {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, +] + [[package]] name = "torch" version = "2.3.1" @@ -7150,6 +7622,17 @@ build = ["cmake (>=3.20)", "lit"] tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +[[package]] +name = "trove-classifiers" +version = "2024.9.12" +description = "Canonical source for classifiers on PyPI (pypi.org)." +optional = false +python-versions = "*" +files = [ + {file = "trove_classifiers-2024.9.12-py3-none-any.whl", hash = "sha256:f88a27a892891c87c5f8bbdf110710ae9e0a4725ea8e0fb45f1bcadf088a491f"}, + {file = "trove_classifiers-2024.9.12.tar.gz", hash = "sha256:4b46b3e134a4d01999ac5bc6e528afcc10cc48f0f724f185f267e276005768f4"}, +] + [[package]] name = "typer" version = "0.9.4" @@ -7375,41 +7858,48 @@ testing = ["coverage (>=5.0)", "pytest", "pytest-cov"] [[package]] name = "wandb" -version = "0.16.6" +version = "0.18.3" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.16.6-py3-none-any.whl", hash = "sha256:5810019a3b981c796e98ea58557a7c380f18834e0c6bdaed15df115522e5616e"}, - {file = "wandb-0.16.6.tar.gz", hash = "sha256:86f491e3012d715e0d7d7421a4d6de41abef643b7403046261f962f3e512fe1c"}, + {file = "wandb-0.18.3-py3-none-any.whl", hash = "sha256:7da64f7da0ff7572439de10bfd45534e8811e71e78ac2ccc3b818f1c0f3a9aef"}, + {file = "wandb-0.18.3-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:6674d8a5c40c79065b9c7eb765136756d5ebc9457a5f9abc820a660fb23f8b67"}, + {file = "wandb-0.18.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:741f566e409a2684d3047e4cc25e8e914d78196b901190937b24b6abb8b052e5"}, + {file = "wandb-0.18.3-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:8be5e877570b693001c52dcc2089e48e6a4dcbf15f3adf5c9349f95148b59d58"}, + {file = "wandb-0.18.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d788852bd4739fa18de3918f309c3a955b5cef3247fae1c40df3a63af637e1a0"}, + {file = "wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab81424eb207d78239a8d69c90521a70074fb81e3709055484e43c76fe44dc08"}, + {file = "wandb-0.18.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2c91315b8b62423eae18577d66a4b4bb8e4341a7d5c849cb2963e3b3dff0bf6d"}, + {file = "wandb-0.18.3-py3-none-win32.whl", hash = "sha256:92a647dab783938ec87776a9fae8a13e72e6dad939c53e357cdea9d2570f0ad8"}, + {file = "wandb-0.18.3-py3-none-win_amd64.whl", hash = "sha256:29cac2cfa3124241fed22cfedc9a52e1500275ee9bbb0b428ce4bf63c4723bf0"}, + {file = "wandb-0.18.3.tar.gz", hash = "sha256:eb2574cea72bc908c6ce1b37edf7a889619e6e06e1b4714eecfe0662ded43c06"}, ] [package.dependencies] -appdirs = ">=1.4.3" -Click = ">=7.1,<8.0.0 || >8.0.0" +click = ">=7.1,<8.0.0 || >8.0.0" docker-pycreds = ">=0.4.0" -GitPython = ">=1.0.0,<3.1.29 || >3.1.29" -protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} +gitpython = ">=1.0.0,<3.1.29 || >3.1.29" +platformdirs = "*" +protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5.28.0 || >5.28.0,<6", markers = "python_version > \"3.9\" or sys_platform != \"linux\""} psutil = ">=5.0.0" -PyYAML = "*" +pyyaml = "*" requests = ">=2.0.0,<3" sentry-sdk = ">=1.0.0" setproctitle = "*" setuptools = "*" [package.extras] -async = ["httpx (>=0.23.0)"] aws = ["boto3"] azure = ["azure-identity", "azure-storage-blob"] gcp = ["google-cloud-storage"] -importers = ["filelock", "mlflow", "polars", "rich", "tenacity"] +importers = ["filelock", "mlflow", "polars (<=1.2.1)", "rich", "tenacity"] kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"] -launch = ["PyYAML (>=6.0.0)", "awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "tomli", "typing-extensions"] -media = ["bokeh", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit-pypi", "soundfile"] +launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "jsonschema", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "typing-extensions"] +media = ["bokeh", "imageio", "moviepy", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit", "soundfile"] models = ["cloudpickle"] perf = ["orjson"] -reports = ["pydantic (>=2.0.0)"] sweeps = ["sweeps (>=0.2.0)"] +workspaces = ["wandb-workspaces"] [[package]] name = "wcwidth" @@ -7547,6 +8037,79 @@ files = [ {file = "wrapt-1.14.1.tar.gz", hash = "sha256:380a85cf89e0e69b7cfbe2ea9f765f004ff419f34194018a6827ac0e3edfed4d"}, ] +[[package]] +name = "xattr" +version = "1.1.0" +description = "Python wrapper for extended filesystem attributes" +optional = false +python-versions = ">=3.8" +files = [ + {file = "xattr-1.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ef2fa0f85458736178fd3dcfeb09c3cf423f0843313e25391db2cfd1acec8888"}, + {file = "xattr-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ccab735d0632fe71f7d72e72adf886f45c18b7787430467ce0070207882cfe25"}, + {file = "xattr-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9013f290387f1ac90bccbb1926555ca9aef75651271098d99217284d9e010f7c"}, + {file = "xattr-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dcd5dfbcee73c7be057676ecb900cabb46c691aff4397bf48c579ffb30bb963"}, + {file = "xattr-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6480589c1dac7785d1f851347a32c4a97305937bf7b488b857fe8b28a25de9e9"}, + {file = "xattr-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08f61cbed52dc6f7c181455826a9ff1e375ad86f67dd9d5eb7663574abb32451"}, + {file = "xattr-1.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:918e1f83f2e8a072da2671eac710871ee5af337e9bf8554b5ce7f20cdb113186"}, + {file = "xattr-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0f06e0c1e4d06b4e0e49aaa1184b6f0e81c3758c2e8365597918054890763b53"}, + {file = "xattr-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:46a641ac038a9f53d2f696716147ca4dbd6a01998dc9cd4bc628801bc0df7f4d"}, + {file = "xattr-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7e4ca0956fd11679bb2e0c0d6b9cdc0f25470cc00d8da173bb7656cc9a9cf104"}, + {file = "xattr-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6881b120f9a4b36ccd8a28d933bc0f6e1de67218b6ce6e66874e0280fc006844"}, + {file = "xattr-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dab29d9288aa28e68a6f355ddfc3f0a7342b40c9012798829f3e7bd765e85c2c"}, + {file = "xattr-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0c80bbf55339c93770fc294b4b6586b5bf8e85ec00a4c2d585c33dbd84b5006"}, + {file = "xattr-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1418705f253b6b6a7224b69773842cac83fcbcd12870354b6e11dd1cd54630f"}, + {file = "xattr-1.1.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:687e7d18611ef8d84a6ecd8f4d1ab6757500c1302f4c2046ce0aa3585e13da3f"}, + {file = "xattr-1.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b6ceb9efe0657a982ccb8b8a2efe96b690891779584c901d2f920784e5d20ae3"}, + {file = "xattr-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b489b7916f239100956ea0b39c504f3c3a00258ba65677e4c8ba1bd0b5513446"}, + {file = "xattr-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0a9c431b0e66516a078125e9a273251d4b8e5ba84fe644b619f2725050d688a0"}, + {file = "xattr-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1a5921ea3313cc1c57f2f53b63ea8ca9a91e48f4cc7ebec057d2447ec82c7efe"}, + {file = "xattr-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f6ad2a7bd5e6cf71d4a862413234a067cf158ca0ae94a40d4b87b98b62808498"}, + {file = "xattr-1.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0683dae7609f7280b0c89774d00b5957e6ffcb181c6019c46632b389706b77e6"}, + {file = "xattr-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54cb15cd94e5ef8a0ef02309f1bf973ba0e13c11e87686e983f371948cfee6af"}, + {file = "xattr-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff6223a854229055e803c2ad0c0ea9a6da50c6be30d92c198cf5f9f28819a921"}, + {file = "xattr-1.1.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d44e8f955218638c9ab222eed21e9bd9ab430d296caf2176fb37abe69a714e5c"}, + {file = "xattr-1.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:caab2c2986c30f92301f12e9c50415d324412e8e6a739a52a603c3e6a54b3610"}, + {file = "xattr-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:d6eb7d5f281014cd44e2d847a9107491af1bf3087f5afeded75ed3e37ec87239"}, + {file = "xattr-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:47a3bdfe034b4fdb70e5941d97037405e3904accc28e10dbef6d1c9061fb6fd7"}, + {file = "xattr-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:00d2b415cf9d6a24112d019e721aa2a85652f7bbc9f3b9574b2d1cd8668eb491"}, + {file = "xattr-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:78b377832dd0ee408f9f121a354082c6346960f7b6b1480483ed0618b1912120"}, + {file = "xattr-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6461a43b585e5f2e049b39bcbfcb6391bfef3c5118231f1b15d10bdb89ef17fe"}, + {file = "xattr-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24d97f0d28f63695e3344ffdabca9fcc30c33e5c8ccc198c7524361a98d526f2"}, + {file = "xattr-1.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ad47d89968c9097900607457a0c89160b4771601d813e769f68263755516065"}, + {file = "xattr-1.1.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc53cab265f6e8449bd683d5ee3bc5a191e6dd940736f3de1a188e6da66b0653"}, + {file = "xattr-1.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cd11e917f5b89f2a0ad639d9875943806c6c9309a3dd02da5a3e8ef92db7bed9"}, + {file = "xattr-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:9c5a78c7558989492c4cb7242e490ffb03482437bf782967dfff114e44242343"}, + {file = "xattr-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cebcf8a303a44fbc439b68321408af7267507c0d8643229dbb107f6c132d389c"}, + {file = "xattr-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b0d73150f2f9655b4da01c2369eb33a294b7f9d56eccb089819eafdbeb99f896"}, + {file = "xattr-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:793c01deaadac50926c0e1481702133260c7cb5e62116762f6fe1543d07b826f"}, + {file = "xattr-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e189e440bcd04ccaad0474720abee6ee64890823ec0db361fb0a4fb5e843a1bf"}, + {file = "xattr-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afacebbc1fa519f41728f8746a92da891c7755e6745164bd0d5739face318e86"}, + {file = "xattr-1.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b1664edf003153ac8d1911e83a0fc60db1b1b374ee8ac943f215f93754a1102"}, + {file = "xattr-1.1.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dda2684228798e937a7c29b0e1c7ef3d70e2b85390a69b42a1c61b2039ba81de"}, + {file = "xattr-1.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b735ac2625a4fc2c9343b19f806793db6494336338537d2911c8ee4c390dda46"}, + {file = "xattr-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:fa6a7af7a4ada43f15ccc58b6f9adcdbff4c36ba040013d2681e589e07ae280a"}, + {file = "xattr-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1059b2f726e2702c8bbf9bbf369acfc042202a4cc576c2dec6791234ad5e948"}, + {file = "xattr-1.1.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e2255f36ebf2cb2dbf772a7437ad870836b7396e60517211834cf66ce678b595"}, + {file = "xattr-1.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dba4f80b9855cc98513ddf22b7ad8551bc448c70d3147799ea4f6c0b758fb466"}, + {file = "xattr-1.1.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4cb70c16e7c3ae6ba0ab6c6835c8448c61d8caf43ea63b813af1f4dbe83dd156"}, + {file = "xattr-1.1.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83652910ef6a368b77b00825ad67815e5c92bfab551a848ca66e9981d14a7519"}, + {file = "xattr-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7a92aff66c43fa3e44cbeab7cbeee66266c91178a0f595e044bf3ce51485743b"}, + {file = "xattr-1.1.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d4f71b673339aeaae1f6ea9ef8ea6c9643c8cd0df5003b9a0eaa75403e2e06c"}, + {file = "xattr-1.1.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a20de1c47b5cd7b47da61799a3b34e11e5815d716299351f82a88627a43f9a96"}, + {file = "xattr-1.1.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23705c7079b05761ff2fa778ad17396e7599c8759401abc05b312dfb3bc99f69"}, + {file = "xattr-1.1.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:27272afeba8422f2a9d27e1080a9a7b807394e88cce73db9ed8d2dde3afcfb87"}, + {file = "xattr-1.1.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd43978966de3baf4aea367c99ffa102b289d6c2ea5f3d9ce34a203dc2f2ab73"}, + {file = "xattr-1.1.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ded771eaf27bb4eb3c64c0d09866460ee8801d81dc21097269cf495b3cac8657"}, + {file = "xattr-1.1.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ca300c0acca4f0cddd2332bb860ef58e1465d376364f0e72a1823fdd58e90d"}, + {file = "xattr-1.1.0.tar.gz", hash = "sha256:fecbf3b05043ed3487a28190dec3e4c4d879b2fcec0e30bafd8ec5d4b6043630"}, +] + +[package.dependencies] +cffi = ">=1.16.0" + +[package.extras] +test = ["pytest"] + [[package]] name = "xmltodict" version = "0.13.0" @@ -7811,4 +8374,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.11" -content-hash = "45445e9560514a9eafd900c4791f5825d55cb46aa9787be556119280bab35485" +content-hash = "ca85e2a2e883f093f03610ba735bf7641c675dd0c227a28a7d39eec551665f05" diff --git a/pyproject.toml b/pyproject.toml index 46d28bab6..43ae4a9a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ torchvision = "^0.18.0" torchinfo = "^1.8.0" ipykernel = "^6.25.1" scikit-learn = "1.5.0" # Pin as it was causing issues with nnunet -wandb = "^0.16.1" +wandb = "^0.18.0" +poetry = "^1.8.3" [tool.poetry.group.dev-local.dependencies] torchtext = "^0.14.1" diff --git a/research/ag_news/dynamic_layer_exchange/client.py b/research/ag_news/dynamic_layer_exchange/client.py index 39897aae8..130c941b6 100644 --- a/research/ag_news/dynamic_layer_exchange/client.py +++ b/research/ag_news/dynamic_layer_exchange/client.py @@ -21,7 +21,7 @@ from fl4health.parameter_exchange.layer_exchanger import DynamicLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.parameter_exchange.parameter_selection_criteria import LayerSelectionFunctionConstructor -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -39,7 +39,7 @@ def __init__( norm_threshold: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, store_initial_model: bool = True, ) -> None: super().__init__( @@ -48,7 +48,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, store_initial_model=store_initial_model, ) assert 0 < exchange_percentage <= 1.0 and norm_threshold > 0 @@ -195,7 +195,9 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ # Note that the server must be started with the same grpc_max_message_length. Otherwise communication # of larger messages would still be blocked. fl.client.start_client( - server_address=args.server_address, client=client.to_client(), grpc_max_message_length=1600000000 + server_address=args.server_address, + client=client.to_client(), + grpc_max_message_length=1600000000, ) client.shutdown() diff --git a/research/ag_news/sparse_tensor_exchange/client.py b/research/ag_news/sparse_tensor_exchange/client.py index cc86dac5f..003ca420e 100644 --- a/research/ag_news/sparse_tensor_exchange/client.py +++ b/research/ag_news/sparse_tensor_exchange/client.py @@ -21,7 +21,7 @@ from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.parameter_exchange.parameter_selection_criteria import largest_final_magnitude_scores from fl4health.parameter_exchange.sparse_coo_parameter_exchanger import SparseCooParameterExchanger -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -38,7 +38,7 @@ def __init__( sparsity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, store_initial_model: bool = True, ) -> None: super().__init__( @@ -47,7 +47,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, store_initial_model=store_initial_model, ) self.sparsity_level = sparsity_level @@ -163,7 +163,9 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ # Note that the server must be started with the same grpc_max_message_length. Otherwise communication # of larger messages would still be blocked. fl.client.start_client( - server_address=args.server_address, client=client.to_client(), grpc_max_message_length=1600000000 + server_address=args.server_address, + client=client.to_client(), + grpc_max_message_length=1600000000, ) client.shutdown() diff --git a/research/cifar10/fedavg/server.py b/research/cifar10/fedavg/server.py index e3176f630..a19bbae2d 100644 --- a/research/cifar10/fedavg/server.py +++ b/research/cifar10/fedavg/server.py @@ -69,7 +69,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, initial_parameters=get_all_model_parameters(model), ) - server = FlServerWithCheckpointing(client_manager, parameter_exchanger, model, None, strategy, checkpointer) + server = FlServerWithCheckpointing(client_manager, parameter_exchanger, model, strategy, None, checkpointer) fl.server.start_server( server=server, diff --git a/research/picai/fedavg/client.py b/research/picai/fedavg/client.py index b3c31bdaf..3fedbbf80 100644 --- a/research/picai/fedavg/client.py +++ b/research/picai/fedavg/client.py @@ -15,7 +15,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.basic_client import BasicClient -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric, TorchMetric from research.picai.data.data_utils import ( @@ -39,7 +39,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, intermediate_client_state_dir: Optional[Path] = None, overviews_dir: Path = Path("./"), @@ -51,7 +51,7 @@ def __init__( device=device, loss_meter_type=loss_meter_type, checkpointer=checkpointer, - metrics_reporter=metrics_reporter, + reporters=reporters, progress_bar=progress_bar, intermediate_client_state_dir=intermediate_client_state_dir, ) @@ -115,10 +115,18 @@ def get_optimizer(self, config: Config) -> Optimizer: if __name__ == "__main__": parser = argparse.ArgumentParser(description="FL Client Main") parser.add_argument( - "--artifact_dir", action="store", type=str, help="Path to dir to store run artifacts", required=True + "--artifact_dir", + action="store", + type=str, + help="Path to dir to store run artifacts", + required=True, ) parser.add_argument( - "--base_dir", action="store", type=str, help="Path to base directory containing PICAI dataset", required=True + "--base_dir", + action="store", + type=str, + help="Path to base directory containing PICAI dataset", + required=True, ) parser.add_argument( "--overviews_dir", @@ -134,7 +142,12 @@ def get_optimizer(self, config: Config) -> Optimizer: help="Server Address for the clients to communicate with the server through", default="0.0.0.0:8080", ) - parser.add_argument("--data_partition", type=int, help="The data partition to train the client on", default=0) + parser.add_argument( + "--data_partition", + type=int, + help="The data partition to train the client on", + default=0, + ) args = parser.parse_args() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -143,7 +156,8 @@ def get_optimizer(self, config: Config) -> Optimizer: metrics = [ TorchMetric( - name="MLAP", metric=MultilabelAveragePrecision(average="macro", num_labels=2, thresholds=3).to(DEVICE) + name="MLAP", + metric=MultilabelAveragePrecision(average="macro", num_labels=2, thresholds=3).to(DEVICE), ) ] diff --git a/research/picai/reporting/client.py b/research/picai/reporting/client.py new file mode 100644 index 000000000..27e8eea9f --- /dev/null +++ b/research/picai/reporting/client.py @@ -0,0 +1,62 @@ +import argparse +from pathlib import Path +from typing import Optional, Tuple + +import flwr as fl +import torch +import torch.nn as nn +from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from examples.models.cnn_model import Net +from fl4health.clients.basic_client import BasicClient +from fl4health.reporting import WandBReporter +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data +from fl4health.utils.metrics import Accuracy + + +class CifarClient(BasicClient): + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) + return train_loader, val_loader + + def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + batch_size = narrow_dict_type(config, "batch_size", int) + test_loader, _ = load_cifar10_test_data(self.data_path, batch_size) + return test_loader + + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() + + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + + def get_model(self, config: Config) -> nn.Module: + return Net().to(self.device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Client Main") + parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset") + args = parser.parse_args() + + DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + data_path = Path(args.dataset_path) + reporter = WandBReporter( + "batch", + project="test", + entity="haider-vector-collab", + name="CIFAR Client", + tags=["debug"], + group="experiment1", + config={"dataset": "CIFAR"}, + job_type="client", + ) + # reporter = JsonReporter() + client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE, reporters=[reporter]) + fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) + client.shutdown() diff --git a/research/picai/reporting/config.yaml b/research/picai/reporting/config.yaml new file mode 100644 index 000000000..d2186be76 --- /dev/null +++ b/research/picai/reporting/config.yaml @@ -0,0 +1,10 @@ +# Parameters that describe server +n_server_rounds: 3 # The number of rounds to run FL + +# Parameters that describe clients +n_clients: 2 # The number of clients in the FL experiment +local_epochs: 3 # The number of epochs to complete for client +batch_size: 32 # The batch size for client training + +# checkpointing +checkpoint_path: "research/picai/reporting" diff --git a/research/picai/reporting/server.py b/research/picai/reporting/server.py new file mode 100644 index 000000000..ee1b32540 --- /dev/null +++ b/research/picai/reporting/server.py @@ -0,0 +1,107 @@ +import argparse +from functools import partial +from typing import Any, Dict, Optional + +import flwr as fl +from flwr.common.typing import Config +from flwr.server.client_manager import SimpleClientManager +from flwr.server.strategy import FedAvg + +from examples.models.cnn_model import Net +from examples.utils.functions import make_dict_with_epochs_or_steps +from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger +from fl4health.reporting import WandBReporter +from fl4health.server.base_server import FlServerWithCheckpointing +from fl4health.utils.config import load_config +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from fl4health.utils.parameter_extraction import get_all_model_parameters + + +def fit_config( + batch_size: int, + current_server_round: int, + local_epochs: Optional[int] = None, + local_steps: Optional[int] = None, +) -> Config: + return { + **make_dict_with_epochs_or_steps(local_epochs, local_steps), + "batch_size": batch_size, + "current_server_round": current_server_round, + } + + +def main(config: Dict[str, Any]) -> None: + # This function will be used to produce a config that is sent to each client to initialize their own environment + fit_config_fn = partial( + fit_config, + config["batch_size"], + local_epochs=config.get("local_epochs"), + local_steps=config.get("local_steps"), + ) + + # Initializing the model on the server side + model = Net() + # To facilitate checkpointing + parameter_exchanger = FullParameterExchanger() + checkpointers = [ + BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl"), + LatestTorchCheckpointer(config["checkpoint_path"], "latest_model.pkl"), + ] + + # Server performs simple FedAveraging as its server-side optimization strategy + strategy = FedAvg( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + # Server waits for min_available_clients before starting FL rounds + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + # We use the same fit config function, as nothing changes for eval + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=get_all_model_parameters(model), + ) + config.update({"strategy": "FedAvg"}) + reporter = WandBReporter( + "round", + project="test", + entity="haider-vector-collab", + tags=["debug"], + name="FedAvgServer", + config=config, + group="experiment1", + job_type="server", + ) + # reporter = JsonReporter() + server = FlServerWithCheckpointing( + client_manager=SimpleClientManager(), + parameter_exchanger=parameter_exchanger, + model=model, + reporters=[reporter], + strategy=strategy, + checkpointer=checkpointers, + ) + + fl.server.start_server( + server=server, + server_address="0.0.0.0:8080", + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FL Server Main") + parser.add_argument( + "--config_path", + action="store", + type=str, + help="Path to configuration file.", + default="examples/basic_example/config.yaml", + ) + args = parser.parse_args() + + config = load_config(args.config_path) + + main(config) diff --git a/tests/clients/test_basic_client.py b/tests/clients/test_basic_client.py index 34cf4c0aa..178bfe248 100644 --- a/tests/clients/test_basic_client.py +++ b/tests/clients/test_basic_client.py @@ -1,4 +1,5 @@ import datetime +from collections.abc import Sequence from pathlib import Path from typing import Dict, Optional from unittest.mock import MagicMock @@ -9,29 +10,38 @@ from freezegun import freeze_time from fl4health.clients.basic_client import BasicClient, LoggingMode +from fl4health.reporting import JsonReporter +from fl4health.reporting.base_reporter import BaseReporter +from tests.test_utils.assert_metrics_dict import assert_metrics_dict freezegun.configure(extend_ignore_list=["transformers"]) # type: ignore @freeze_time("2012-12-12 12:12:12") -def test_metrics_reporter_setup_client() -> None: - fl_client = MockBasicClient() +def test_json_reporter_setup_client() -> None: + reporter = JsonReporter() + fl_client = MockBasicClient(reporters=[reporter]) fl_client.setup_client({}) - assert fl_client.metrics_reporter.metrics == { - "type": "client", - "initialized": datetime.datetime(2012, 12, 12, 12, 12, 12), + metric_dict = { + "host_type": "client", + "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), } + errors = assert_metrics_dict(metric_dict, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" @freeze_time("2012-12-12 12:12:12") -def test_metrics_reporter_shutdown() -> None: - fl_client = MockBasicClient() +def test_json_reporter_shutdown() -> None: + reporter = JsonReporter() + fl_client = MockBasicClient(reporters=[reporter]) fl_client.shutdown() - assert fl_client.metrics_reporter.metrics == { - "shutdown": datetime.datetime(2012, 12, 12, 12, 12, 12), + metric_dict = { + "shutdown": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), } + errors = assert_metrics_dict(metric_dict, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" @freeze_time("2012-12-12 12:12:12") @@ -39,23 +49,26 @@ def test_metrics_reporter_fit() -> None: test_current_server_round = 2 test_loss_dict = {"test_loss": 123.123} test_metrics: Dict[str, Scalar] = {"test_metric": 1234} + reporter = JsonReporter() - fl_client = MockBasicClient(loss_dict=test_loss_dict, metrics=test_metrics) + fl_client = MockBasicClient(loss_dict=test_loss_dict, metrics=test_metrics, reporters=[reporter]) fl_client.fit([], {"current_server_round": test_current_server_round, "local_epochs": 0}) - - assert fl_client.metrics_reporter.metrics == { - "type": "client", - "initialized": datetime.datetime(2012, 12, 12, 12, 12, 12), + metric_dict = { + "host_type": "client", + "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), "rounds": { test_current_server_round: { - "fit_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "loss_dict": test_loss_dict, + "round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "fit_losses": test_loss_dict, "fit_metrics": test_metrics, - "fit_end": datetime.datetime(2012, 12, 12, 12, 12, 12), + "round": test_current_server_round, }, }, } + errors = assert_metrics_dict(metric_dict, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" + @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_evaluate() -> None: @@ -69,22 +82,29 @@ def test_metrics_reporter_evaluate() -> None: "test - loss": 123.123, "test - num_examples": 0, } - - fl_client = MockBasicClient(loss=test_loss, metrics=test_metrics, test_set_metrics=test_metrics_testing) + reporter = JsonReporter() + fl_client = MockBasicClient( + loss=test_loss, + metrics=test_metrics, + test_set_metrics=test_metrics_testing, + reporters=[reporter], + ) fl_client.evaluate([], {"current_server_round": test_current_server_round, "local_epochs": 0}) - assert fl_client.metrics_reporter.metrics == { - "type": "client", - "initialized": datetime.datetime(2012, 12, 12, 12, 12, 12), + metric_dict = { + "host_type": "client", + "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), "rounds": { test_current_server_round: { - "evaluate_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "loss": test_loss, - "evaluate_metrics": test_metrics_final, - "evaluate_end": datetime.datetime(2012, 12, 12, 12, 12, 12), + "eval_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "eval_loss": test_loss, + "eval_metrics": test_metrics_final, + "eval_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), }, }, } + errors = assert_metrics_dict(metric_dict, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" def test_evaluate_after_fit_enabled() -> None: @@ -116,8 +136,9 @@ def __init__( metrics: Optional[Dict[str, Scalar]] = None, test_set_metrics: Optional[Dict[str, Scalar]] = None, loss: Optional[float] = 0, + reporters: Sequence[BaseReporter] | None = None, ): - super().__init__(Path(""), [], torch.device(0)) + super().__init__(Path(""), [], torch.device(0), reporters=reporters) self.mock_loss_dict = loss_dict if self.mock_loss_dict is None: diff --git a/tests/clients/test_evaluate_client.py b/tests/clients/test_evaluate_client.py index d2f935b4a..06f91fd04 100644 --- a/tests/clients/test_evaluate_client.py +++ b/tests/clients/test_evaluate_client.py @@ -1,5 +1,6 @@ import datetime import math +from collections.abc import Sequence from pathlib import Path from typing import Dict, Optional, Union from unittest.mock import MagicMock @@ -10,16 +11,26 @@ from freezegun import freeze_time from fl4health.clients.evaluate_client import EvaluateClient +from fl4health.reporting import JsonReporter +from fl4health.reporting.base_reporter import BaseReporter from tests.clients.fixtures import get_basic_client, get_evaluation_client # noqa +from tests.test_utils.assert_metrics_dict import assert_metrics_dict from tests.test_utils.models_for_test import SingleLayerWithSeed def test_evaluate_merge_metrics(caplog: pytest.LogCaptureFixture) -> None: - global_metrics: Dict[str, Scalar] = {"global_metric_1": 0.22, "local_metric_2": 0.11} + global_metrics: Dict[str, Scalar] = { + "global_metric_1": 0.22, + "local_metric_2": 0.11, + } local_metrics: Dict[str, Scalar] = {"local_metric_1": 0.1, "local_metric_2": 0.99} merged_metrics = EvaluateClient.merge_metrics(global_metrics, local_metrics) # Test merge is good, local metrics are folded in last, so they take precedence when overlap exists - assert merged_metrics == {"global_metric_1": 0.22, "local_metric_1": 0.1, "local_metric_2": 0.99} + assert merged_metrics == { + "global_metric_1": 0.22, + "local_metric_1": 0.1, + "local_metric_2": 0.99, + } # Test that we are warned about duplicate metric keys assert "metric_name: local_metric_2 already exists in dictionary." in caplog.text @@ -31,7 +42,9 @@ def test_evaluate_merge_metrics(caplog: pytest.LogCaptureFixture) -> None: @pytest.mark.parametrize("model", [SingleLayerWithSeed()]) -def test_evaluating_identical_global_and_local_models(get_evaluation_client: EvaluateClient) -> None: # noqa +def test_evaluating_identical_global_and_local_models( + get_evaluation_client: EvaluateClient, # noqa +) -> None: evaluate_client = get_evaluation_client loss, metrics = evaluate_client.validate() @@ -44,7 +57,9 @@ def test_evaluating_identical_global_and_local_models(get_evaluation_client: Eva @pytest.mark.parametrize("model", [SingleLayerWithSeed()]) -def test_evaluating_different_global_and_local_models(get_evaluation_client: EvaluateClient) -> None: # noqa +def test_evaluating_different_global_and_local_models( + get_evaluation_client: EvaluateClient, # noqa +) -> None: evaluate_client = get_evaluation_client evaluate_client.global_model = SingleLayerWithSeed(seed=37) @@ -82,13 +97,16 @@ def test_evaluating_only_global_models(get_evaluation_client: EvaluateClient) -> @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_setup_client() -> None: - evaluate_client = MockEvaluateClient() + reporter = JsonReporter() + evaluate_client = MockEvaluateClient(reporters=[reporter]) evaluate_client.setup_client({}) - assert evaluate_client.metrics_reporter.metrics == { - "type": "client", - "initialized": datetime.datetime(2012, 12, 12, 12, 12, 12), + metrics_to_assert = { + "host_type": "client", + "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), } + errors = assert_metrics_dict(metrics_to_assert, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" @freeze_time("2012-12-12 12:12:12") @@ -96,22 +114,33 @@ def test_metrics_reporter_evaluate() -> None: test_loss = 123.123 test_metrics: Dict[str, Union[bool, bytes, float, int, str]] = {"test_metric": 1234} - evaluate_client = MockEvaluateClient(loss=test_loss, metrics=test_metrics) + reporter = JsonReporter() + evaluate_client = MockEvaluateClient(loss=test_loss, metrics=test_metrics, reporters=[reporter]) evaluate_client.evaluate([], {}) - - assert evaluate_client.metrics_reporter.metrics == { - "type": "client", - "initialized": datetime.datetime(2012, 12, 12, 12, 12, 12), - "evaluate_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "loss": test_loss, - "metrics": test_metrics, - "evaluate_end": datetime.datetime(2012, 12, 12, 12, 12, 12), + print(reporter.metrics) + metric_dict = { + "host_type": "client", + "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "rounds": { + 0: { + "eval_metrics": test_metrics, + "eval_loss": test_loss, + "eval_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + } + }, } + errors = assert_metrics_dict(metric_dict, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" class MockEvaluateClient(EvaluateClient): - def __init__(self, loss: Optional[float] = None, metrics: Optional[Dict[str, Scalar]] = None): - super().__init__(Path(""), [], torch.device(0)) + def __init__( + self, + loss: Optional[float] = None, + metrics: Optional[Dict[str, Scalar]] = None, + reporters: Sequence[BaseReporter] | None = None, + ): + super().__init__(Path(""), [], torch.device(0), reporters=reporters) # Mocking methods self.get_data_loader = MagicMock() # type: ignore diff --git a/tests/losses/test_mkmmd_loss.py b/tests/losses/test_mkmmd_loss.py index ddc0ab0b5..55263375f 100644 --- a/tests/losses/test_mkmmd_loss.py +++ b/tests/losses/test_mkmmd_loss.py @@ -618,4 +618,4 @@ def test_optimizer_betas_in_non_degenerate_case() -> None: one_hot_betas = torch.zeros_like(betas_local) one_hot_betas[1, 0] = 1 - assert torch.all(betas_local.eq(one_hot_betas)) + assert torch.allclose(one_hot_betas, betas_local, rtol=0.0, atol=1e-6) diff --git a/tests/reporting/test_metrics.py b/tests/reporting/test_json_reporter.py similarity index 59% rename from tests/reporting/test_metrics.py rename to tests/reporting/test_json_reporter.py index 4cb4ebb8f..bda99d6e3 100644 --- a/tests/reporting/test_metrics.py +++ b/tests/reporting/test_json_reporter.py @@ -3,47 +3,48 @@ from pathlib import Path from unittest.mock import Mock, patch -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting import JsonReporter -@patch("fl4health.reporting.metrics.uuid") -def test_metrics_reporter_init(mock_uuid: Mock) -> None: +@patch("fl4health.reporting.json_reporter.uuid") +def test_json_reporter_init(mock_uuid: Mock) -> None: test_uuid = "test uuid" mock_uuid.uuid4.return_value = test_uuid - metrics_reporter = MetricsReporter() + metrics_reporter = JsonReporter() + metrics_reporter.initialize() assert metrics_reporter.run_id == test_uuid assert metrics_reporter.metrics == {} -def test_metrics_reporter_add_to_metrics() -> None: +def test_json_reporter_add_summary_data() -> None: test_data_1 = {"test data 1": 123} test_data_2 = {"test data 2": 456} - metrics_reporter = MetricsReporter() + metrics_reporter = JsonReporter() - metrics_reporter.add_to_metrics(test_data_1) + metrics_reporter.report(test_data_1) assert metrics_reporter.metrics == test_data_1 - metrics_reporter.add_to_metrics(test_data_2) + metrics_reporter.report(test_data_2) assert metrics_reporter.metrics == {**test_data_1, **test_data_2} -def test_metrics_reporter_add_to_metrics_at_round() -> None: +def test_metrics_reporter_add_round_data() -> None: test_data_1 = {"test data 1": 123} test_data_2 = {"test data 2": 456} - metrics_reporter = MetricsReporter() + metrics_reporter = JsonReporter() - metrics_reporter.add_to_metrics_at_round(2, test_data_1) + metrics_reporter.report(test_data_1, round=2) assert metrics_reporter.metrics == { "rounds": { 2: test_data_1, }, } - metrics_reporter.add_to_metrics_at_round(4, test_data_1) + metrics_reporter.report(test_data_1, round=4) assert metrics_reporter.metrics == { "rounds": { 2: test_data_1, @@ -51,7 +52,7 @@ def test_metrics_reporter_add_to_metrics_at_round() -> None: }, } - metrics_reporter.add_to_metrics_at_round(2, test_data_2) + metrics_reporter.report(test_data_2, round=2) assert metrics_reporter.metrics == { "rounds": { 2: {**test_data_1, **test_data_2}, @@ -63,14 +64,14 @@ def test_metrics_reporter_add_to_metrics_at_round() -> None: def test_metrics_reporter_dump(tmp_path: Path) -> None: test_data_1 = {"test data 1": 123} test_data_2 = {"test data 2": 456} - test_date = datetime.datetime.now() + test_date = str(datetime.datetime.now()) test_run_id = "test" test_json_file_name = f"{tmp_path}/{test_run_id}.json" - metrics_reporter = MetricsReporter(run_id=test_run_id, output_folder=tmp_path) - metrics_reporter.add_to_metrics(test_data_1) - metrics_reporter.add_to_metrics({"date": test_date}) - metrics_reporter.add_to_metrics_at_round(2, test_data_2) + metrics_reporter = JsonReporter(run_id=test_run_id, output_folder=tmp_path) + metrics_reporter.report(test_data_1) + metrics_reporter.report({"date": test_date}) + metrics_reporter.report(test_data_2, round=2) metrics_reporter.dump() with open(test_json_file_name, "r") as file: diff --git a/tests/reporting/test_wandb_reporter.py b/tests/reporting/test_wandb_reporter.py deleted file mode 100644 index 1a9a27e6f..000000000 --- a/tests/reporting/test_wandb_reporter.py +++ /dev/null @@ -1,20 +0,0 @@ -from pathlib import Path -from unittest import mock - -from fl4health.reporting.fl_wandb import ClientWandBReporter, ServerWandBReporter - - -def test_server_wandb_reporter(tmp_path: Path) -> None: - with mock.patch.object(ServerWandBReporter, "__init__", lambda a, b, c, d, e, f, g, h: None): - reporter = ServerWandBReporter("", "", "", "", None, None, {}) - log_dir = str(tmp_path.joinpath("fl_wandb_logs")) - reporter._maybe_create_local_log_directory(log_dir) - assert log_dir in list(map(lambda x: str(x), tmp_path.iterdir())) - - -def test_client_wandb_reporter(tmp_path: Path) -> None: - with mock.patch.object(ClientWandBReporter, "__init__", lambda a, b, c, d, e: None): - reporter = ClientWandBReporter("", "", "", "") - log_dir = str(tmp_path.joinpath("fl_wandb_logs")) - reporter._maybe_create_local_log_directory(log_dir) - assert log_dir in list(map(lambda x: str(x), tmp_path.iterdir())) diff --git a/tests/server/test_base_server.py b/tests/server/test_base_server.py index 8dd50c229..576dabc9d 100644 --- a/tests/server/test_base_server.py +++ b/tests/server/test_base_server.py @@ -16,10 +16,12 @@ from fl4health.client_managers.base_sampling_manager import SimpleClientManager from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger +from fl4health.reporting import JsonReporter from fl4health.server.base_server import FlServer, FlServerWithCheckpointing from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, TestMetricPrefix +from tests.test_utils.assert_metrics_dict import assert_metrics_dict from tests.test_utils.custom_client_proxy import CustomClientProxy from tests.test_utils.models_for_test import LinearTransform @@ -83,7 +85,6 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None: client_manager=PoissonSamplingClientManager(), parameter_exchanger=parameter_exchanger, model=initial_model, - wandb_reporter=None, strategy=None, checkpointer=checkpointer, ) @@ -101,20 +102,29 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None: @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_fit(mock_fit: Mock) -> None: test_history = History() - test_history.metrics_centralized = {"test metrics centralized": [(123, "loss")]} - test_history.losses_centralized = [(123, 123.123)] + test_history.metrics_centralized = {"test_metric1": [(1, 123.123), (2, 123)]} + test_history.losses_centralized = [(1, 123.123), (2, 123)] mock_fit.return_value = (test_history, 1) - - fl_server = FlServer(SimpleClientManager()) - fl_server.fit(3, None) - - assert fl_server.metrics_reporter.metrics == { - "type": "server", - "fit_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "fit_end": datetime.datetime(2012, 12, 12, 12, 12, 12), - "metrics_centralized": test_history.metrics_centralized, - "losses_centralized": test_history.losses_centralized, + reporter = JsonReporter() + fl_server = FlServer(SimpleClientManager(), reporters=[reporter]) + fl_server.fit(2, None) + metrics_to_assert = { + "host_type": "server", + "fit_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "fit_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "rounds": { + 1: { + "eval_metrics_centralized": {"test_metric1": 123.123}, + "val - loss - centralized": 123.123, + }, + 2: { + "eval_metrics_centralized": {"test_metric1": 123}, + "val - loss - centralized": 123, + }, + }, } + errors = assert_metrics_dict(metrics_to_assert, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}, {reporter.metrics}" @patch("fl4health.server.base_server.Server.fit_round") @@ -124,18 +134,21 @@ def test_metrics_reporter_fit_round(mock_fit_round: Mock) -> None: test_metrics_aggregated = "test metrics aggregated" mock_fit_round.return_value = (None, test_metrics_aggregated, None) - fl_server = FlServer(SimpleClientManager()) + reporter = JsonReporter() + fl_server = FlServer(SimpleClientManager(), reporters=[reporter]) fl_server.fit_round(test_round, None) - assert fl_server.metrics_reporter.metrics == { + metrics_to_assert = { "rounds": { test_round: { - "fit_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "metrics_aggregated": test_metrics_aggregated, - "fit_end": datetime.datetime(2012, 12, 12, 12, 12, 12), + "fit_round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "fit_metrics": test_metrics_aggregated, + "fit_round_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), }, }, } + errors = assert_metrics_dict(metrics_to_assert, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}. {reporter.metrics}" def test_unpack_metrics() -> None: @@ -201,7 +214,10 @@ def test_handle_result_aggregation() -> None: }, ) - results: List[Tuple[ClientProxy, EvaluateRes]] = [(client_proxy1, eval_res1), (client_proxy2, eval_res2)] + results: List[Tuple[ClientProxy, EvaluateRes]] = [ + (client_proxy1, eval_res1), + (client_proxy2, eval_res2), + ] failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] server_round = 1 @@ -226,18 +242,25 @@ def test_metrics_reporter_evaluate_round(mock_evaluate_round: Mock) -> None: test_round = 2 test_loss_aggregated = "test loss aggregated" test_metrics_aggregated = "test metrics aggregated" - mock_evaluate_round.return_value = (test_loss_aggregated, test_metrics_aggregated, (None, None)) + mock_evaluate_round.return_value = ( + test_loss_aggregated, + test_metrics_aggregated, + (None, None), + ) - fl_server = FlServer(SimpleClientManager()) + reporter = JsonReporter() + fl_server = FlServer(SimpleClientManager(), reporters=[reporter]) fl_server.evaluate_round(test_round, None) - assert fl_server.metrics_reporter.metrics == { + metrics_to_assert = { "rounds": { test_round: { - "evaluate_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "loss_aggregated": test_loss_aggregated, - "metrics_aggregated": test_metrics_aggregated, - "evaluate_end": datetime.datetime(2012, 12, 12, 12, 12, 12), + "eval_round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "val - loss - aggregated": test_loss_aggregated, + "eval_metrics_aggregated": test_metrics_aggregated, + "eval_round_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), }, }, } + errors = assert_metrics_dict(metrics_to_assert, reporter.metrics) + assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" diff --git a/tests/server/test_evaluate_server.py b/tests/server/test_evaluate_server.py index ae004209b..49b928154 100644 --- a/tests/server/test_evaluate_server.py +++ b/tests/server/test_evaluate_server.py @@ -4,21 +4,26 @@ from freezegun import freeze_time from fl4health.client_managers.base_sampling_manager import SimpleClientManager +from fl4health.reporting import JsonReporter from fl4health.server.evaluate_server import EvaluateServer +from tests.test_utils.assert_metrics_dict import assert_metrics_dict @patch("fl4health.server.evaluate_server.EvaluateServer.federated_evaluate") @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_fit(mock_federated_evaluate: Mock) -> None: + pass test_evaluate_metrics = {"test evaluate metrics": 123} mock_federated_evaluate.return_value = (None, test_evaluate_metrics, None) - evaluate_server = EvaluateServer(SimpleClientManager(), 0.5) + reporter = JsonReporter() + evaluate_server = EvaluateServer(SimpleClientManager(), 0.5, reporters=[reporter]) evaluate_server.fit(3, None) - - assert evaluate_server.metrics_reporter.metrics == { - "type": "server", - "fit_start": datetime.datetime(2012, 12, 12, 12, 12, 12), - "fit_end": datetime.datetime(2012, 12, 12, 12, 12, 12), - "metrics": test_evaluate_metrics, + metrics_to_assert = { + "host_type": "server", + "fit_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "fit_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "fit_metrics": test_evaluate_metrics, } + errors = assert_metrics_dict(metrics_to_assert, reporter.metrics) + assert len(errors) == 0, f"Metric check failed. Errors: {errors}" diff --git a/tests/smoke_tests/apfl_client_metrics.json b/tests/smoke_tests/apfl_client_metrics.json index 1f7148a59..ede630e6d 100644 --- a/tests/smoke_tests/apfl_client_metrics.json +++ b/tests/smoke_tests/apfl_client_metrics.json @@ -2,57 +2,57 @@ "rounds": { "1": { "fit_metrics": { - "train - personal - accuracy": 0.6078, - "train - global - accuracy": 0.6265, - "train - local - accuracy": 0.4890 + "train - personal - accuracy": 0.6063, + "train - global - accuracy": 0.6078, + "train - local - accuracy": 0.4797 }, - "loss_dict": { - "backward": 1.4258, - "global": 1.2897, - "local": 1.6243 + "fit_losses": { + "backward": 1.4428, + "global": 1.3272, + "local": 1.6362 }, - "evaluate_metrics": { - "val - personal - accuracy": 0.6254, - "val - global - accuracy": 0.6757, - "val - local - accuracy": 0.5723 + "eval_metrics": { + "val - personal - accuracy": 0.5269, + "val - global - accuracy": 0.5331, + "val - local - accuracy": 0.4984 }, - "loss": 1.4845 + "eval_loss": 1.5862 }, "2": { "fit_metrics": { - "train - personal - accuracy": 0.7845, - "train - global - accuracy": 0.7921, - "train - local - accuracy": 0.775 + "train - personal - accuracy": 0.8063, + "train - global - accuracy": 0.8266, + "train - local - accuracy": 0.7688 }, - "loss_dict": { - "backward": 0.6868, - "global": 0.6557, - "local": 0.7524 + "fit_losses": { + "backward": 0.6089, + "global": 0.6306, + "local": 0.7349 }, - "evaluate_metrics": { - "val - personal - accuracy": 0.6237, - "val - global - accuracy": 0.6583, - "val - local - accuracy": 0.5203 + "eval_metrics": { + "val - personal - accuracy": 0.7117, + "val - global - accuracy": 0.7162, + "val - local - accuracy": 0.6274 }, - "loss": 1.1075 + "eval_loss": 0.9102 }, "3": { "fit_metrics": { - "train - personal - accuracy": 0.8015, - "train - global - accuracy": 0.8078, - "train - local - accuracy": 0.78125 + "train - personal - accuracy": 0.85, + "train - global - accuracy": 0.8703, + "train - local - accuracy": 0.8422 }, - "loss_dict": { - "backward": 0.5872, - "global": 0.5809, - "local": 0.6471 + "fit_losses": { + "backward": 0.4666, + "global": 0.4760, + "local": 0.6050 }, - "evaluate_metrics": { - "val - personal - accuracy": 0.7182, - "val - global - accuracy": 0.7571, - "val - local - accuracy": 0.5864 + "eval_metrics": { + "val - personal - accuracy": 0.7681, + "val - global - accuracy": 0.7943, + "val - local - accuracy": 0.7094 }, - "loss": 0.8714 + "eval_loss": 0.7815 } } } diff --git a/tests/smoke_tests/apfl_server_metrics.json b/tests/smoke_tests/apfl_server_metrics.json index a02ac1401..ee68c9c64 100644 --- a/tests/smoke_tests/apfl_server_metrics.json +++ b/tests/smoke_tests/apfl_server_metrics.json @@ -1,28 +1,28 @@ { "rounds": { "1": { - "metrics_aggregated": { - "val - personal - accuracy": 0.6254, - "val - global - accuracy": 0.6757, - "val - local - accuracy": 0.5723 + "eval_metrics_aggregated": { + "val - personal - accuracy": 0.5269, + "val - global - accuracy": 0.5331, + "val - local - accuracy": 0.4984 }, - "loss_aggregated": 1.4845 + "val - loss - aggregated": 1.5862 }, "2": { - "metrics_aggregated": { - "val - personal - accuracy": 0.6237, - "val - global - accuracy": 0.6583, - "val - local - accuracy": 0.5203 + "eval_metrics_aggregated": { + "val - personal - accuracy": 0.7117, + "val - global - accuracy": 0.7162, + "val - local - accuracy": 0.6274 }, - "loss_aggregated": 1.1076 + "val - loss - aggregated": 0.9102 }, "3": { - "metrics_aggregated": { - "val - personal - accuracy": 0.7182, - "val - global - accuracy": 0.7571, - "val - local - accuracy": 0.5864 + "eval_metrics_aggregated": { + "val - personal - accuracy": 0.7681, + "val - global - accuracy": 0.7943, + "val - local - accuracy": 0.7094 }, - "loss_aggregated": 0.87143 + "val - loss - aggregated": 0.7814 } } } diff --git a/tests/smoke_tests/basic_client_metrics.json b/tests/smoke_tests/basic_client_metrics.json index bb85cf85f..6fbceeae0 100644 --- a/tests/smoke_tests/basic_client_metrics.json +++ b/tests/smoke_tests/basic_client_metrics.json @@ -4,7 +4,7 @@ "fit_metrics": { "train - prediction - accuracy": 0.084375 }, - "loss_dict": { + "fit_losses": { "backward": 3.4583 } }, @@ -12,31 +12,31 @@ "fit_metrics": { "train - prediction - accuracy": 0.1 }, - "loss_dict": { + "fit_losses": { "backward": 2.32976 }, - "evaluate_metrics": { + "eval_metrics": { "val - prediction - accuracy": 0.0942, "test - num_examples": 10000, "test - loss": 2.30616, "test - prediction - accuracy": 0.0966 }, - "loss": 2.3042 + "eval_loss": 2.3042 }, "3": { "fit_metrics": { "train - prediction - accuracy": 0.096875 }, - "loss_dict": { + "fit_losses": { "backward": 2.31093 }, - "evaluate_metrics": { + "eval_metrics": { "val - prediction - accuracy": 0.0936, "test - num_examples": 10000, "test - loss": 2.30109, "test - prediction - accuracy": 0.0972 }, - "loss": 2.2999 + "eval_loss": 2.2999 } } } diff --git a/tests/smoke_tests/basic_server_metrics.json b/tests/smoke_tests/basic_server_metrics.json index 1d4c55d4e..b8ad7f447 100644 --- a/tests/smoke_tests/basic_server_metrics.json +++ b/tests/smoke_tests/basic_server_metrics.json @@ -1,28 +1,28 @@ { "rounds": { "1": { - "metrics_aggregated": { + "eval_metrics_aggregated": { "val - prediction - accuracy": 0.1031, "test - prediction - accuracy": 0.1039, "test - loss - aggregated": 2.3613 }, - "loss_aggregated": 2.3567 + "val - loss - aggregated": 2.3567 }, "2": { - "metrics_aggregated": { + "eval_metrics_aggregated": { "val - prediction - accuracy": 0.0942, "test - prediction - accuracy": 0.0966, "test - loss - aggregated": 2.3061 }, - "loss_aggregated": 2.3042 + "val - loss - aggregated": 2.3042 }, "3": { - "metrics_aggregated": { + "eval_metrics_aggregated": { "val - prediction - accuracy": 0.0936, "test - prediction - accuracy": 0.0972, "test - loss - aggregated": 2.3010 }, - "loss_aggregated": 2.2999 + "val - loss - aggregated": 2.2999 } } } diff --git a/tests/smoke_tests/feddg_ga_client_metrics.json b/tests/smoke_tests/feddg_ga_client_metrics.json index 0be323091..bfd329d12 100644 --- a/tests/smoke_tests/feddg_ga_client_metrics.json +++ b/tests/smoke_tests/feddg_ga_client_metrics.json @@ -1,71 +1,69 @@ { - "type": "client", "rounds": { "1": { "fit_metrics": { - "train - personal - accuracy": 0.6078, - "train - global - accuracy": 0.6265, - "train - local - accuracy": 0.4890, - "val - personal - accuracy": 0.6254, - "val - global - accuracy": 0.6757, - "val - local - accuracy": 0.5723, - "val - loss": 1.4845 + "train - personal - accuracy": 0.6063, + "train - global - accuracy": 0.6078, + "train - local - accuracy": 0.4797, + "val - personal - accuracy": 0.5269, + "val - global - accuracy": 0.5331, + "val - local - accuracy": 0.4984, + "val - loss": 1.5862 }, - "loss_dict": { - "global": 1.2897, - "local": 1.6243, - "backward": 1.4258 + "fit_losses": { + "backward": 1.4428, + "global": 1.3272, + "local": 1.6362 }, - "evaluate_metrics": { - "val - personal - accuracy": 0.6254, - "val - global - accuracy": 0.6757, - "val - local - accuracy": 0.5723 + "eval_metrics": { + "val - personal - accuracy": 0.5269, + "val - global - accuracy": 0.5331, + "val - local - accuracy": 0.4984 }, - "loss": 1.4845 + "eval_loss": 1.5862 }, "2": { "fit_metrics": { - "train - personal - accuracy": 0.7656, - "train - global - accuracy": {"target_value": 0.7609, "custom_tolerance": 0.005}, + "train - personal - accuracy": 0.7906, + "train - global - accuracy": {"target_value": 0.8109, "custom_tolerance": 0.005}, "train - local - accuracy": {"target_value": 0.7453, "custom_tolerance": 0.005}, - "val - personal - accuracy": {"target_value": 0.6778, "custom_tolerance": 0.005}, - "val - global - accuracy": 0.76, - "val - local - accuracy": {"target_value": 0.5066, "custom_tolerance": 0.005}, - "val - loss": {"target_value": 0.9618, "custom_tolerance": 0.005} + "val - personal - accuracy": {"target_value": 0.6584, "custom_tolerance": 0.005}, + "val - global - accuracy": 0.6031, + "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005}, + "val - loss": {"target_value": 1.0008, "custom_tolerance": 0.005} }, - "loss_dict": { - "global": 0.7053, - "local": 0.8164, - "backward": 0.7433 + "fit_losses": { + "global": 0.6785, + "local": 0.7770, + "backward": 0.6460 }, - "evaluate_metrics": { - "val - personal - accuracy": {"target_value": 0.6778, "custom_tolerance": 0.005}, - "val - global - accuracy": 0.76, - "val - local - accuracy": {"target_value": 0.5066, "custom_tolerance": 0.005} + "eval_metrics": { + "val - personal - accuracy": {"target_value": 0.6584, "custom_tolerance": 0.005}, + "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005} }, - "loss": {"target_value": 0.9618, "custom_tolerance": 0.005} + "eval_loss": {"target_value": 1.0008, "custom_tolerance": 0.005} }, "3": { "fit_metrics": { - "train - personal - accuracy": 0.8218, - "train - global - accuracy": 0.8468, - "train - local - accuracy": {"target_value": 0.8, "custom_tolerance": 0.005}, - "val - personal - accuracy": {"target_value": 0.739, "custom_tolerance": 0.005}, - "val - global - accuracy": 0.78, - "val - local - accuracy": {"target_value": 0.5602, "custom_tolerance": 0.005}, - "val - loss": {"target_value": 0.8043, "custom_tolerance": 0.005} + "train - personal - accuracy": 0.8359, + "train - global - accuracy": 0.8656, + "train - local - accuracy": {"target_value": 0.8078, "custom_tolerance": 0.005}, + "val - personal - accuracy": {"target_value": 0.8218, "custom_tolerance": 0.005}, + "val - global - accuracy": 0.8497, + "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005}, + "val - loss": {"target_value": 0.6042, "custom_tolerance": 0.005} }, - "loss_dict": { - "global": 0.4757, - "local": {"target_value": 0.5995, "custom_tolerance": 0.005}, - "backward": {"target_value": 0.5091, "custom_tolerance": 0.005} + "fit_losses": { + "global": 0.5084, + "local": {"target_value": 0.6561, "custom_tolerance": 0.005}, + "backward": {"target_value": 0.5327, "custom_tolerance": 0.005} }, - "evaluate_metrics": { - "val - personal - accuracy": {"target_value": 0.739, "custom_tolerance": 0.005}, - "val - global - accuracy": 0.78, - "val - local - accuracy": {"target_value": 0.5602, "custom_tolerance": 0.005} + "eval_metrics": { + "val - personal - accuracy": {"target_value": 0.8218, "custom_tolerance": 0.005}, + "val - global - accuracy": 0.8497, + "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005} }, - "loss": {"target_value": 0.8043, "custom_tolerance": 0.005} + "eval_loss": {"target_value": 0.6042, "custom_tolerance": 0.005} } } } diff --git a/tests/smoke_tests/feddg_ga_server_metrics.json b/tests/smoke_tests/feddg_ga_server_metrics.json index 73444552b..648943759 100644 --- a/tests/smoke_tests/feddg_ga_server_metrics.json +++ b/tests/smoke_tests/feddg_ga_server_metrics.json @@ -1,31 +1,28 @@ { - "type": "server", "rounds": { "1": { - "metrics_aggregated": { - "val - personal - accuracy": 0.6254, - "val - global - accuracy": 0.6757, - "val - local - accuracy": 0.5723 + "eval_metrics_aggregated": { + "val - personal - accuracy": 0.5269, + "val - global - accuracy": 0.5331, + "val - local - accuracy": 0.4984 }, - "loss_aggregated": 1.4845 + "val - loss - aggregated": 1.5862 }, "2": { - "metrics_aggregated": { - "val - personal - accuracy": {"target_value": 0.6778, "custom_tolerance": 0.005}, - "val - global - accuracy": 0.76, - "val - local - accuracy": {"target_value": 0.5066, "custom_tolerance": 0.005} + "eval_metrics_aggregated": { + "val - personal - accuracy": {"target_value": 0.6584, "custom_tolerance": 0.005}, + "val - global - accuracy": 0.6031, + "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005} }, - "loss_aggregated": {"target_value": 0.9618, "custom_tolerance": 0.005} + "val - loss - aggregated": {"target_value": 1.0008, "custom_tolerance": 0.005} }, "3": { - "metrics_aggregated": { - "val - personal - accuracy": {"target_value": 0.739, "custom_tolerance": 0.005}, - "val - global - accuracy": 0.78, - "val - local - accuracy": {"target_value": 0.5602, "custom_tolerance": 0.005} + "eval_metrics_aggregated": { + "val - personal - accuracy": {"target_value": 0.8218, "custom_tolerance": 0.005}, + "val - global - accuracy": 0.8497, + "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005} }, - "loss_aggregated": {"target_value": 0.8043, "custom_tolerance": 0.005} + "val - loss - aggregated": {"target_value": 0.6042, "custom_tolerance": 0.005} } - }, - "metrics_centralized": {}, - "losses_centralized": [] + } } diff --git a/tests/smoke_tests/fedprox_client_metrics.json b/tests/smoke_tests/fedprox_client_metrics.json index 98ab12cf5..119b9e3c2 100644 --- a/tests/smoke_tests/fedprox_client_metrics.json +++ b/tests/smoke_tests/fedprox_client_metrics.json @@ -1,34 +1,34 @@ { "rounds": { "1": { - "fit_metrics": {"train - prediction - accuracy": 0.2234}, - "loss_dict": { - "loss": 2.1261, - "backward": 2.1525, - "penalty_loss": 0.0263 + "fit_metrics": {"train - prediction - accuracy": 0.2484}, + "fit_losses": { + "loss": 2.1330, + "backward": 2.1598, + "penalty_loss": 0.0268 }, - "evaluate_metrics": {"val - prediction - accuracy": 0.4627}, - "loss": 2.0268 + "eval_metrics": {"val - prediction - accuracy": 0.3633}, + "eval_loss": 1.9861 }, "2": { - "fit_metrics": {"train - prediction - accuracy": 0.43125}, - "loss_dict": { - "loss": 1.8922, - "backward": 1.8922, - "penalty_loss": 0.0 + "fit_metrics": {"train - prediction - accuracy": 0.4531}, + "fit_losses": { + "penalty_loss": 0.0, + "loss": 1.7784, + "backward": 1.7784 }, - "evaluate_metrics": {"val - prediction - accuracy": 0.4323}, - "loss": 1.7356 + "eval_metrics": {"val - prediction - accuracy": 0.5016}, + "eval_loss": 1.4836 }, "3": { - "fit_metrics": {"train - prediction - accuracy": 0.5359}, - "loss_dict": { - "loss": 1.5627, - "backward": 1.5627, - "penalty_loss": 0.0 + "fit_metrics": {"train - prediction - accuracy": 0.6016}, + "fit_losses": { + "penalty_loss": 0.0, + "loss": 1.3226, + "backward": 1.3226 }, - "evaluate_metrics": {"val - prediction - accuracy": 0.5362}, - "loss": 1.4518 + "eval_metrics": {"val - prediction - accuracy": 0.6901}, + "eval_loss": 1.1124 } } } diff --git a/tests/smoke_tests/fedprox_config.yaml b/tests/smoke_tests/fedprox_config.yaml index 6116a1b2d..524475d32 100644 --- a/tests/smoke_tests/fedprox_config.yaml +++ b/tests/smoke_tests/fedprox_config.yaml @@ -14,12 +14,3 @@ proximal_weight_patience : 1 # The number of rounds to wait before increasing or n_clients: 2 # The number of clients in the FL experiment local_steps: 5 # The number of local steps (one per batch) to complete for client batch_size: 128 # The batch size for client training - -reporting_config: - enabled: False - project_name: FL4Health # Name of the project under which everything should be logged - run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name - group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored - entity: "your_entity_here" # WandB user name - notes: "Testing WB reporting" - tags: ["Test", "FedProx"] diff --git a/tests/smoke_tests/fedprox_server_metrics.json b/tests/smoke_tests/fedprox_server_metrics.json index f88dc0d1f..531e53ef5 100644 --- a/tests/smoke_tests/fedprox_server_metrics.json +++ b/tests/smoke_tests/fedprox_server_metrics.json @@ -1,16 +1,16 @@ { "rounds": { "1": { - "metrics_aggregated": {"val - prediction - accuracy": 0.4627}, - "loss_aggregated": 2.0268 + "eval_metrics_aggregated": {"val - prediction - accuracy": 0.3633}, + "val - loss - aggregated": 1.9861 }, "2": { - "metrics_aggregated": {"val - prediction - accuracy": 0.4323}, - "loss_aggregated": 1.7356 + "eval_metrics_aggregated": {"val - prediction - accuracy": 0.5016}, + "val - loss - aggregated": 1.4836 }, "3": { - "metrics_aggregated": {"val - prediction - accuracy": 0.5362}, - "loss_aggregated": 1.4518 + "eval_metrics_aggregated": {"val - prediction - accuracy": 0.6901}, + "val - loss - aggregated": 1.1124 } } } diff --git a/tests/smoke_tests/load_from_checkpoint_example/client.py b/tests/smoke_tests/load_from_checkpoint_example/client.py index ffd8a6aab..9a7719432 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/client.py +++ b/tests/smoke_tests/load_from_checkpoint_example/client.py @@ -13,7 +13,8 @@ from examples.models.cnn_model import Net from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.basic_client import BasicClient -from fl4health.reporting.metrics import MetricsReporter +from fl4health.reporting import JsonReporter +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data from fl4health.utils.losses import LossMeterType @@ -29,7 +30,7 @@ def __init__( device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpointer: Optional[ClientCheckpointModule] = None, - metrics_reporter: Optional[MetricsReporter] = None, + reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, intermediate_client_state_dir: Optional[Path] = None, client_name: Optional[str] = None, @@ -41,7 +42,7 @@ def __init__( device, loss_meter_type, checkpointer, - metrics_reporter, + reporters, progress_bar, intermediate_client_state_dir, client_name, @@ -110,7 +111,8 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict intermediate_client_state_dir=args.intermediate_client_state_dir, client_name=args.client_name, seed=args.seed, + reporters=[JsonReporter()], ) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) - client.metrics_reporter.dump() + client.shutdown() diff --git a/tests/smoke_tests/load_from_checkpoint_example/server.py b/tests/smoke_tests/load_from_checkpoint_example/server.py index bb2c51d57..20278d40c 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/server.py +++ b/tests/smoke_tests/load_from_checkpoint_example/server.py @@ -12,6 +12,7 @@ from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger +from fl4health.reporting import JsonReporter from fl4health.server.base_server import FlServerWithCheckpointing from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -68,9 +69,9 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name client_manager=SimpleClientManager(), model=model, parameter_exchanger=parameter_exchanger, - wandb_reporter=None, strategy=strategy, checkpointer=checkpointers, + reporters=[JsonReporter()], intermediate_server_state_dir=Path(intermediate_server_state_dir), server_name=server_name, ) @@ -81,7 +82,6 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - server.metrics_reporter.dump() server.shutdown() diff --git a/tests/smoke_tests/run_smoke_test.py b/tests/smoke_tests/run_smoke_test.py index 58ab81eb8..142011521 100644 --- a/tests/smoke_tests/run_smoke_test.py +++ b/tests/smoke_tests/run_smoke_test.py @@ -13,7 +13,11 @@ from fl4health.utils.load_data import load_cifar10_data, load_mnist_data -logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S") +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) logger = logging.getLogger() @@ -547,7 +551,7 @@ def _assert_metrics(metric_type: MetricType, metrics_to_assert: Optional[Dict[st with open(file) as f: metrics = json.load(f) - if metrics["type"] != metric_type.value: + if metrics["host_type"] != metric_type.value: continue metrics_found = True diff --git a/tests/smoke_tests/scaffold_client_metrics.json b/tests/smoke_tests/scaffold_client_metrics.json index 5ffb2c456..32a202f59 100644 --- a/tests/smoke_tests/scaffold_client_metrics.json +++ b/tests/smoke_tests/scaffold_client_metrics.json @@ -1,30 +1,30 @@ { "rounds": { "0": { - "fit_metrics": {"train - prediction - accuracy": 0.2}, - "loss_dict": {"backward": 2.2647} + "fit_metrics": {"train - prediction - accuracy": 0.2031}, + "fit_losses": {"backward": 2.2655} }, "1": { - "fit_metrics": {"train - prediction - accuracy": 0.1781}, - "loss_dict": {"backward": 2.267}, - "evaluate_metrics": {"val - prediction - accuracy": {"target_value": 0.1763, "custom_tolerance": 0.005}}, - "loss": 2.2789 + "fit_metrics": {"train - prediction - accuracy": 0.18125}, + "fit_losses": {"backward": 2.2684}, + "eval_metrics": {"val - prediction - accuracy": {"target_value": 0.1824, "custom_tolerance": 0.005}}, + "eval_loss": 2.2785 }, "2": { - "fit_metrics": {"train - prediction - accuracy": {"target_value": 0.40625, "custom_tolerance": 0.05}}, - "loss_dict": { - "backward": {"target_value": 2.1465, "custom_tolerance": 0.005} + "fit_metrics": {"train - prediction - accuracy": {"target_value": 0.3906, "custom_tolerance": 0.05}}, + "fit_losses": { + "backward": {"target_value": 2.1567, "custom_tolerance": 0.005} }, - "evaluate_metrics": {"val - prediction - accuracy": {"target_value": 0.3444, "custom_tolerance": 0.05}}, - "loss": {"target_value": 2.2516, "custom_tolerance": 0.005} + "eval_metrics": {"val - prediction - accuracy": {"target_value": 0.3332, "custom_tolerance": 0.05}}, + "eval_loss": {"target_value": 2.2509, "custom_tolerance": 0.005} }, "3": { - "fit_metrics": {"train - prediction - accuracy": {"target_value": 0.4390, "custom_tolerance": 0.05}}, - "loss_dict": { - "backward": {"target_value": 2.0843, "custom_tolerance": 0.005} + "fit_metrics": {"train - prediction - accuracy": {"target_value": 0.4078, "custom_tolerance": 0.05}}, + "fit_losses": { + "backward": {"target_value": 2.0964, "custom_tolerance": 0.005} }, - "evaluate_metrics": {"val - prediction - accuracy": {"target_value": 0.4734, "custom_tolerance": 0.005}}, - "loss": {"target_value": 2.2294, "custom_tolerance": 0.05} + "eval_metrics": {"val - prediction - accuracy": {"target_value": 0.4062, "custom_tolerance": 0.005}}, + "eval_loss": {"target_value": 2.2070, "custom_tolerance": 0.05} } } } diff --git a/tests/smoke_tests/scaffold_server_metrics.json b/tests/smoke_tests/scaffold_server_metrics.json index c9ead7152..cd9d60875 100644 --- a/tests/smoke_tests/scaffold_server_metrics.json +++ b/tests/smoke_tests/scaffold_server_metrics.json @@ -1,16 +1,16 @@ { "rounds": { "1": { - "metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.1766, "custom_tolerance": 0.005}}, - "loss_aggregated": 2.2789 + "eval_metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.1824, "custom_tolerance": 0.005}}, + "val - loss - aggregated": 2.2785 }, "2": { - "metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.3444, "custom_tolerance": 0.05}}, - "loss_aggregated": {"target_value": 2.2516, "custom_tolerance": 0.005} + "eval_metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.3332, "custom_tolerance": 0.05}}, + "val - loss - aggregated": {"target_value": 2.2509, "custom_tolerance": 0.005} }, "3": { - "metrics_aggregated": {"val - prediction - accuracy": 0.4734}, - "loss_aggregated": {"target_value": 2.2294, "custom_tolerance": 0.05} + "eval_metrics_aggregated": {"val - prediction - accuracy": 0.4062}, + "val - loss - aggregated": {"target_value": 2.2070, "custom_tolerance": 0.05} } } } diff --git a/tests/test_utils/assert_metrics_dict.py b/tests/test_utils/assert_metrics_dict.py new file mode 100644 index 000000000..d6dc2b298 --- /dev/null +++ b/tests/test_utils/assert_metrics_dict.py @@ -0,0 +1,55 @@ +from typing import Any, Optional + +from pytest import approx + +DEFAULT_TOLERANCE = 0.0005 + + +def assert_metrics_dict(metrics_to_assert: dict[str, Any], metrics_saved: dict[str, Any]) -> list[str]: + errors = [] + + def _assert(value: Any, saved_value: Any) -> Optional[str]: + # helper function to avoid code repetition + tolerance = DEFAULT_TOLERANCE + if isinstance(value, dict): + # if the value is a dictionary, extract the target value and the custom tolerance + tolerance = value["custom_tolerance"] + value = value["target_value"] + + if approx(value, abs=tolerance) != saved_value: + return ( + f"Saved value for metric '{metric_key}' ({saved_value}) does not match the requested " + f"value ({value}) within requested tolerance ({tolerance})." + ) + + return None + + for metric_key in metrics_to_assert: + if metric_key not in metrics_saved: + errors.append(f"Metric '{metric_key}' not found in saved metrics.") + continue + + value_to_assert = metrics_to_assert[metric_key] + + if isinstance(value_to_assert, dict): + if "target_value" not in value_to_assert and "custom_tolerance" not in value_to_assert: + # if it's a dictionary, call this function recursively + # except when the dictionary has "target_value" and "custom_tolerance", which should + # be treated as a regular dictionary + errors.extend(assert_metrics_dict(value_to_assert, metrics_saved[metric_key])) + continue + + if isinstance(value_to_assert, list) and len(value_to_assert) > 0: + # if it's a list, call an assertion for each element of the list + for i in range(len(value_to_assert)): + error = _assert(value_to_assert[i], metrics_saved[metric_key][i]) + if error is not None: + errors.append(error) + continue + + # if it's just a regular value, perform the assertion + error = _assert(value_to_assert, metrics_saved[metric_key]) + if error is not None: + errors.append(error) + + return errors