Skip to content

Commit

Permalink
refactor: decouple kv-cache storage (#379)
Browse files Browse the repository at this point in the history
In our previous design, k-cache and v-cache are coupled together as a
`(num_pages, 2, page_size, num_heads, head_dim)` or a `(num_pages, 2,
num_heads, page_size, head_dim)` tensor.

In this PR, we decouple the k-cache and v-cache storage to enable more
flexible kv-cache storage. Note that the original coupled layout is
still supported, but we also supports standalone k-cache and k-cache.
  • Loading branch information
yzh119 authored Jul 18, 2024
1 parent 9cb28de commit d68a408
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 81 deletions.
4 changes: 2 additions & 2 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
cp_async::commit_group();
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta();
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
Expand Down Expand Up @@ -554,7 +554,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
// load v tiles
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta();
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
Expand Down
17 changes: 8 additions & 9 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs
static_assert(num_frags_z * 4 % num_warps_x == 0);
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta() + kv_offset[i]
: paged_kv.data + kv_offset[i];
DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i];
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
Expand Down Expand Up @@ -1608,8 +1607,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
page_iter, entry_idx);
kv_offset[i] =
page_iter < last_indptr
? paged_kv.get_k_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
? paged_kv.get_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
Expand Down Expand Up @@ -1645,11 +1644,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_k_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
kv_offset[i] =
page_iter < last_indptr
? paged_kv.get_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
cp_async::wait_group<1>();
block.sync();
Expand Down
145 changes: 81 additions & 64 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,20 @@ struct paged_kv_t {
uint32_t num_heads;
uint32_t head_dim;
uint32_t batch_size;
uint32_t stride_page;
uint32_t stride_n;
uint32_t stride_h;

// The flattened key-value cache, used when page_storage == kIndices
// Internal layout:
// [max_num_pages, 2, num_heads, page_size, head_dim] if layout == HND
// [max_num_pages, 2, page_size, num_heads, head_dim] if layout == NHD
DType* data;
// [max_num_pages, num_heads, page_size, head_dim] if layout == HND
// [max_num_pages, page_size, num_heads, head_dim] if layout == NHD
DType* k_data;
DType* v_data;
// [nnz_pages] The page indices array, used when page_storage == kIndices
IdType* indices;
// [nnz_pages] The page pointers array, used when page_storage == kPointer
DType** ptrs;
DType** kv_ptrs;

// [batch_size + 1] The page indptr array, with the first element 0, the last element nnz_pages
IdType* indptr;
Expand All @@ -102,11 +104,13 @@ struct paged_kv_t {
page_size(0),
head_dim(0),
batch_size(0),
stride_page(0),
stride_n(0),
stride_h(0),
data(nullptr),
k_data(nullptr),
v_data(nullptr),
indices(nullptr),
ptrs(nullptr),
kv_ptrs(nullptr),
indptr(nullptr),
last_page_len(nullptr),
rope_pos_offset(nullptr) {}
Expand All @@ -118,26 +122,29 @@ struct paged_kv_t {
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param data The flattened key-value cache
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* data,
IdType* indices, IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
uint32_t batch_size, QKVLayout layout, DType* k_data,
DType* v_data, IdType* indices, IdType* indptr,
IdType* last_page_len, IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
data(data),
k_data(k_data),
v_data(v_data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = num_heads * page_size * head_dim;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}
Expand All @@ -149,92 +156,100 @@ struct paged_kv_t {
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param ptrs The array of pointers to each active page
* \param kv_data The flattened key-value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType** ptrs,
IdType* indptr, IdType* last_page_len,
uint32_t batch_size, QKVLayout layout, DType* kv_data,
IdType* indices, IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
ptrs(ptrs),
k_data(kv_data),
v_data(kv_data + num_heads * page_size * head_dim),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = 2 * num_heads * page_size * head_dim;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}

/*!
* \brief Compute the offset of k element in the allocated buffer.
* \param page_idx The page index
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
* \note This function should only be used when page_storage == kIndices
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param kv_ptrs The array of pointers to each active kv page
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ size_t get_k_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return page_idx * 2 * page_size * num_heads * head_dim + head_idx * stride_h +
entry_idx * stride_n + feat_idx;
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType** kv_ptrs,
IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
kv_ptrs(kv_ptrs),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = 2 * num_heads * page_size * head_dim;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}

/*!
* \brief Compute the offset of k element inside the page.
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
*/
__host__ __device__ __forceinline__ size_t get_k_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return head_idx * stride_h + entry_idx * stride_n + feat_idx;
__host__ __device__ __forceinline__ int64_t kv_ptr_delta() const {
return page_storage == PageStorage::kPointer
? num_heads * page_size * head_dim
: (int64_t(v_data) - int64_t(k_data)) / sizeof(DType);
}

/*!
* \brief Compute the offset of v element in the allocated buffer.
* \brief Compute the offset of element in the allocated buffer.
* \param page_idx The page index
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
* \note This function should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ size_t get_v_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return (page_idx * 2 + 1) * page_size * num_heads * head_dim + head_idx * stride_h +
entry_idx * stride_n + feat_idx;
__host__ __device__ __forceinline__ size_t get_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return page_idx * stride_page + head_idx * stride_h + entry_idx * stride_n + feat_idx;
}

/*!
* \brief Compute the offset of v element inside the page.
* \brief Compute the offset of element inside the page.
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
*/
__host__ __device__ __forceinline__ size_t get_v_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
__host__ __device__ __forceinline__ size_t get_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return head_idx * stride_h + entry_idx * stride_n + feat_idx;
}

__host__ __device__ __forceinline__ uint32_t kv_offset_delta() const {
return num_heads * page_size * head_dim;
}

__device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
if constexpr (page_storage == PageStorage::kIndices) {
return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return ptrs[page_iter] + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return kv_ptrs[page_iter] + get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
}
}

Expand All @@ -243,25 +258,26 @@ struct paged_kv_t {
IdType last_indptr) const {
if constexpr (page_storage == PageStorage::kIndices) {
if (page_iter < last_indptr) {
return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
return k_data;
}
} else {
if (page_iter < last_indptr) {
return ptrs[page_iter] + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return kv_ptrs[page_iter] + get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return *ptrs;
return *kv_ptrs;
}
}
}

__device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
if constexpr (page_storage == PageStorage::kIndices) {
return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return ptrs[page_iter] + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return (kv_ptrs[page_iter] + kv_ptr_delta()) +
get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
}
}

