Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow access to different nutpie backends via pip-style syntax #7498

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Sep 10, 2024

Description

Adds a pip-style syntax to the nuts_sampler argument that allows access to alternative compile backends, when relevant. This lets you get the nutpie jax backend by setting nuts_sampler='nutpie[jax]'. For backwards compatibility, nuts_sampler='nutpie' is equivalent to nuts_sampler='nutpie[numba]'.

The current PR only deals with nutpie, but we could easily extend this to include the default PyMC sampler, to compile to JAX, numba, or pytorch directly, without going through nutpie. I'm willing to do that extension in this PR if it is deemed worthwhile..

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7498.org.readthedocs.build/en/7498/

@jessegrabowski
Copy link
Member Author

I created some global variables to track available backends in nutpie, anticipating that we will eventually expand the list to include pytorch. It would be nice if these globals were defined in the nutpie library itself and we just imported them here.

@ricardoV94
Copy link
Member

The current PR only deals with nutpie, but we could easily extend this to include the default PyMC sampler, to compile to JAX, numba, or pytorch directly

I'm not sure how the other backends (specially JAX) behave with multiprocessing tbh :O Otherwise the idea sounds cool. Perhaps @aseyboldt can weigh in as he has a better picture of how we handle multiprocessing in pm.sample

@jessegrabowski
Copy link
Member Author

Should we go work on a pmap for pytensor?

@ricardoV94
Copy link
Member

Should we go work on a pmap for pytensor?

The PyMC codebase would still need to know about it and work around differently for JAX

Copy link

codecov bot commented Sep 10, 2024

Codecov Report

Attention: Patch coverage is 47.82609% with 12 lines in your changes missing coverage. Please review.

Project coverage is 92.80%. Comparing base (5352798) to head (a5b3241).

Files with missing lines Patch % Lines
pymc/sampling/mcmc.py 45.45% 12 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7498      +/-   ##
==========================================
- Coverage   92.85%   92.80%   -0.06%     
==========================================
  Files         105      105              
  Lines       17591    17612      +21     
==========================================
+ Hits        16335    16344       +9     
- Misses       1256     1268      +12     
Files with missing lines Coverage Δ
pymc/sampling/jax.py 94.81% <100.00%> (+0.02%) ⬆️
pymc/sampling/mcmc.py 85.34% <45.45%> (-1.90%) ⬇️

@ricardoV94
Copy link
Member

With numba backend we could actually do it with threads (with nogil)? Maybe worth opening an issue to investigate different backends for pymc samplers.

Also this shouldn't have to be nuts specific so for that a backend argument may make more sense. Something like this was also raised for VI recently IIRC

@jessegrabowski
Copy link
Member Author

Should probably just be compile_kwargs like every other pymc function, but I liked my cute syntax.

@ricardoV94
Copy link
Member

ricardoV94 commented Sep 10, 2024

Should probably just be compile_kwargs like every other pymc function, but I liked my cute syntax.

Not if we need to change how the samplers/threads are orchestrated

@ricardoV94
Copy link
Member

But compile kwargs already works anyway

@jessegrabowski
Copy link
Member Author

But compile kwargs already works anyway

What do you mean, pm.sample already takes compile_kwargs? It's undocumented if so.

@ricardoV94
Copy link
Member

But compile kwargs already works anyway

What do you mean, pm.sample already takes compile_kwargs? It's undocumented if so.

It doesn't? Maybe I've played with global mode then

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski, this looks very nice. I left a few suggestions though

