Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reporting restructure #254

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ea175da
Reporting restructure
scarere Oct 10, 2024
be770a7
addressing minor comments from John
scarere Oct 15, 2024
7f75b7f
Merge branch 'main' into base_reporter
scarere Oct 15, 2024
0cd5a61
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
88d7c3b
Fixing remaining merge conflicts
scarere Oct 15, 2024
49dcc51
added report manager and change batch to step
scarere Oct 15, 2024
a1f1616
Merge branch 'main' into base_reporter
scarere Oct 15, 2024
9b9357b
fixed tests on local machine
scarere Oct 16, 2024
28c5bf3
removing the generic from list assertion
emersodb Oct 16, 2024
137f639
Updated gitignore and fedprox example
scarere Oct 17, 2024
d39c83c
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2024
80f60b1
Checking in changes to help fix smoke tests. Want to see what the res…
emersodb Oct 17, 2024
86128f3
Partial fix of the apfl smoke tests
emersodb Oct 17, 2024
8815992
Fixing APFL server metrics comparison
emersodb Oct 17, 2024
97982b6
Fixing APFL client metrics
emersodb Oct 17, 2024
4308041
Enabling FedDGGA smokek test
emersodb Oct 17, 2024
590f23c
Partial fix of the FedGD-GA Smoke tests
emersodb Oct 17, 2024
4669faf
More fixes of the FedGD-GA Smoke tests
emersodb Oct 17, 2024
e135e1a
More fixes of the FedGD-GA Smoke tests
emersodb Oct 17, 2024
a485f84
Hopefully final fix of the FedGD-GA Smoke tests
emersodb Oct 17, 2024
c205bd5
partial fix of the fedgga client metrics
emersodb Oct 17, 2024
cd8b1bf
Hopefully final smoke test fix
emersodb Oct 17, 2024
cff9588
fixing json
emersodb Oct 17, 2024
54b0b6b
another json fix
emersodb Oct 17, 2024
970fc74
re-enabling the remainder of the smoke tests
emersodb Oct 17, 2024
ab9183a
Properly migrating the evaluate client
emersodb Oct 17, 2024
12b880b
Fixing a few last smoke test bugs.
emersodb Oct 17, 2024
be2ddf1
Fixing a small bug in the unit tests associated with the evaluate cli…
emersodb Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ dmypy.json
# vscode
launch.json
settings.json
.devcontainer*

#mac
.DS_Store
Expand Down
1 change: 0 additions & 1 deletion examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
1 change: 0 additions & 1 deletion examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
6 changes: 3 additions & 3 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -58,7 +59,6 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistApflClient(data_path, [Accuracy()], DEVICE)
client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.metrics_reporter.dump()
client.shutdown() # This will tell the JsonReporter to dump data
4 changes: 2 additions & 2 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.server.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
Expand Down Expand Up @@ -59,15 +60,14 @@ def main(config: Dict[str, Any]) -> None:
)

client_manager = SimpleClientManager()
server = FlServer(client_manager, strategy)
server = FlServer(client_manager, strategy, reporters=[JsonReporter()])

fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

server.metrics_reporter.dump()
server.shutdown()


Expand Down
1 change: 0 additions & 1 deletion examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointers,
)
Expand Down
5 changes: 2 additions & 3 deletions examples/ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.ditto_client import DittoClient
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -68,10 +69,8 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistDittoClient(data_path, [Accuracy()], DEVICE)
client = MnistDittoClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
client.shutdown()

client.metrics_reporter.dump()
7 changes: 4 additions & 3 deletions examples/dp_fed_examples/instance_level_dp/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import string
from collections.abc import Sequence
from functools import partial
from random import choices
from typing import Any, Dict, Optional
Expand All @@ -15,7 +16,7 @@
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer, OpacusCheckpointer
from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.reporting.fl_wandb import ServerWandBReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.server.instance_level_dp_server import InstanceLevelDpServer
from fl4health.strategies.basic_fedavg import OpacusBasicFedAvg
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -71,8 +72,8 @@ def __init__(
strategy: OpacusBasicFedAvg,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
wandb_reporter: Optional[ServerWandBReporter] = None,
checkpointer: Optional[OpacusCheckpointer] = None,
reporters: Sequence[BaseReporter] | None = None,
delta: Optional[float] = None,
) -> None:
super().__init__(
Expand All @@ -83,8 +84,8 @@ def __init__(
strategy,
local_epochs,
local_steps,
wandb_reporter,
checkpointer,
reporters,
delta,
)
self.parameter_exchanger = FullParameterExchanger()
Expand Down
6 changes: 3 additions & 3 deletions examples/feddg_ga_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -58,7 +59,6 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistApflClient(data_path, [Accuracy()], DEVICE)
client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.metrics_reporter.dump()
client.shutdown()
5 changes: 3 additions & 2 deletions examples/feddg_ga_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.server.base_server import FlServer
from fl4health.strategies.feddg_ga_strategy import FedDgGaStrategy
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -66,15 +67,15 @@ def main(config: Dict[str, Any]) -> None:
# will return the same sampling until it is told to reset, which in FedDgGaStrategy
# is done right before fit_round.
client_manager = FixedSamplingClientManager()
server = FlServer(strategy=strategy, client_manager=client_manager)
server = FlServer(strategy=strategy, client_manager=client_manager, reporters=[JsonReporter()])

fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)