Expand All @@ -270,15 +286,16 @@ struct paged_kv_t {
IdType last_indptr) const {
if constexpr (page_storage == PageStorage::kIndices) {
if (page_iter < last_indptr) {
return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
return v_data;
}
} else {
if (page_iter < last_indptr) {
return ptrs[page_iter] + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return (kv_ptrs[page_iter] + kv_ptr_delta()) +
get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return *ptrs;
return *kv_ptrs;
}
}
}
Expand Down Expand Up @@ -312,7 +329,7 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t<page_storage, DType, I
uint32_t entry_idx = (seq_len - 1) % paged_kv.page_size;

DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = k_ptr + paged_kv.kv_offset_delta();
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr, key + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size);

Expand Down Expand Up @@ -355,7 +372,7 @@ __global__ void AppendPagedKVCachePrefillKernel(paged_kv_t<page_storage, DType,
uint32_t entry_idx = page_seq_idx % paged_kv.page_size;

DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = k_ptr + paged_kv.kv_offset_delta();
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr,
key + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ void append_paged_kv_cache(paged_kv_t<PageStorage::kIndices, T, IdxType> page_cp
for (size_t h = 0; h < num_heads; ++h) {
std::copy(ki.begin() + (j * num_heads + h) * head_dim,
ki.begin() + (j * num_heads + h + 1) * head_dim,
page_cpu.data + page_cpu.get_k_elem_offset(page_idx, h, entry_idx, 0));
page_cpu.k_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0));
std::copy(vi.begin() + (j * num_heads + h) * head_dim,
vi.begin() + (j * num_heads + h + 1) * head_dim,
page_cpu.data + page_cpu.get_v_elem_offset(page_idx, h, entry_idx, 0));
page_cpu.v_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0));
}
}
}
Expand Down
Loading

0 comments on commit d68a408

Please sign in to comment.