Skip to content

Commit

Permalink
feat: Add mask to merge_state_in_place (#372)
Browse files Browse the repository at this point in the history
This pushes down the conditional logic to the kernel, allowing for
better CUDA graph support with variable sequence length. I didn't see
much purpose in adding the `mask` parameter to the out of place merge
state kernels.
  • Loading branch information
Yard1 authored Jul 13, 2024
1 parent 17a5f1b commit e14fa81
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 7 deletions.
12 changes: 10 additions & 2 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,22 @@ __global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__
* \param s The logsumexp value to be updated in-place. (n, h)
* \param v_other The other v to be merged. (n, h, d)
* \param s_other The other logsumexp value to be merged. (n, h)
* \param mask Optional mask of whether to merge given sequences or not. (n)
* \param num_heads The number of heads of v and v_other.
* \param head_dim The dimension of each head.
* \note Both s and s_other are logsumexp values with base 2.
*/
template <uint32_t vec_size, typename DType>
__global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s,
DType* __restrict__ v_other, float* __restrict__ s_other,
uint8_t* __restrict__ mask,
uint32_t num_heads, uint32_t head_dim) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t pos = blockIdx.x;

if (mask != nullptr && mask[pos] == 0)
return;

uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t head_idx = ty;

float s_val = s[pos * num_heads + head_idx];
Expand Down Expand Up @@ -383,13 +389,15 @@ cudaError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DType
* \param seq_len The sequence length.
* \param num_heads The number of heads of v and v_other.
* \param head_dim The dimension of each head.
* \param mask Optional mask of whether to merge given sequences or not. (n)
* \param stream The CUDA stream to execute the kernel.
* \return status Indicates whether CUDA calls are successful
* \note Both s and s_other are logsumexp values with base 2.
*/
template <typename DType>
cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len,
uint32_t num_heads, uint32_t head_dim,
uint8_t* mask = nullptr,
cudaStream_t stream = nullptr) {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U);
Expand All @@ -398,7 +406,7 @@ cudaError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other
dim3 nblks(seq_len);
dim3 nthrs(bdx, bdy);
auto kernel = MergeStateInPlaceKernel<vec_size, DType>;
void* args[] = {&v, &s, &v_other, &s_other, &num_heads, &head_dim};
void* args[] = {&v, &s, &v_other, &s_other, &mask, &num_heads, &head_dim};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
Expand Down
11 changes: 9 additions & 2 deletions python/csrc/cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
}

void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other,
torch::Tensor s_other) {
torch::Tensor s_other, std::optional<torch::Tensor> mask) {
CHECK_INPUT(v);
CHECK_INPUT(s);
CHECK_INPUT(v_other);
Expand All @@ -82,6 +82,13 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
CHECK_EQ(v.size(1), s.size(1));
CHECK_EQ(s.scalar_type(), torch::kFloat32);
CHECK_EQ(s_other.scalar_type(), torch::kFloat32);
uint8_t* mask_ptr = nullptr;
if (mask.has_value()) {
CHECK_DIM(1, mask.value());
CHECK_EQ(v.size(0), mask.value().size(0));
CHECK_EQ(mask.value().device(), device);
mask_ptr = static_cast<uint8_t*>(mask.value().data_ptr());
}
unsigned int seq_len = v.size(0);
unsigned int num_heads = v.size(1);
unsigned int head_dim = v.size(2);
Expand All @@ -91,7 +98,7 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
cudaError_t status = MergeStateInPlace(
static_cast<c_type*>(v.data_ptr()), static_cast<float*>(s.data_ptr()),
static_cast<c_type*>(v_other.data_ptr()), static_cast<float*>(s_other.data_ptr()), seq_len,
num_heads, head_dim, torch_current_stream);
num_heads, head_dim, mask_ptr, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"MergeStateInPlace kernel launch failed: ", cudaGetErrorString(status));
return true;
Expand Down
2 changes: 1 addition & 1 deletion python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
torch::Tensor s_b);

