Skip to content

Commit

Permalink
Properly migrating the evaluate client
Browse files Browse the repository at this point in the history
  • Loading branch information
emersodb committed Oct 17, 2024
1 parent 970fc74 commit ab9183a
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions fl4health/clients/evaluate_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.reporting.report_manager import ReportsManager
from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType
from fl4health.utils.metrics import Metric, MetricManager
from fl4health.utils.random import generate_hash
Expand Down Expand Up @@ -46,10 +47,8 @@ def __init__(
self.initialized = False

# Initialize reporters with client information.
self.reporters = [] if reporters is None else list(reporters)

for r in self.reporters:
r.initialize(id=self.client_name)
self.reports_manager = ReportsManager(reporters)
self.reports_manager.initialize(id=self.client_name)

# This data loader should be instantiated as the one on which to run evaluation
self.global_loss_meter = LossMeter[EvaluationLosses](loss_meter_type, EvaluationLosses)
Expand Down Expand Up @@ -88,8 +87,7 @@ def setup_client(self, config: Config) -> None:
self.criterion = self.get_criterion(config)
self.parameter_exchanger = self.get_parameter_exchanger(config)

for r in self.reporters:
r.report({"host_type": "client", "initialized": str(datetime.datetime.now())})
self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())})

self.initialized = True

Expand Down Expand Up @@ -117,17 +115,19 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di
assert self.local_model or self.global_model

loss, metric_values = self.validate()
elapsed = datetime.datetime.now() - start_time

for r in self.reporters:
r.report(
{
"eval_metrics": metric_values,
"eval_loss": loss,
"eval_start": str(start_time),
"eval_time_elapsed": str(elapsed),
},
)
end_time = datetime.datetime.now()
elapsed = end_time - start_time

self.reports_manager.report(
{
"eval_metrics": metric_values,
"eval_loss": loss,
"eval_start": str(start_time),
"eval_time_elapsed": str(elapsed),
"eval_end": str(end_time),
},
0,
)

# EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics
# calculation results.
Expand Down

0 comments on commit ab9183a

Please sign in to comment.