Skip to content

Commit

Permalink
Hide RecordSet members and make it picklable (#3209)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Apr 18, 2024
1 parent 55d28e4 commit 4eb1bca
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 29 deletions.
95 changes: 67 additions & 28 deletions src/py/flwr/common/record/recordset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,64 +16,103 @@


from dataclasses import dataclass
from typing import Callable, Dict, Optional, Type, TypeVar
from typing import Dict, Optional, cast

from .configsrecord import ConfigsRecord
from .metricsrecord import MetricsRecord
from .parametersrecord import ParametersRecord
from .typeddict import TypedDict

T = TypeVar("T")

class RecordSetData:
"""Inner data container for the RecordSet class."""

@dataclass
class RecordSet:
"""RecordSet stores groups of parameters, metrics and configs."""

_parameters_records: TypedDict[str, ParametersRecord]
_metrics_records: TypedDict[str, MetricsRecord]
_configs_records: TypedDict[str, ConfigsRecord]
parameters_records: TypedDict[str, ParametersRecord]
metrics_records: TypedDict[str, MetricsRecord]
configs_records: TypedDict[str, ConfigsRecord]

def __init__(
self,
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
) -> None:
def _get_check_fn(__t: Type[T]) -> Callable[[T], None]:
def _check_fn(__v: T) -> None:
if not isinstance(__v, __t):
raise TypeError(f"Expected `{__t}`, but `{type(__v)}` was passed.")

return _check_fn

self._parameters_records = TypedDict[str, ParametersRecord](
_get_check_fn(str), _get_check_fn(ParametersRecord)
self.parameters_records = TypedDict[str, ParametersRecord](
self._check_fn_str, self._check_fn_params
)
self._metrics_records = TypedDict[str, MetricsRecord](
_get_check_fn(str), _get_check_fn(MetricsRecord)
self.metrics_records = TypedDict[str, MetricsRecord](
self._check_fn_str, self._check_fn_metrics
)
self._configs_records = TypedDict[str, ConfigsRecord](
_get_check_fn(str), _get_check_fn(ConfigsRecord)
self.configs_records = TypedDict[str, ConfigsRecord](
self._check_fn_str, self._check_fn_configs
)
if parameters_records is not None:
self._parameters_records.update(parameters_records)
self.parameters_records.update(parameters_records)
if metrics_records is not None:
self._metrics_records.update(metrics_records)
self.metrics_records.update(metrics_records)
if configs_records is not None:
self._configs_records.update(configs_records)
self.configs_records.update(configs_records)

def _check_fn_str(self, key: str) -> None:
if not isinstance(key, str):
raise TypeError(
f"Expected `{str.__name__}`, but "
f"received `{type(key).__name__}` for the key."
)

def _check_fn_params(self, record: ParametersRecord) -> None:
if not isinstance(record, ParametersRecord):
raise TypeError(
f"Expected `{ParametersRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_metrics(self, record: MetricsRecord) -> None:
if not isinstance(record, MetricsRecord):
raise TypeError(
f"Expected `{MetricsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)

def _check_fn_configs(self, record: ConfigsRecord) -> None:
if not isinstance(record, ConfigsRecord):
raise TypeError(
f"Expected `{ConfigsRecord.__name__}`, but "
f"received `{type(record).__name__}` for the value."
)


@dataclass
class RecordSet:
"""RecordSet stores groups of parameters, metrics and configs."""

def __init__(
self,
parameters_records: Optional[Dict[str, ParametersRecord]] = None,
metrics_records: Optional[Dict[str, MetricsRecord]] = None,
configs_records: Optional[Dict[str, ConfigsRecord]] = None,
) -> None:
data = RecordSetData(
parameters_records=parameters_records,
metrics_records=metrics_records,
configs_records=configs_records,
)
setattr(self, "_data", data) # noqa

@property
def parameters_records(self) -> TypedDict[str, ParametersRecord]:
"""Dictionary holding ParametersRecord instances."""
return self._parameters_records
data = cast(RecordSetData, getattr(self, "_data")) # noqa
return data.parameters_records

@property
def metrics_records(self) -> TypedDict[str, MetricsRecord]:
"""Dictionary holding MetricsRecord instances."""
return self._metrics_records
data = cast(RecordSetData, getattr(self, "_data")) # noqa
return data.metrics_records

@property
def configs_records(self) -> TypedDict[str, ConfigsRecord]:
"""Dictionary holding ConfigsRecord instances."""
return self._configs_records
data = cast(RecordSetData, getattr(self, "_data")) # noqa
return data.configs_records
18 changes: 17 additions & 1 deletion src/py/flwr/common/record/recordset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""RecordSet tests."""

import pickle
from copy import deepcopy
from typing import Callable, Dict, List, OrderedDict, Type, Union

Expand All @@ -33,7 +34,7 @@
Parameters,
)

from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord
from . import Array, ConfigsRecord, MetricsRecord, ParametersRecord, RecordSet


def get_ndarrays() -> NDArrays:
Expand Down Expand Up @@ -398,3 +399,18 @@ def test_count_bytes_configsrecord() -> None:

record_bytest_count = c_record.count_bytes()
assert bytes_in_dict == record_bytest_count


def test_record_is_picklable() -> None:
"""Test if RecordSet and *Record are picklable."""
# Prepare
p_record = ParametersRecord()
m_record = MetricsRecord({"aa": 123})
c_record = ConfigsRecord({"cc": bytes(9)})
rs = RecordSet()
rs.parameters_records["params"] = p_record
rs.metrics_records["metrics"] = m_record
rs.configs_records["configs"] = c_record

# Execute
pickle.dumps((p_record, m_record, c_record, rs))

0 comments on commit 4eb1bca

Please sign in to comment.