void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other,
torch::Tensor s_other);
torch::Tensor s_other, std::optional<torch::Tensor> mask = std::nullopt);

std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s);

Expand Down
13 changes: 11 additions & 2 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def merge_state(


def merge_state_in_place(
v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor
v: torch.Tensor,
s: torch.Tensor,
v_other: torch.Tensor,
s_other: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
r"""Merge the self-attention state ``(v, s)`` with another state
``(v_other, s_other)`` in-place.
Expand All @@ -117,6 +121,11 @@ def merge_state_in_place(
s_other : torch.Tensor
The other logsumexp value to be merged, expected to be a float32 tensor,
shape: ``(seq_len, num_heads)``.
mask : Optional[torch.Tensor]
The boolean mask tensor for whether to merge the state for a corresponding sequence
or not. Useful for CUDA graphs. If not specified (default), will merge states for
all sequences.
shape: ``[seq_len]``
Example
-------
Expand All @@ -131,7 +140,7 @@ def merge_state_in_place(
>>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> flashinfer.merge_state_in_place(v, s, v_other, s_other)
"""
_kernels.merge_state_in_place(v, s, v_other, s_other)
_kernels.merge_state_in_place(v, s, v_other, s_other, mask)


def merge_states(v: torch.Tensor, s: torch.Tensor):
Expand Down
57 changes: 57 additions & 0 deletions python/tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,63 @@ def test_batch_attention_with_shared_prefix_paged_kv_cache(
o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3
)

@pytest.mark.parametrize("seed", [0])
@pytest.mark.parametrize("num_tries", [50])
def test_merge_state_in_place_with_mask(seed, num_tries):
seq_len = 512
num_heads = 32
head_dim = 128
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
va_orginal = va.clone()
sa_original = sa.clone()

# No mask.
flashinfer.merge_state_in_place(va, sa, vb, sb)
va_merged_ref = va.clone()
sa_merged_ref = sa.clone()
assert not torch.allclose(va_merged_ref, va_orginal)
assert not torch.allclose(sa_merged_ref, sa_original)

# Mask with all 1s. Should be identical to no mask.
mask = torch.ones(seq_len, dtype=torch.bool).to("cuda:0")
va = va_orginal.clone()
sa = sa_original.clone()
flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask)
va_merged = va
sa_merged = sa
numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_merged_ref.cpu().numpy(), rtol=1e-3, atol=1e-3)

# Mask with all zeros. Input and output should be identical.
mask = torch.zeros(seq_len, dtype=torch.bool).to("cuda:0")
va = va_orginal.clone()
sa = sa_original.clone()
flashinfer.merge_state_in_place(va, sa, vb, sb, mask=mask)
va_merged = va
sa_merged = sa
numpy.testing.assert_allclose(va_merged.cpu().numpy(), va_orginal.cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged.cpu().numpy(), sa_original.cpu().numpy(), rtol=1e-3, atol=1e-3)

# Test some random masks.
randgen = torch.Generator(device="cuda:0")
randgen.manual_seed(seed)
for _ in range(num_tries):
rand_mask = (torch.rand(seq_len, generator=randgen, dtype=torch.float32, device="cuda:0") > 0.5).to(dtype=torch.bool)
true_indices = rand_mask.nonzero()
false_indices = (rand_mask==0).nonzero()
va = va_orginal.clone()
sa = sa_original.clone()
flashinfer.merge_state_in_place(va, sa, vb, sb, mask=rand_mask)
va_merged = va
sa_merged = sa

numpy.testing.assert_allclose(va_merged[false_indices].cpu().numpy(), va_orginal[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged[false_indices].cpu().numpy(), sa_original[false_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(va_merged[true_indices].cpu().numpy(), va_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)
numpy.testing.assert_allclose(sa_merged[true_indices].cpu().numpy(), sa_merged_ref[true_indices].cpu().numpy(), rtol=1e-3, atol=1e-3)

if __name__ == "__main__":
test_batch_attention_with_shared_prefix_paged_kv_cache(
Expand Down

0 comments on commit e14fa81

Please sign in to comment.