Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR4: Add deep_mmd_loss files #170

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 115 additions & 70 deletions fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py

Large diffs are not rendered by default.

256 changes: 141 additions & 115 deletions fl4health/losses/deep_mmd_loss.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion fl4health/losses/mkmmd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def optimize_betas(self, X: torch.Tensor, Y: torch.Tensor, lambda_m: float = 1e-
log(INFO, f"{e} We keep previous betas for layer {self.layer_name}.")
else:
log(INFO, f"{e} We keep previous betas.")
raw_betas = self.betas
raw_betas = self.betas.clone()
else:
# If we're trying to maximize the type II error, then we are trying to maximize a convex function over a
# convex polygon of beta values. So the maximum is found at one of the vertices
Expand Down
129 changes: 83 additions & 46 deletions research/cifar10/ditto_deep_mmd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient
from fl4health.utils.config import narrow_config_type
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Accuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from fl4health.utils.sampler import DirichletLabelBasedSampler
from research.cifar10.model import ConvNet
from research.cifar10.preprocess import get_preprocessed_data, get_test_preprocessed_data

NUM_CLIENTS = 10
BASELINE_LAYERS: OrderedDict[str, int] = OrderedDict()
BASELINE_LAYERS["bn1"] = 32768
BASELINE_LAYERS["bn2"] = 16384
Expand All @@ -40,70 +40,90 @@ def __init__(
device: torch.device,
client_number: int,
learning_rate: float,
lam: float,
heterogeneity_level: float,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
deep_mmd_loss_weight: float = 10,
deep_mmd_loss_depth: int = 1,
checkpointer: Optional[ClientCheckpointModule] = None,
use_partitioned_data: bool = False,
) -> None:
size_feature_extraction_layers = OrderedDict(list(BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :])
flatten_feature_extraction_layers = {key: True for key in size_feature_extraction_layers}
feature_extraction_layers_with_size = OrderedDict(list(BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :])
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
lam=lam,
deep_mmd_loss_weight=deep_mmd_loss_weight,
flatten_feature_extraction_layers=flatten_feature_extraction_layers,
size_feature_extraction_layers=size_feature_extraction_layers,
feature_extraction_layers_with_size=feature_extraction_layers_with_size,
)
self.use_partitioned_data = use_partitioned_data
self.client_number = client_number
self.heterogeneity_level = heterogeneity_level
self.learning_rate: float = learning_rate

assert 0 <= client_number < NUM_CLIENTS
log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}")
def setup_client(self, config: Config) -> None:
# Check if the client number is within the range of the total number of clients
num_clients = narrow_dict_type(config, "n_clients", int)
assert 0 <= self.client_number < num_clients
super().setup_client(config)

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
n_clients = narrow_config_type(config, "n_clients", int)
# Set client-specific hash_key for sampler to ensure heterogneous data distribution among clients
sampler = DirichletLabelBasedSampler(
list(range(10)),
sample_percentage=1.0 / n_clients,
beta=self.heterogeneity_level,
hash_key=self.client_number,
)
# Set the same hash_key for the train_loader and val_loader to ensure the same data split
# of train and validation for all clients
train_loader, val_loader, _ = load_cifar10_data(
self.data_path,
batch_size,
validation_proportion=0.2,
sampler=sampler,
hash_key=100,
)
batch_size = narrow_dict_type(config, "batch_size", int)
# The partitioned data should be generated prior to running the clients via preprocess_data function
# in the research/cifar10/preprocess.py file
if self.use_partitioned_data:
train_loader, val_loader, _ = get_preprocessed_data(
self.data_path, self.client_number, batch_size, self.heterogeneity_level
)
else:
n_clients = narrow_dict_type(config, "n_clients", int)
# Set client-specific hash_key for sampler to ensure heterogeneous data distribution among clients
sampler = DirichletLabelBasedSampler(
list(range(10)),
sample_percentage=1.0 / n_clients,
beta=self.heterogeneity_level,
hash_key=self.client_number,
)
# Set the same hash_key for the train_loader and val_loader to ensure the same data split
# of train and validation for all clients
train_loader, val_loader, _ = load_cifar10_data(
self.data_path,
batch_size,
validation_proportion=0.2,
sampler=sampler,
hash_key=100,
)
return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
batch_size = narrow_config_type(config, "batch_size", int)
n_clients = narrow_config_type(config, "n_clients", int)
sampler = DirichletLabelBasedSampler(
list(range(10)),
sample_percentage=1.0 / n_clients,
beta=self.heterogeneity_level,
hash_key=self.client_number,
)
test_loader, _ = load_cifar10_test_data(self.data_path, batch_size, sampler=sampler)
batch_size = narrow_dict_type(config, "batch_size", int)
# The partitioned data should be generated prior to running the clients via preprocess_data function
# in the research/cifar10/preprocess.py file
if self.use_partitioned_data:
test_loader, _ = get_test_preprocessed_data(
self.data_path, self.client_number, batch_size, self.heterogeneity_level
)
else:
n_clients = narrow_dict_type(config, "n_clients", int)
# Set client-specific hash_key for sampler to ensure heterogeneous data distribution among clients
# Also as hash_key is same between train and test sampler, the test data distribution will be same
# as the train data distribution
sampler = DirichletLabelBasedSampler(
list(range(10)),
sample_percentage=1.0 / n_clients,
beta=self.heterogeneity_level,
hash_key=self.client_number,
)
test_loader, _ = load_cifar10_test_data(self.data_path, batch_size, sampler=sampler)
return test_loader

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
# Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized
# Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer
global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9)
local_optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
return {"global": global_optimizer, "local": local_optimizer}
Expand All @@ -128,6 +148,12 @@ def get_model(self, config: Config) -> nn.Module:
help="Path to the preprocessed Cifar 10 Dataset",
required=True,
)
parser.add_argument(
"--use_partitioned_data",
action="store_true",
help="Use preprocessed partitioned data for training, validation and testing",
default=False,
)
parser.add_argument(
"--run_name",
action="store",
Expand All @@ -151,9 +177,6 @@ def get_model(self, config: Config) -> nn.Module:
parser.add_argument(
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=0.1
)
parser.add_argument(
"--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01
)
parser.add_argument(
"--seed",
action="store",
Expand All @@ -179,26 +202,40 @@ def get_model(self, config: Config) -> nn.Module:
"--deep_mmd_loss_depth",
action="store",
type=int,
help="Depth of applying the deep mmd loss",
help="Depth of applying the Deep MMD loss",
required=False,
default=1,
)
args = parser.parse_args()
if args.use_partitioned_data:
log(INFO, "Using preprocessed partitioned data for training, validation and testing")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")
log(INFO, f"Lambda: {args.lam}")
log(INFO, f"Mu: {args.mu}")
log(INFO, f"DEEP MMD Loss Depth: {args.deep_mmd_loss_depth}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

# Adding extensive checkpointing for the client
checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name))
pre_aggregation_best_checkpoint_name = f"pre_aggregation_client_{args.client_number}_best_model.pkl"
pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl"
post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl"
post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl"
checkpointer = ClientCheckpointModule(
pre_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name),
],
post_aggregation=[
BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name),
LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name),
],
)

