Skip to content

Commit

Permalink
Add sklearn template for flwr new (#3251)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Apr 23, 2024
1 parent e089f75 commit 6716f89
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class MlFramework(str, Enum):
NUMPY = "NumPy"
PYTORCH = "PyTorch"
TENSORFLOW = "TensorFlow"
SKLEARN = "sklearn"


class TemplateNotFound(Exception):
Expand Down
94 changes: 94 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""$project_name: A Flower / Scikit-Learn app."""

import warnings

import numpy as np
from flwr.client import NumPyClient, ClientApp
from flwr_datasets import FederatedDataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss


def get_model_parameters(model):
if model.fit_intercept:
params = [
model.coef_,
model.intercept_,
]
else:
params = [model.coef_]
return params


def set_model_params(model, params):
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model


def set_initial_params(model):
n_classes = 10 # MNIST has 10 classes
n_features = 784 # Number of features in dataset
model.classes_ = np.array([i for i in range(10)])

model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))


class FlowerClient(NumPyClient):
def __init__(self, model, X_train, X_test, y_train, y_test):
self.model = model
self.X_train = X_train
self.X_test = X_test
self.y_train = y_train
self.y_test = y_test

def get_parameters(self, config):
return get_model_parameters(self.model)

def fit(self, parameters, config):
set_model_params(self.model, parameters)

# Ignore convergence failure due to low local epochs
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.model.fit(self.X_train, self.y_train)

return get_model_parameters(self.model), len(self.X_train), {}

def evaluate(self, parameters, config):
set_model_params(self.model, parameters)

loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
accuracy = self.model.score(self.X_test, self.y_test)

return loss, len(self.X_test), {"accuracy": accuracy}

fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})

def client_fn(cid: str):
dataset = fds.load_partition(int(cid), "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

# Split the on edge data: 80% train, 20% test
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]

# Create LogisticRegression Model
model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)

# Setting initial parameters, akin to model.compile for keras models
set_initial_params(model)

return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()


# Flower ClientApp
app = ClientApp(client_fn=client_fn)
17 changes: 17 additions & 0 deletions src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""$project_name: A Flower / Scikit-Learn app."""

from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg


strategy = FedAvg(
fraction_fit=1.0,
fraction_evaluate=1.0,
min_available_clients=2,
)

# Create ServerApp
app = ServerApp(
config=ServerConfig(num_rounds=3),
strategy=strategy,
)
20 changes: 20 additions & 0 deletions src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "$project_name"
version = "1.0.0"
description = ""
authors = [
{ name = "The Flower Authors", email = "[email protected]" },
]
license = {text = "Apache License (2.0)"}
dependencies = [
"flwr[simulation]>=1.8.0,<2.0",
"flwr-datasets[vision]>=0.0.2,<1.0.0",
"scikit-learn>=1.1.1",
]

[tool.hatch.build.targets.wheel]
packages = ["."]

0 comments on commit 6716f89

Please sign in to comment.