server.metrics_reporter.dump()
server.shutdown()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions examples/federated_eval_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from examples.models.cnn_model import Net
from fl4health.clients.evaluate_client import EvaluateClient
from fl4health.reporting.metrics import MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_test_data
from fl4health.utils.losses import LossMeterType
Expand All @@ -25,15 +25,15 @@ def __init__(
metrics: Sequence[Metric],
device: torch.device,
model_checkpoint_path: Optional[Path],
metrics_reporter: Optional[MetricsReporter] = None,
reporters: Sequence[BaseReporter] | None = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
model_checkpoint_path=model_checkpoint_path,
loss_meter_type=LossMeterType.AVERAGE,
metrics_reporter=metrics_reporter,
reporters=reporters,
)

def initialize_global_model(self, config: Config) -> Optional[nn.Module]:
Expand Down
1 change: 0 additions & 1 deletion examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
5 changes: 2 additions & 3 deletions examples/fedprox_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.fed_prox_client import FedProxClient
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -64,10 +65,8 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistFedProxClient(data_path, [Accuracy()], DEVICE)
client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
client.shutdown()

client.metrics_reporter.dump()
7 changes: 3 additions & 4 deletions examples/fedprox_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ local_epochs: 1 # The number of epochs to complete for client
batch_size: 128 # The batch size for client training

reporting_config:
enabled: False
project_name: FL4Health # Name of the project under which everything should be logged
run_name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group_name: "FedProx Experiment" # Group under which each of the FL run logging will be stored
project: FL4Health # Name of the project under which everything should be logged
name: "FedProx Server" # Name of the run on the server-side, each client will also have it's own run name
group: "FedProx Experiment" # Group under which each of the FL run logging will be stored
entity: "your_entity_here" # WandB user name
notes: "Testing WB reporting"
tags: ["Test", "FedProx"]
10 changes: 6 additions & 4 deletions examples/fedprox_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from examples.models.cnn_model import MnistNet
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.reporting.fl_wandb import ServerWandBReporter
from fl4health.reporting import JsonReporter, WandBReporter
from fl4health.server.adaptive_constraint_servers.fedprox_server import FedProxServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -78,9 +78,12 @@ def main(config: Dict[str, Any], server_address: str) -> None:
loss_weight_patience=config["proximal_weight_patience"],
)

wandb_reporter = ServerWandBReporter.from_config(config)
wandb_reporter = WandBReporter("round", **config["reporting_config"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why we're doing this with the kwargs implementation below. However, I'd recommend we limit the use of freeform arguments from unrolled dictionaries as much as possible if we can. It makes it harder for a user to know what to send along.

json_reporter = JsonReporter()
client_manager = SimpleClientManager()
server = FedProxServer(client_manager=client_manager, strategy=strategy, model=None, wandb_reporter=wandb_reporter)
server = FedProxServer(
client_manager=client_manager, strategy=strategy, model=None, reporters=[wandb_reporter, json_reporter]
)

fl.server.start_server(
server=server,
Expand All @@ -89,7 +92,6 @@ def main(config: Dict[str, Any], server_address: str) -> None:
)
# Shutdown the server gracefully
server.shutdown()
server.metrics_reporter.dump()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def main(config: Dict[str, Any]) -> None:

# Initializing the model on the server side
model: nn.Module = FedSimClrModel(
CifarSslEncoder(), CifarSslProjectionHead(), CifarSslPredictionHead(), pretrain=True
CifarSslEncoder(),
CifarSslProjectionHead(),
CifarSslPredictionHead(),
pretrain=True,
)
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
Expand All @@ -67,7 +70,6 @@ def main(config: Dict[str, Any]) -> None:
client_manager=SimpleClientManager(),
parameter_exchanger=parameter_exchanger,
model=model,
wandb_reporter=None,
strategy=strategy,
checkpointer=checkpointer,
)
Expand Down
7 changes: 4 additions & 3 deletions examples/fenda_ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fl4health.model_bases.fenda_base import FendaModel
from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode
from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -114,18 +115,18 @@ def get_criterion(self, config: Config) -> _Loss:
)

checkpointer = ClientCheckpointModule(
pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer
pre_aggregation=pre_aggregation_checkpointer,
post_aggregation=post_aggregation_checkpointer,
)
client = MnistFendaDittoClient(
data_path,
[Accuracy()],
DEVICE,
args.checkpoint_path,
checkpointer=checkpointer,
reporters=[JsonReporter()],
)
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
client.shutdown()

client.metrics_reporter.dump()
6 changes: 3 additions & 3 deletions examples/mr_mtl_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from examples.models.cnn_model import MnistNet
from fl4health.clients.mr_mtl_client import MrMtlClient
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -64,10 +65,9 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistMrMtlClient(data_path, [Accuracy()], DEVICE)
client = MnistMrMtlClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])

fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
client.shutdown()

client.metrics_reporter.dump()
5 changes: 2 additions & 3 deletions examples/scaffold_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.scaffold_client import ScaffoldClient
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
from fl4health.utils.metrics import Accuracy
Expand Down Expand Up @@ -55,8 +56,6 @@ def get_criterion(self, config: Config) -> _Loss:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistScaffoldClient(data_path, [Accuracy()], DEVICE)
client = MnistScaffoldClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()

client.metrics_reporter.dump()
Loading
Loading