Skip to content

Commit

Permalink
Feature/non contiguous kv cache (#513)
Browse files Browse the repository at this point in the history
This PR solves #506 

Custom strides to support non-contiguous kv cache.
Tests in `test_batch_prefill_kernels.py` and
`test_batch_decode_kernels.py` are modified to test input kv_data on
both contiguous and non-contiguous tensor.

---------

Signed-off-by: LinHeLurking <[email protected]>
Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
LinHeLurking and yzh119 authored Oct 9, 2024
1 parent 794bdda commit 85b1878
Show file tree
Hide file tree
Showing 24 changed files with 484 additions and 438 deletions.
50 changes: 22 additions & 28 deletions flashinfer-aot/csrc_aot/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,36 +87,25 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
bool paged_kv_defined = paged_kv_cache.has_value();
auto device = q.device();
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;
if (paged_kv_defined) {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}

if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache.size(1);
page_size = paged_k_cache.size(2);
} else {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}
uint32_t head_dim = q.size(2);

Expand All @@ -137,8 +126,14 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(

// get q_scalar_type and kv_scalar_type
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type =
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
Expand All @@ -154,10 +149,9 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(

paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr),
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr),
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
47 changes: 19 additions & 28 deletions flashinfer-aot/csrc_aot/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,38 +198,26 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
bool paged_kv_defined = paged_kv_cache.has_value();
auto device = q.device();
int64_t batch_size = paged_kv_indptr.size(0) - 1;
int64_t num_qo_heads = q.size(1);
int64_t num_kv_heads, page_size;
uint32_t head_dim = q.size(2);
if (paged_kv_defined) {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache.size(1);
page_size = paged_k_cache.size(2);
} else {
if (kv_layout == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
page_size = paged_k_cache.size(1);
num_kv_heads = paged_k_cache.size(2);
}

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
Expand All @@ -248,8 +236,14 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
using IdType = int32_t;
bool use_logits_soft_cap = logits_soft_cap > 0.f;
auto q_scalar_type = q.scalar_type();
auto kv_scalar_type =
paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();
auto kv_scalar_type = paged_k_cache.scalar_type();

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] {
using DTypeQ = q_type;
Expand All @@ -260,12 +254,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
static_cast<DTypeKV*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
: nullptr),
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
kv_cache_strides,
static_cast<IdType*>(paged_kv_indices.data_ptr()),
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
Expand Down
5 changes: 2 additions & 3 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor kv_indices,
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
unsigned int layout);

Expand Down
11 changes: 5 additions & 6 deletions flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(

std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
std::optional<torch::Tensor> alibi_slopes, unsigned int kv_layout_code, int window_left,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse);
std::vector<int64_t> plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> alibi_slopes,
unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
Expand Down
7 changes: 3 additions & 4 deletions flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCacheRun(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
unsigned int mask_mode_code, torch::Tensor float_workspace_buffer,
torch::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, torch::Tensor q,
std::optional<torch::Tensor> paged_kv_cache, std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, std::optional<torch::Tensor> maybe_custom_mask,
std::optional<torch::Tensor> maybe_alibi_slopes, torch::Tensor qo_indptr,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
std::optional<torch::Tensor> maybe_custom_mask, std::optional<torch::Tensor> maybe_alibi_slopes,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, std::optional<torch::Tensor> maybe_qk_indptr,
unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
Expand Down
38 changes: 19 additions & 19 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
DTypeKV* k_smem = (DTypeKV*)smem;
DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim *
sizeof(DTypeKV));
DTypeKV** k_ptrs_smem = (DTypeKV**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz *
head_dim * sizeof(DTypeKV));
size_t* kv_offset_smem = (size_t*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz *
head_dim * sizeof(DTypeKV));
float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim *
sizeof(DTypeKV));

Expand Down Expand Up @@ -453,34 +453,35 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr);
kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}
block.sync();

DTypeKV* k_ptrs[tile_size_per_bdx];
size_t kv_offset[tile_size_per_bdx];
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
k_ptrs[j] =
k_ptrs_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size;
kv_offset[j] =
kv_offset_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size;
}
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
k_ptrs[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
paged_kv.k_data + kv_offset[j],
((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
v_ptr, ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
paged_kv.v_data + kv_offset[j],
((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
stage_idx = (stage_idx + 1) % num_stages_smem;
Expand All @@ -499,8 +500,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
packed_page_iter_base + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx),
q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr);
kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}
}
// compute qk
Expand All @@ -516,10 +517,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__

#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) *
tile_size_per_bdx +
j] +
tx * vec_size;
kv_offset[j] = kv_offset_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) *
tile_size_per_bdx +
j] +
tx * vec_size;
}

// load k tiles
Expand All @@ -528,7 +529,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
k_ptrs[j],
paged_kv.k_data + kv_offset[j],
(((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
Expand All @@ -543,11 +544,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
// load v tiles
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
v_ptr,
paged_kv.v_data + kv_offset[j],
(((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size);
}
cp_async::commit_group();
Expand Down
16 changes: 6 additions & 10 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1856,11 +1856,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
lane_idx / kv_frag_cols +
kv_frag_rows * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>())
: 0;
kv_offset[i] = paged_kv.protective_get_kv_offset(
page_iter, kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>(), last_indptr);
}
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size);
Expand Down Expand Up @@ -1902,11 +1900,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
lane_idx / kv_frag_cols +
kv_frag_rows * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>())
: 0;
kv_offset[i] = paged_kv.protective_get_kv_offset(
page_iter, kv_head_idx, entry_idx,
(lane_idx % kv_frag_cols) * num_elems_per_128b<DTypeKV>(), last_indptr);
}
cp_async::wait_group<1>();
block.sync();
Expand Down
Loading

0 comments on commit 85b1878

Please sign in to comment.