Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZarrTrace #7540

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Add ZarrTrace #7540

wants to merge 1 commit into from

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Oct 16, 2024

Description

This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).

To be honest, the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them. McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:

  1. xarray can read zarr stores directly making it possible to write InferenceData objects to disk directly almost without even having to call a converter.
  2. zarr works with hierarchically structured data. It's possible to store arrays for each variable inside of a group (e.g. posterior, observed_data) directly.
  3. zarr arrays handle numpy arrays nicely. Fixed sized binary data can be fit into zarr arrays seemlessly.
  4. zarr arrays also have the possibility of storing object dtyped arrays using the numcodec package. This makes it possible use the same store to hold sample stats warning objects and step methods sampling_state in the same place as the actual samples from the posterior.
  5. zarr hierarchies can use many different kinds of storage: they can be held in memory, saved as a directory structure, inside of a zip file, or even remotely on s3 buckets.
  6. It's also possible to write to the same zarr object concurrently from different processes or threads, as long as a synchronization object is provided.
  7. zarr also stores the data using a compressed binary representation. The actual compressor can be customized.
  8. zarr arrays are chunked. This means that they don't need to be loaded entirely onto memory, making it possible to leave a smaller memory footprint while sampling. Another benefit of chunking is that write operations on different chunks should be completely independent from each other.

Having stated all of these considerations I intend to:

  • Build a zarr trace backend
  • Replace the MultiTrace and NDArray backend defaults with their Zarr counterparts
  • Store the sampling state in the zarr backend
  • Make it possible to load the zarr trace backend and resume sampling from it.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/

@lucianopaz lucianopaz added enhancements trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Oct 16, 2024
Copy link

codecov bot commented Oct 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.92%. Comparing base (5352798) to head (549518e).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7540      +/-   ##
==========================================
+ Coverage   92.85%   92.92%   +0.06%     
==========================================
  Files         105      106       +1     
  Lines       17591    17754     +163     
==========================================
+ Hits        16335    16498     +163     
  Misses       1256     1256              
Files with missing lines Coverage Δ
pymc/backends/zarr.py 100.00% <100.00%> (ø)

@lucianopaz
Copy link
Contributor Author

This is an important issue to keep track of when we'll eventually want to read the zarr store and create an InferenceData object using xarray and arviz

Comment on lines 95 to 109
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0, "Y"), _dtype, None)
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta64(0, "Y"), _dtype, None)
else:
return (None, _dtype, numcodecs.Pickle())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?

If so, this seems especially perilous for bool 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes a lot more sense now, thanks for the explanation!

In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec, or make some little comment that the fill value is used for initialization?

@michaelosthege
Copy link
Member

the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.

Yes, therefore I would recommend not to use them for any new implementation.

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

  • NumPyBackend is the go-to for in memory situations
  • ClickHouseBackend is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend!
I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.
There are two things which McBackend does not integrate tightly:

  1. xarray because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)
  2. InferenceData groups other than .posterior

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default.
I see that you added "the other" groups as properties to the ZarrTrace. Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

  • How do prior/posterior/log_likelihood data points arrive?
  • Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?
  • If yes, should the storage backend even make a difference between MCMC and forward sampling?

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language?
The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Is that tight integration?

The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:

  1. Determine it before starting the costly MCMC
  2. Serialize it to its own "blob" of data. Think of {constant_data, observed_data, coords, names dtypes, ...} as the "header" section of a trace.

@lucianopaz
Copy link
Contributor Author

Thanks @michaelosthege for the feedback!

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

* `NumPyBackend` is the go-to for _in memory_ situations

* `ClickHouseBackend` is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend! I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.

The way I see this is that McBackend offers a signature to convert from a kind of storage (like MultiTrace) into another one (arviz.InferenceData). I understand that with this method, you guarantee that there should always be a method to go from an McBackend Run to arviz.InferenceData, but you have to handle a lot of transformation logic in this conversion (just like the extra conversion logic that's already in pymc.backends.arviz). In my opinion, this isn't tight integration. Having the data stored in native zarr makes it possible to generate xarray.Dataset objects with a simple xr.open_zarr(store, group) calls, and then these can be wrapped into an InferenceData object with a simple InferenceData(posterior=zarr_posterior, ...) (and potentially even into the future DataTree objects, since zarr hierarchies are already very much tree-like).

There are two things which McBackend does not integrate tightly:

1. `xarray` because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)

xarray does not support sparse arrays. At the moment, the posterior samples are initialized as "empty" zarr arrays (in practice, filled arrays with a fill_value). The nice thing about zarr arrays is that these filled, uninitialized places, don't take up almost any space because they aren't actually stored. If queried, their value gets set from the fill_value attribute. xarray still needs to figure out pydata/xarray#5475 though.

2. `InferenceData` groups other than `.posterior`

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. I see that you added "the other" groups as properties to the ZarrTrace.

The key thing is that I added these groups to the zarr hierarchy, having them as ZarrTrace properties is not necessary. By having them in a single shared zarr entity, they are stored almost like an InferenceData from zarr. I need to actually check if arviz has a from_zarr method, because that would be the direct conversion method from a ZarrTrace to an InferenceData object without having to add any extra conversion code.

Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

* How do prior/posterior/log_likelihood data points arrive?

* Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?

* If yes, should the storage backend even make a difference between MCMC and forward sampling?

I decided to only focus on MCMC for now, and I'm trying to make ZarrTrace handle concurrent writes to the zarr store from multiple processes during sampling. Having said that, it's almost effortless to add other groups to a zarr hierarchy, and the same store could house prior, prior_predictive, posterior_predictive and predictions as well without having to handle almost any extra logic.

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language? The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with object dtype. At the moment, this is limited to two things:

  1. SamplerWarning
  2. StepMethodState

The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since SamplerWarning is a dataclass, it could potentially be converted into a dictionary and then represented as a json object, which zarr can serialize without problems.

Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:

  • Offloading the maintenance cost of the storage backend code
  • Growing set of features that will become available to us as time goes by
  • Seamless compression of the arrays to save storage space
  • Integration with multiple on disk storage options that range from directory structure, zipfiles and multiple SQL and no-SQL databases
  • Integration with distributed or cloud storage like S3, Hadoop, Google Cloud Storage and Azure storage blob, and also fsspec.

I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend.

@maresb
Copy link
Contributor

maresb commented Oct 17, 2024

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

@lucianopaz
Copy link
Contributor Author

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

No, I haven't. But I've made the chunksize customizable now via the draws_per_chunk parameter. @aseyboldt said that we could try to use a different chunk size depending on the dimensionality of the RV.

Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished).

@lucianopaz
Copy link
Contributor Author

By the way, I've added a to_inferencedata method to the zarr trace. I had to do it because I wanted to ensure that the zarr store had consolidated metadata (if it didn't, xarray would complain) and because I needed to pass mask_and_scale=False to xarray.open_zarr (which arviz doesn't allow in from_zarr). Anyway, you can see for yourselves that the conversion code is extremely short because the stored data is already aligned with what arviz wants.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants