diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 978191cc0ec..03ac9619062 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -90,7 +90,13 @@ flower-superlink --insecure Start 2 Flower `SuperNodes` in 2 separate terminal windows, using: ```bash -flower-client-app client:app --insecure +flower-client-app client:partition_0 --insecure +``` + +And: + +```bash +flower-client-app client:partition_1 --insecure ``` ### 3. Run the Flower App diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index be4be88b8f8..5fc5fe8cc01 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -94,59 +94,80 @@ def apply_transforms(batch): # 2. Federation of the pipeline with Flower # ############################################################################# -# Get partition id -parser = argparse.ArgumentParser(description="Flower") -parser.add_argument( - "--partition-id", - choices=[0, 1, 2], - default=0, - type=int, - help="Partition of the dataset divided into 3 iid partitions created artificially.", -) -partition_id = parser.parse_known_args()[0].partition_id - -# Load model and data (simple CNN, CIFAR-10) -net = Net().to(DEVICE) -trainloader, testloader = load_data(partition_id=partition_id) - # Define Flower client class FlowerClient(NumPyClient): + def __init__(self, net, data): + super().__init__() + self.net = net + self.trainloader, self.testloader = data + def get_parameters(self, config): - return [val.cpu().numpy() for _, val in net.state_dict().items()] + return [val.cpu().numpy() for _, val in self.net.state_dict().items()] def set_parameters(self, parameters): - params_dict = zip(net.state_dict().keys(), parameters) + params_dict = zip(self.net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - net.load_state_dict(state_dict, strict=True) + self.net.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): self.set_parameters(parameters) - train(net, trainloader, epochs=1) - return self.get_parameters(config={}), len(trainloader.dataset), {} + train(self.net, self.trainloader, epochs=1) + return self.get_parameters(config={}), len(self.trainloader.dataset), {} def evaluate(self, parameters, config): self.set_parameters(parameters) - loss, accuracy = test(net, testloader) - return loss, len(testloader.dataset), {"accuracy": accuracy} + loss, accuracy = test(self.net, self.testloader) + return loss, len(self.testloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +# Flower client_fn for client holding partition 0 +def client_0_fn(cid: str): + net = Net().to(DEVICE) + data = load_data(partition_id=0) """Create and return an instance of Flower `Client`.""" - return FlowerClient().to_client() + return FlowerClient(net, data).to_client() -# Flower ClientApp -app = ClientApp( - client_fn=client_fn, +# Flower ClientApp for client holding partition 0 +partition_0 = ClientApp( + client_fn=client_0_fn, ) +# Flower client_fn for client holding partition 1 +def client_1_fn(cid: str): + net = Net().to(DEVICE) + data = load_data(partition_id=1) + """Create and return an instance of Flower `Client`.""" + return FlowerClient(net, data).to_client() + + +# Flower ClientApp for client holding partition 1 +partition_1 = ClientApp( + client_fn=client_1_fn, +) + # Legacy mode if __name__ == "__main__": from flwr.client import start_client + # Get partition id + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + choices=[0, 1, 2], + default=0, + type=int, + help="Partition of the dataset divided into 3 iid partitions created artificially.", + ) + partition_id = parser.parse_known_args()[0].partition_id + + # Load model and data (simple CNN, CIFAR-10) + net = Net().to(DEVICE) + data = load_data(partition_id=partition_id) + start_client( server_address="127.0.0.1:8080", - client=FlowerClient().to_client(), + client=FlowerClient(net, data).to_client(), )