Skip to content

Commit

Permalink
docs: clarify broadcasting semantics and output shape in `linalg.solv…
Browse files Browse the repository at this point in the history
…e` (#810)

PR-URL: #810
  • Loading branch information
asmeurer authored Jul 25, 2024
1 parent b569b03 commit b933915
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/array_api_stubs/_2022_12/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,12 @@ def solve(x1: array, x2: array, /) -> array:
x1: array
coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type.
x2: array
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type.
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (see :ref:`broadcasting`). Should have a floating-point data type.
Returns
-------
out: array
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
an array containing the solution to the system ``AX = B`` for each square matrix. If ``x2`` has shape ``(M,)``, the returned array must have shape equal to ``shape(x1)[:-2] + shape(x2)[-1:]``. Otherwise, if ``x2`` has shape ``(..., M, K)```, the returned array must have shape equal to ``(..., M, K)``, where ``...`` refers to the result of broadcasting ``shape(x1)[:-2]`` and ``shape(x2)[:-2]``. The returned array must have a floating-point data type determined by :ref:`type-promotion`.
Notes
-----
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_stubs/_2023_12/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,12 @@ def solve(x1: array, x2: array, /) -> array:
x1: array
coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type.
x2: array
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type.
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (see :ref:`broadcasting`). Should have a floating-point data type.
Returns
-------
out: array
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
an array containing the solution to the system ``AX = B`` for each square matrix. If ``x2`` has shape ``(M,)``, the returned array must have shape equal to ``shape(x1)[:-2] + shape(x2)[-1:]``. Otherwise, if ``x2`` has shape ``(..., M, K)```, the returned array must have shape equal to ``(..., M, K)``, where ``...`` refers to the result of broadcasting ``shape(x1)[:-2]`` and ``shape(x2)[:-2]``. The returned array must have a floating-point data type determined by :ref:`type-promotion`.
Notes
-----
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_stubs/_2023_12/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def tile(x: array, repetitions: Tuple[int, ...], /) -> array:

def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]:
"""
Splits an array in a sequence of arrays along the given axis.
Splits an array into a sequence of arrays along the given axis.
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_stubs/_draft/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,12 @@ def solve(x1: array, x2: array, /) -> array:
x1: array
coefficient array ``A`` having shape ``(..., M, M)`` and whose innermost two dimensions form square matrices. Must be of full rank (i.e., all rows or, equivalently, columns must be linearly independent). Should have a floating-point data type.
x2: array
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-1]`` must be compatible with ``shape(x1)[:-1]`` (see :ref:`broadcasting`). Should have a floating-point data type.
ordinate (or "dependent variable") array ``B``. If ``x2`` has shape ``(M,)``, ``x2`` is equivalent to an array having shape ``(..., M, 1)``. If ``x2`` has shape ``(..., M, K)``, each column ``k`` defines a set of ordinate values for which to compute a solution, and ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (see :ref:`broadcasting`). Should have a floating-point data type.
Returns
-------
out: array
an array containing the solution to the system ``AX = B`` for each square matrix. The returned array must have the same shape as ``x2`` (i.e., the array corresponding to ``B``) and must have a floating-point data type determined by :ref:`type-promotion`.
an array containing the solution to the system ``AX = B`` for each square matrix. If ``x2`` has shape ``(M,)``, the returned array must have shape equal to ``shape(x1)[:-2] + shape(x2)[-1:]``. Otherwise, if ``x2`` has shape ``(..., M, K)```, the returned array must have shape equal to ``(..., M, K)``, where ``...`` refers to the result of broadcasting ``shape(x1)[:-2]`` and ``shape(x2)[:-2]``. The returned array must have a floating-point data type determined by :ref:`type-promotion`.
Notes
-----
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_stubs/_draft/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def tile(x: array, repetitions: Tuple[int, ...], /) -> array:

def unstack(x: array, /, *, axis: int = 0) -> Tuple[array, ...]:
"""
Splits an array in a sequence of arrays along the given axis.
Splits an array into a sequence of arrays along the given axis.
Parameters
----------
Expand Down

0 comments on commit b933915

Please sign in to comment.