Skip to content

Commit

Permalink
Second option for improving quickstart examples
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Apr 5, 2024
1 parent ccaeb70 commit a57fbb9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
8 changes: 7 additions & 1 deletion examples/quickstart-pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,13 @@ flower-superlink --insecure
Start 2 Flower `SuperNodes` in 2 separate terminal windows, using:

```bash
flower-client-app client:app --insecure
flower-client-app client:partition_0 --insecure
```

And:

```bash
flower-client-app client:partition_1 --insecure
```

### 3. Run the Flower App
Expand Down
76 changes: 47 additions & 29 deletions examples/quickstart-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,59 +94,77 @@ def apply_transforms(batch):
# 2. Federation of the pipeline with Flower
# #############################################################################

# Get partition id
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--partition-id",
choices=[0, 1, 2],
default=0,
type=int,
help="Partition of the dataset divided into 3 iid partitions created artificially.",
)
partition_id = parser.parse_known_args()[0].partition_id

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data(partition_id=partition_id)


# Define Flower client
class FlowerClient(NumPyClient):
def __init__(self, net, data):
super().__init__()
self.net = net
self.trainloader, self.testloader = data

def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
params_dict = zip(self.net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
self.net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}
train(self.net, self.trainloader, epochs=1)
return self.get_parameters(config={}), len(self.trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}
loss, accuracy = test(self.net, self.testloader)
return loss, len(self.testloader.dataset), {"accuracy": accuracy}


def client_fn(cid: str):
# Flower client_fn for client holding partition 0
def client_0_fn(cid: str):
net = Net().to(DEVICE)
data = load_data(partition_id=0)
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()

return FlowerClient(net, data).to_client()

# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
# Flower ClientApp for client holding partition 0
partition_0 = ClientApp(
client_fn=client_0_fn,
)

# Flower client_fn for client holding partition 1
def client_1_fn(cid: str):
net = Net().to(DEVICE)
data = load_data(partition_id=1)
"""Create and return an instance of Flower `Client`."""
return FlowerClient(net, data).to_client()

# Flower ClientApp for client holding partition 1
partition_1 = ClientApp(
client_fn=client_1_fn,
)

# Legacy mode
if __name__ == "__main__":
from flwr.client import start_client

# Get partition id
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--partition-id",
choices=[0, 1, 2],
default=0,
type=int,
help="Partition of the dataset divided into 3 iid partitions created artificially.",
)
partition_id = parser.parse_known_args()[0].partition_id

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
data = load_data(partition_id=partition_id)

start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
client=FlowerClient(net, data).to_client(),
)

0 comments on commit a57fbb9

Please sign in to comment.