Comment on lines +314 to +324
if match is None:
return NUTPIE_DEFAULT_BACKEND
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also get a None match if the string is misformatted. For example, nutpie[jax would return a None match. I suggest that you test exact equality to set the default option, and if you get None then raise a ValueError.

Suggested change
if match is None:
return NUTPIE_DEFAULT_BACKEND
if string == "nutpie":
return NUTPIE_DEFAULT_BACKEND
elif match is None:
raise ValueError(f"Could not parse nutpie backend. Found {string!r}")

expected = (
", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"'
)
raise ValueError(f'Expected one of {expected}; found "{result}"')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError(f'Expected one of {expected}; found "{result}"')
raise ValueError(
'Could not parse nutpie backend. Expected one of {expected}; found "{result}"'
)

return m


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
@pytest.mark.parametrize(
"nuts_sampler",
["pymc", "nutpie", "nutpie[jax]", "blackjax", "numpyro"],
)


def test_invalid_nutpie_backend_raises(pymc_model):
pytest.importorskip("nutpie")
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'):
with pytest.raises(
ValueError,
match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"',
):

with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'):
with pymc_model:
sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with pytest.raises(ValueError, match="Could not parse nutpie backend. Found 'nutpie[bad'"):
with pymc_model:
sample(nuts_sampler="nutpie[bad", random_seed=123, chains=2, tune=500, draws=500)

@twiecki
Copy link
Member

twiecki commented Sep 14, 2024

I posted this PR with comments into GPT o1-mini:

---

Summary

This pull request introduces a pip-style syntax for specifying different Nutpie backends within the nuts_sampler argument of PyMC's pm.sample function. Users can now select a specific backend by using a syntax like nuts_sampler='nutpie[jax]', with nuts_sampler='nutpie' defaulting to nuts_sampler='nutpie[numba]' for backward compatibility. This enhancement aims to provide greater flexibility and performance optimization by allowing users to choose the most suitable backend for their computational needs.

Key Changes:

  1. Code Enhancements:

    • Backend Parsing: Introduces a helper function _extract_backend to parse and extract backend specifications from the nuts_sampler string.
    • Type Definitions: Defines new type aliases for better type safety and clarity.
    • Sampler Integration: Updates sampling functions to handle the new backend specifications seamlessly.
    • Error Handling: Improves error messages to provide clearer guidance when invalid backend specifications are provided.
  2. Testing Improvements:

    • Expanded Test Coverage: Adds new test cases to cover various backend specifications, including valid and invalid scenarios.
    • Refactored Tests: Utilizes fixtures and helper functions to reduce redundancy and enhance test maintainability.
  3. Documentation Updates:

    • Docstrings: Enhances existing docstrings to reflect the new backend syntax and usage examples.
    • ReadTheDocs: Provides a preview link to the updated documentation for further review.
  4. GitHub Discussions:

    • Collaborative Feedback: Addresses feedback from reviewers, refining error handling and test cases based on suggestions.
    • Future Considerations: Mentions potential future expansions, such as supporting additional backends like PyTorch.

Detailed Review

1. Code Changes

a. pymc/sampling/mcmc.py

  • Type Aliases:

    ExternalNutsSampler = Literal["nutpie", "numpyro", "blackjax"]
    NutsSampler = Literal["pymc"] | ExternalNutsSampler
    NutpieBackend = Literal["numba", "jax"]
    • Purpose: Enhances type safety by clearly defining allowed values for samplers and Nutpie backends.
    • Suggestion: Consider placing these type aliases in a separate module if they are to be reused across different parts of the codebase.
  • Global Variables:

    NUTPIE_BACKENDS = get_args(NutpieBackend)
    NUTPIE_DEFAULT_BACKEND = cast(NutpieBackend, "numba")
    • Purpose: Manages available backends and sets the default backend to numba.
    • Suggestion: If the list of backends is expected to grow, consider loading them dynamically or configuring them via external settings.
  • Backend Extraction Function:

    def extract_backend(string: str) -> NutpieBackend:
        match = re.search(r"(?<=\[)[^\]]+(?=\])", string)
        if match is None:
            return NUTPIE_DEFAULT_BACKEND
        result = cast(NutpieBackend, match.group(0))
        if result not in NUTPIE_BACKENDS:
            last_option = f"{NUTPIE_BACKENDS[-1]}"
            expected = (
                ", ".join([f'"{x}"' for x in NUTPIE_BACKENDS[:-1]]) + f' or "{last_option}"'
            )
            raise ValueError(f'Expected one of {expected}; found "{result}"')
        return result
    • Purpose: Parses the backend from the nuts_sampler string and validates it.
    • Strengths:
      • Uses regex to extract backend specifications.
      • Provides clear error messages when invalid backends are specified.
    • Suggestions:
      • Regex Robustness: Ensure the regex accurately captures backends in all valid formats and handles edge cases (e.g., multiple brackets, unexpected characters).
      • Error Messaging: The error message could be more descriptive by including the entire input string to aid debugging.
  • Sampler Integration:

    if sampler.startswith("nutpie"):
        backend = extract_backend(sampler)
        compiled_model = nutpie.compile_pymc_model(model, backend=backend)
    • Purpose: Integrates the backend specification into the Nutpie sampler.
    • Suggestions:
      • Namespace Consistency: Ensure that all references to samplers and backends maintain consistent naming conventions.
      • Extensibility: If future backends are added, verify that the integration logic remains scalable.

b. pymc/sampling/jax.py

  • Type Alias Update:

    JaxNutsSampler = Literal["numpyro", "blackjax"]
    • Purpose: Defines allowed JAX-based samplers.
    • Suggestion: Similar to ExternalNutsSampler, consider centralizing type aliases for better maintainability.
  • Function Signature Update:

    def sample_jax_nuts(
        ...
        nuts_sampler: JaxNutsSampler,
        ...
    ) -> az.InferenceData:
    • Purpose: Enforces that only numpyro and blackjax are accepted as JAX-based samplers.
    • Suggestion: Ensure that all downstream functions and users are aware of this restriction to prevent runtime errors.

c. tests/sampling/test_mcmc_external.py

  • Test Refactoring:

    @pytest.fixture
    def pymc_model():
        with Model() as m:
            x = Normal("x", 100, 5)
            y = Data("y", [1, 2, 3, 4])
            Data("z", [100, 190, 310, 405])
    
            Normal("L", mu=x, sigma=0.1, observed=y)
    
        return m
    • Purpose: Defines a reusable model fixture for tests.
    • Strengths: Promotes DRY (Don't Repeat Yourself) principles by avoiding redundancy in test definitions.
  • New Test Cases:

    • Backend Options Test:

      @pytest.mark.parametrize("backend", ["numba", "jax"], ids=["numba", "jax"])
      def test_numba_backend_options(pymc_model, recwarn, backend):
          ...
      • Purpose: Validates that specifying different Nutpie backends works as expected.
    • Invalid Backend Test:

      def test_invalid_nutpie_backend_raises(pymc_model):
          pytest.importorskip("nutpie")
          with pytest.raises(ValueError, match='Could not parse nutpie backend. Expected one of "numba" or "jax"; found "invalid"'):
              with pymc_model:
                  sample(nuts_sampler="nutpie[invalid]", ...)
      • Purpose: Ensures that invalid backend specifications raise appropriate errors.
      • Strengths: Provides clear assertions that verify error handling mechanisms.
  • Test Helper Function:

    def check_external_sampler_output(warns, idata1, idata2, sample_kwargs):
        ...
    • Purpose: Abstracts common test assertions to enhance readability and maintainability.
    • Suggestion: Ensure that this function covers all necessary assertions to comprehensively validate sampler outputs.

d. Documentation Updates

  • Docstring Enhancements:

    nuts_sampler : str, default "pymc"
        Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"]. In addition, the compilation
        backend for the chosen sampler can be set using square brackets, if available. For example, "nutpie[jax]" will
        use the JAX backend for the nutpie sampler. Currently, "nutpie[jax]" and "nutpie[numba]" are allowed.
    • Purpose: Clarifies the usage of the nuts_sampler argument with examples.
    • Strengths: Provides users with clear guidance on how to specify different backends.
  • Documentation Preview:

    • Link: ReadTheDocs Preview
    • Suggestion: Verify that all new features are accurately reflected in the documentation, including edge cases and error scenarios.

2. Testing Enhancements
  • Increased Coverage:

    • Codecov Report: Indicates a patch coverage of 47.83% with 12 lines missing.
    • Action Required: Ensure that all new code paths, especially those involving error handling and backend parsing, are adequately tested to improve coverage.
  • Comprehensive Test Cases:

    • Valid Backend Specifications: Tests cover both numba and jax backends, ensuring they are correctly parsed and utilized.
    • Invalid Backend Specifications: Tests verify that improper backend inputs are gracefully handled with meaningful error messages.
    • Edge Cases: Tests include scenarios like missing brackets or unsupported backend names to ensure robustness.
  • Fixture Utilization:

    • pymc_model Fixture: Enhances test efficiency by providing a standardized model setup for multiple tests.
    • Suggestion: Consider adding more fixtures if additional setup is required for future tests, promoting scalability.

3. GitHub Discussion & Iterative Improvements
  • Collaborative Feedback:

    • Error Handling Enhancements: Reviewer @lucianopaz suggested refining error messages and handling misformatted backend specifications, which were addressed by the author.
    • Test Case Refinements: Suggestions to expand parameterizations and improve test assertions were implemented, enhancing test robustness.
  • Future Considerations:

    • Additional Backends: The author mentioned the potential to extend support beyond Nutpie, such as integrating with PyTorch, which could be explored in future PRs.
    • Multiprocessing Compatibility: Discussions around how different backends behave with multiprocessing indicate areas for further investigation to ensure seamless integration.
  • Commit History:

    • Clarity and Relevance: Commit messages are descriptive and align with the logical progression of changes, facilitating easier codebase navigation and history tracking.
    • Reversion Handling: The author promptly addressed issues raised during code review, demonstrating responsiveness and a commitment to quality.

Conclusion

This pull request effectively enhances PyMC's sampling functionality by introducing a flexible and user-friendly way to specify different Nutpie backends. The implementation is thoughtfully designed, with clear type definitions, robust error handling, and comprehensive testing to ensure reliability and maintainability.

Recommendations Before Merging:

  1. Improve Regex Parsing:

    • Robustness: Ensure that the regex used in the extract_backend function accurately captures backend specifications and handles all valid and invalid formats.
    • Edge Case Handling: Consider additional tests for edge cases, such as multiple backend specifications or unexpected characters.
  2. Enhance Test Coverage:

    • Complete Coverage: Address the missing lines reported by Codecov by adding tests that cover all new code paths, especially those involving error scenarios and backend extraction.
    • Performance Tests: If applicable, include tests that evaluate the performance impact of different backends to guide users in selecting the most suitable option.
  3. Documentation Completeness:

    • Usage Examples: Incorporate detailed usage examples in the documentation to demonstrate how to specify different backends effectively.
    • Error Messages: Document common error messages and their resolutions to assist users in troubleshooting.
  4. Future Backend Support:

    • Scalability: Plan for the potential addition of more backends by ensuring that the current implementation is easily extensible.
    • Modular Design: Consider refactoring the backend management logic into separate modules or classes to enhance modularity and separation of concerns.
  5. Code Cleanup:

    • Unused Imports: Review and remove any unused imports or variables to maintain code cleanliness.
    • Consistency: Ensure that naming conventions and coding styles are consistent across the new and existing code.

Overall, this PR represents a significant improvement to PyMC's flexibility and user experience, enabling more tailored and optimized sampling strategies. With the recommended refinements, it is well-positioned for successful integration into the main codebase.


Approved with minor recommendations.

@jessegrabowski
Copy link
Member Author

Close this as stale now that #7535 is merged? Or is there interest in this syntax still.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 15, 2024

This syntax is a bit more ergonomic / discoverable

@aseyboldt
Copy link
Member

I like the syntax :-)
But maybe we should go with jax by default? I know it is a change, but I think right now this is much more robust.

@jessegrabowski
Copy link
Member Author

Thoughts on the best way to ask for the gradient backend in this syntax? Or nutpie[jax] should just always set both, and consider the mixed case to be "advanced"?

@ricardoV94
Copy link
Member

I like the syntax :-) But maybe we should go with jax by default? I know it is a change, but I think right now this is much more robust.

That will break people's existing code if they don't have jax installed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose all nutpie compile backends through pm.sample
5 participants