data_path = Path(args.dataset_dir)
client = CifarDittoClient(
Expand All @@ -208,10 +245,10 @@ def get_model(self, config: Config) -> nn.Module:
client_number=args.client_number,
learning_rate=args.learning_rate,
heterogeneity_level=args.beta,
lam=args.lam,
checkpointer=checkpointer,
deep_mmd_loss_depth=args.deep_mmd_loss_depth,
deep_mmd_loss_weight=args.mu,
use_partitioned_data=args.use_partitioned_data,
)

fl.client.start_client(server_address=args.server_address, client=client.to_client())
Expand Down
4 changes: 2 additions & 2 deletions research/cifar10/ditto_deep_mmd/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Parameters that describe server
n_server_rounds: 20 # The number of rounds to run FL
n_server_rounds: 10 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 10 # The number of clients in the FL experiment
n_clients: 5 # The number of clients in the FL experiment
local_epochs: 5 # The number of epochs to complete for client
batch_size: 32 # The batch size for client training
9 changes: 5 additions & 4 deletions research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#
# Notes:
# 1) The sbatch command above should be run from the top level directory of the repository.
# 2) This example runs ditto. As such the data paths and python launch commands are hardcoded. If you want to change
# 2) This example runs ditto Deep MMD. As such the data paths and python launch commands are hardcoded. If you want to change
# the example you run, you need to explicitly modify the code below.
# 3) The logging directories need to ALREADY EXIST. The script does not create them.
###############################################
Expand Down Expand Up @@ -74,7 +74,7 @@ CLIENT_LR=$5
LAM_VALUE=$6
MU_VALUE=$7
DEEP_MMD_LOSS_DEPTH=$8
SERVER_ADDRESS=${9}
SERVER_ADDRESS=$9
CLIENT_BETA=${10}

# Create the artifact directory
Expand Down Expand Up @@ -134,13 +134,14 @@ do
--config_path ${SERVER_CONFIG_PATH} \
--server_address ${SERVER_ADDRESS} \
--seed ${SEED} \
--lam ${LAM_VALUE} \
> ${SERVER_OUTPUT_FILE} 2>&1 &

# Sleep for 20 seconds to allow the server to come up.
sleep 20

# Start n number of clients and divert the outputs to their own files
n_clients=10
n_clients=5
for (( c=0; c<${n_clients}; c++ ))
do
CLIENT_NAME="client_${c}"
Expand All @@ -154,12 +155,12 @@ do
--run_name ${RUN_NAME} \
--client_number ${c} \
--learning_rate ${CLIENT_LR} \
--lam ${LAM_VALUE} \
--mu ${MU_VALUE} \
--deep_mmd_loss_depth ${DEEP_MMD_LOSS_DEPTH} \
--server_address ${SERVER_ADDRESS} \
--seed ${SEED} \
--beta ${CLIENT_BETA} \
--use_partitioned_data \
> ${CLIENT_LOG_PATH} 2>&1 &
done

Expand Down
13 changes: 9 additions & 4 deletions research/cifar10/ditto_deep_mmd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from flwr.common.logger import log
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand All @@ -33,7 +33,7 @@ def fit_config(
}


def main(config: Dict[str, Any], server_address: str) -> None:
def main(config: Dict[str, Any], server_address: str, lam: float) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand All @@ -47,7 +47,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
# Initializing the model on the server side
model = ConvNet(in_channels=3)
# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
strategy = FedAvgWithAdaptiveConstraint(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
Expand All @@ -58,6 +58,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
initial_loss_weight=lam,
)

server = PersonalServer(client_manager, strategy)
Expand Down Expand Up @@ -98,12 +99,16 @@ def main(config: Dict[str, Any], server_address: str) -> None:
help="Seed for the random number generators across python, torch, and numpy",
required=False,
)
parser.add_argument(
"--lam", action="store", type=float, help="Ditto loss weight for local model training", default=0.01
)
args = parser.parse_args()

config = load_config(args.config_path)
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Lambda: {args.lam}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

main(config, args.server_address)
main(config, args.server_address, args.lam)
Loading
Loading