diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/flashinfer-aot/csrc_aot/batch_decode.cu index f9e4796f..fbf6381e 100644 --- a/flashinfer-aot/csrc_aot/batch_decode.cu +++ b/flashinfer-aot/csrc_aot/batch_decode.cu @@ -87,36 +87,25 @@ std::vector BatchDecodeWithPagedKVCachePlan( std::vector BatchDecodeWithPagedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - std::vector plan_info_vec, torch::Tensor q, - std::optional paged_kv_cache, std::optional paged_k_cache, - std::optional paged_v_cache, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, - std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, + unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse) { DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(kv_layout_code); - bool paged_kv_defined = paged_kv_cache.has_value(); auto device = q.device(); int64_t batch_size = q.size(0); int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; - if (paged_kv_defined) { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } + + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } uint32_t head_dim = q.size(2); @@ -137,8 +126,14 @@ std::vector BatchDecodeWithPagedKVCacheRun( // get q_scalar_type and kv_scalar_type auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = - paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + auto kv_scalar_type = paged_k_cache.scalar_type(); + + // get kv_cache_strides + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; @@ -154,10 +149,9 @@ std::vector BatchDecodeWithPagedKVCacheRun( paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), + kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/flashinfer-aot/csrc_aot/batch_prefill.cu index ce943378..f38b0762 100644 --- a/flashinfer-aot/csrc_aot/batch_prefill.cu +++ b/flashinfer-aot/csrc_aot/batch_prefill.cu @@ -198,38 +198,26 @@ std::vector BatchPrefillWithRaggedKVCacheRun( std::vector BatchPrefillWithPagedKVCacheRun( unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - std::optional paged_kv_cache, std::optional paged_k_cache, - std::optional paged_v_cache, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); - bool paged_kv_defined = paged_kv_cache.has_value(); auto device = q.device(); int64_t batch_size = paged_kv_indptr.size(0) - 1; int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; uint32_t head_dim = q.size(2); - if (paged_kv_defined) { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -248,8 +236,14 @@ std::vector BatchPrefillWithPagedKVCacheRun( using IdType = int32_t; bool use_logits_soft_cap = logits_soft_cap > 0.f; auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = - paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + auto kv_scalar_type = paged_k_cache.scalar_type(); + + // get kv_cache_strides + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE(q_scalar_type, kv_scalar_type, q_type, kv_type, [&] { using DTypeQ = q_type; @@ -260,12 +254,9 @@ std::vector BatchPrefillWithPagedKVCacheRun( return DISPATCH_LOGITS_SOFT_CAP(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), + kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index c9f0313f..b2fc051d 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -16,9 +16,8 @@ #include void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor append_indptr, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor kv_indices, torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, unsigned int layout); diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu index a9a666a3..6483d983 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_decode.cu @@ -31,12 +31,11 @@ std::vector BatchDecodeWithPagedKVCachePlan( std::vector BatchDecodeWithPagedKVCacheRun( torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, - std::vector plan_info_vec, torch::Tensor q, - std::optional paged_kv_cache, std::optional paged_k_cache, - std::optional paged_v_cache, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, - std::optional alibi_slopes, unsigned int kv_layout_code, int window_left, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); + std::vector plan_info_vec, torch::Tensor q, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, + unsigned int kv_layout_code, int window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu index 955a6cb1..9ef91d8b 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_prefill.cu @@ -39,10 +39,9 @@ std::vector BatchPrefillWithRaggedKVCacheRun( std::vector BatchPrefillWithPagedKVCacheRun( unsigned int mask_mode_code, torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - std::optional paged_kv_cache, std::optional paged_k_cache, - std::optional paged_v_cache, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, - torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, + torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional maybe_qk_indptr, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 5320e84c..02a9564d 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -413,8 +413,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ DTypeKV* k_smem = (DTypeKV*)smem; DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeKV)); - DTypeKV** k_ptrs_smem = (DTypeKV**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * - head_dim * sizeof(DTypeKV)); + size_t* kv_offset_smem = (size_t*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * + head_dim * sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeKV)); @@ -453,34 +453,35 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { uint32_t q, r; paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r); - k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = - paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr); + kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = + paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr); } block.sync(); - DTypeKV* k_ptrs[tile_size_per_bdx]; + size_t kv_offset[tile_size_per_bdx]; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - k_ptrs[j] = - k_ptrs_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size; + kv_offset[j] = + kv_offset_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size; } #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k_ptrs[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); + paged_kv.k_data + kv_offset[j], + ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta(); cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v_ptr, ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); + paged_kv.v_data + kv_offset[j], + ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; @@ -499,8 +500,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ packed_page_iter_base + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz + ((j * bdz + tz) * bdy + ty) * bdx + tx), q, r); - k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = - paged_kv.protective_get_k_ptr(q, kv_head_idx, r, 0, last_indptr); + kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = + paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr); } } // compute qk @@ -516,10 +517,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * - tile_size_per_bdx + - j] + - tx * vec_size; + kv_offset[j] = kv_offset_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * + tile_size_per_bdx + + j] + + tx * vec_size; } // load k tiles @@ -528,7 +529,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - k_ptrs[j], + paged_kv.k_data + kv_offset[j], (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); @@ -543,11 +544,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ // load v tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta(); cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, - v_ptr, + paged_kv.v_data + kv_offset[j], (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 54642f5a..3c3ce130 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1856,11 +1856,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage lane_idx / kv_frag_cols + kv_frag_rows * num_warps_x * num_warps_z * i, page_iter, entry_idx); - kv_offset[i] = page_iter < last_indptr - ? paged_kv.get_elem_offset( - __ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b()) - : 0; + kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % kv_frag_cols) * num_elems_per_128b(), last_indptr); } page_produce_kv( k_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); @@ -1902,11 +1900,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage lane_idx / kv_frag_cols + kv_frag_rows * num_warps_x * num_warps_z * i, page_iter, entry_idx); - kv_offset[i] = page_iter < last_indptr - ? paged_kv.get_elem_offset( - __ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx, - (lane_idx % kv_frag_cols) * num_elems_per_128b()) - : 0; + kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % kv_frag_cols) * num_elems_per_128b(), last_indptr); } cp_async::wait_group<1>(); block.sync(); diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index 18c1af20..e15adbcb 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -73,47 +73,6 @@ struct paged_kv_t { last_page_len(nullptr), rope_pos_offset(nullptr) {} - /*! - * \brief Construct a paged key-value cache - * \param num_heads The number of heads - * \param page_size The size of each page - * \param head_dim The dimension of each head - * \param batch_size The batch size - * \param layout The layout of last 3 dimensions in KV-Cache. - * \param kv_data The flattened key-value cache - * \param k_data The flattened key cache - * \param v_data The flattened value cache - * \param indices The page indices array - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the batch - * \param rope_pos_offset The start position of each request in the batch. - */ - __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, - uint32_t batch_size, QKVLayout layout, DType* kv_data, - DType* k_data, DType* v_data, IdType* indices, IdType* indptr, - IdType* last_page_len, IdType* rope_pos_offset = nullptr) - : num_heads(num_heads), - page_size(page_size), - head_dim(head_dim), - batch_size(batch_size), - indices(indices), - indptr(indptr), - last_page_len(last_page_len), - rope_pos_offset(rope_pos_offset) { - bool kv_defined = kv_data != nullptr; - if (kv_defined) { - stride_page = 2 * num_heads * page_size * head_dim; - this->k_data = kv_data; - this->v_data = kv_data + num_heads * page_size * head_dim; - } else { - stride_page = num_heads * page_size * head_dim; - this->k_data = k_data; - this->v_data = v_data; - } - stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; - stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; - } - /*! * \brief Construct a paged key-value cache * \param num_heads The number of heads @@ -136,13 +95,13 @@ struct paged_kv_t { page_size(page_size), head_dim(head_dim), batch_size(batch_size), - k_data(k_data), - v_data(v_data), indices(indices), indptr(indptr), last_page_len(last_page_len), rope_pos_offset(rope_pos_offset) { stride_page = num_heads * page_size * head_dim; + this->k_data = k_data; + this->v_data = v_data; stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; } @@ -154,33 +113,33 @@ struct paged_kv_t { * \param head_dim The dimension of each head * \param batch_size The batch size * \param layout The layout of last 3 dimensions in KV-Cache. - * \param kv_data The flattened key-value cache + * \param k_data The flattened key cache + * \param v_data The flattened value cache + * \param kv_strides custom strides of each dimensions of k_data and v_data * \param indices The page indices array * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch * \param rope_pos_offset The start position of each request in the batch. + * \note This constructor should only be used when page_storage == kIndices */ __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, - uint32_t batch_size, QKVLayout layout, DType* kv_data, - IdType* indices, IdType* indptr, IdType* last_page_len, + uint32_t batch_size, QKVLayout layout, DType* k_data, + DType* v_data, const int64_t* kv_strides, IdType* indices, + IdType* indptr, IdType* last_page_len, IdType* rope_pos_offset = nullptr) : num_heads(num_heads), page_size(page_size), head_dim(head_dim), batch_size(batch_size), - k_data(kv_data), - v_data(kv_data + num_heads * page_size * head_dim), indices(indices), indptr(indptr), last_page_len(last_page_len), rope_pos_offset(rope_pos_offset) { - stride_page = 2 * num_heads * page_size * head_dim; - stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; - stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; - } - - __host__ __device__ __forceinline__ int64_t kv_ptr_delta() const { - return (int64_t(v_data) - int64_t(k_data)) / sizeof(DType); + stride_page = kv_strides[0]; + this->k_data = k_data; + this->v_data = v_data; + stride_n = layout == QKVLayout::kHND ? kv_strides[2] : kv_strides[1]; + stride_h = layout == QKVLayout::kHND ? kv_strides[1] : kv_strides[2]; } __host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const { @@ -220,16 +179,22 @@ struct paged_kv_t { return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); } - __device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx, - uint32_t entry_idx, uint32_t feat_idx, - IdType last_indptr) const { + __device__ __forceinline__ size_t protective_get_kv_offset(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { if (page_iter < last_indptr) { - return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + return get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); } else { - return k_data; + return 0; } } + __device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { + return k_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, feat_idx, last_indptr); + } + __device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx) const { return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); @@ -238,11 +203,7 @@ struct paged_kv_t { __device__ __forceinline__ DType* protective_get_v_ptr(IdType page_iter, uint32_t head_idx, uint32_t entry_idx, uint32_t feat_idx, IdType last_indptr) const { - if (page_iter < last_indptr) { - return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); - } else { - return v_data; - } + return v_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, feat_idx, last_indptr); } }; diff --git a/python/csrc/flashinfer_page_ops.cu b/python/csrc/flashinfer_page_ops.cu index 39caf24f..604c4156 100644 --- a/python/csrc/flashinfer_page_ops.cu +++ b/python/csrc/flashinfer_page_ops.cu @@ -16,9 +16,8 @@ #include void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor append_indptr, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor kv_indices, torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, unsigned int layout); diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 787aa1aa..2e002338 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -20,33 +20,24 @@ using namespace flashinfer; void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor append_indptr, torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor kv_indices, torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, unsigned int layout) { - bool paged_kv_defined = paged_kv_cache.has_value(); CHECK_INPUT(append_key); CHECK_INPUT(append_value); CHECK_INPUT(append_indptr); - if (paged_kv_defined) { - CHECK_INPUT(paged_kv_cache.value()); - } else { - CHECK_INPUT(paged_k_cache.value()); - CHECK_INPUT(paged_v_cache.value()); - } + // NOTE(Zihao): doesn't have to be contiguous + CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_k_cache); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_v_cache); CHECK_INPUT(kv_indices); CHECK_INPUT(kv_indptr); CHECK_INPUT(kv_last_page_len); CHECK_DIM(3, append_key); CHECK_DIM(3, append_value); CHECK_DIM(1, append_indptr); - if (paged_kv_defined) { - CHECK_DIM(5, paged_kv_cache.value()); - } else { - CHECK_DIM(4, paged_k_cache.value()); - CHECK_DIM(4, paged_v_cache.value()); - } + CHECK_DIM(4, paged_k_cache); + CHECK_DIM(4, paged_v_cache); CHECK_DIM(1, kv_indices); CHECK_DIM(1, kv_indptr); CHECK_DIM(1, kv_last_page_len); @@ -60,12 +51,8 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, auto device = append_indptr.device(); CHECK_EQ(append_key.device(), device); CHECK_EQ(append_value.device(), device); - if (paged_kv_defined) { - CHECK_EQ(paged_kv_cache->device(), device); - } else { - CHECK_EQ(paged_k_cache->device(), device); - CHECK_EQ(paged_v_cache->device(), device); - } + CHECK_EQ(paged_k_cache.device(), device); + CHECK_EQ(paged_v_cache.device(), device); CHECK_EQ(kv_indices.device(), device); CHECK_EQ(kv_indptr.device(), device); CHECK_EQ(kv_last_page_len.device(), device); @@ -73,24 +60,13 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, QKVLayout kv_layout = QKVLayout(layout); unsigned int num_heads, page_size, head_dim; - if (paged_kv_defined) { - head_dim = paged_kv_cache->size(4); - if (kv_layout == QKVLayout::kHND) { - num_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_heads = paged_kv_cache->size(3); - } + head_dim = paged_k_cache.size(3); + if (kv_layout == QKVLayout::kHND) { + num_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - head_dim = paged_k_cache->size(3); - if (kv_layout == QKVLayout::kHND) { - num_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_heads = paged_k_cache->size(2); - } + page_size = paged_k_cache.size(1); + num_heads = paged_k_cache.size(2); } CHECK_EQ(append_key.size(1), num_heads); @@ -100,17 +76,15 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto kv_scalar_dtype = - paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + auto kv_scalar_dtype = paged_k_cache.scalar_type(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] { - paged_kv_t paged_kv( - num_heads, page_size, head_dim, batch_size, kv_layout, - static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), - static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), - static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), - static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), - static_cast(kv_last_page_len.data_ptr())); + paged_kv_t paged_kv(num_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + static_cast(kv_last_page_len.data_ptr())); cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), static_cast(append_value.data_ptr()), diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 2526c940..7e05d6e5 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -211,10 +211,15 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) #define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") diff --git a/python/flashinfer/jit/batch_decode_templ.py b/python/flashinfer/jit/batch_decode_templ.py index 5ae58737..80daa551 100644 --- a/python/flashinfer/jit/batch_decode_templ.py +++ b/python/flashinfer/jit/batch_decode_templ.py @@ -66,8 +66,9 @@ torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, - torch::Tensor q, std::optional paged_kv_cache, - std::optional paged_k_cache, std::optional paged_v_cache, + torch::Tensor q, + torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, std::optional alibi_slopes, @@ -76,27 +77,16 @@ DecodePlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(kv_layout_code); - bool paged_kv_defined = paged_kv_cache.has_value(); auto device = q.device(); int64_t batch_size = q.size(0); int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; - if (paged_kv_defined) { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -111,15 +101,18 @@ void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); + paged_kv_t<{{ dtype_kv }}, {{ dtype_idx }}> paged_kv( num_kv_heads, page_size, {{ head_dim }}, batch_size, kv_layout, - static_cast<{{ dtype_kv }}*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast<{{ dtype_kv }} *>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast<{{ dtype_kv }}*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), + static_cast<{{ dtype_kv }}*>(paged_k_cache.data_ptr()), + static_cast<{{ dtype_kv }}*>(paged_v_cache.data_ptr()), + kv_cache_strides, static_cast<{{ dtype_idx }}*>(paged_kv_indices.data_ptr()), static_cast<{{ dtype_idx }}*>(paged_kv_indptr.data_ptr()), static_cast<{{ dtype_idx }}*>(paged_kv_last_page_len.data_ptr())); diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index a6b0c0ac..01923cc1 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -160,9 +160,8 @@ torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, std::vector plan_info_vec, torch::Tensor q, - std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, + torch::Tensor paged_k_cache, + torch::Tensor paged_v_cache, std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, torch::Tensor qo_indptr, @@ -175,27 +174,16 @@ PrefillPlanInfo plan_info; plan_info.FromVector(plan_info_vec); QKVLayout kv_layout = static_cast(layout); - bool paged_kv_defined = paged_kv_cache.has_value(); auto device = q.device(); int64_t batch_size = paged_kv_indptr.size(0) - 1; int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; - if (paged_kv_defined) { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(2); - page_size = paged_kv_cache->size(3); - } else { - page_size = paged_kv_cache->size(2); - num_kv_heads = paged_kv_cache->size(3); - } + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->size(1); - page_size = paged_k_cache->size(2); - } else { - page_size = paged_k_cache->size(1); - num_kv_heads = paged_k_cache->size(2); - } + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -209,15 +197,18 @@ void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); void* int_buffer_ptr = static_cast(int_workspace_buffer.data_ptr()); + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); + paged_kv_t<{{ dtype_kv }}, {{ dtype_idx }}> paged_kv( num_kv_heads, page_size, {{ head_dim }}, batch_size, kv_layout, - static_cast<{{ dtype_kv }}*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() - : nullptr), - static_cast<{{ dtype_kv }} *>(paged_k_cache.has_value() ? paged_k_cache->data_ptr() - : nullptr), - static_cast<{{ dtype_kv }}*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr() - : nullptr), + static_cast<{{ dtype_kv }}*>(paged_k_cache.data_ptr()), + static_cast<{{ dtype_kv }}*>(paged_v_cache.data_ptr()), + kv_cache_strides, static_cast<{{ dtype_idx }}*>(paged_kv_indices.data_ptr()), static_cast<{{ dtype_idx }}*>(paged_kv_indptr.data_ptr()), static_cast<{{ dtype_idx }}*>(paged_kv_last_page_len.data_ptr())); diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index acaed4a0..1e6a50de 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -100,16 +100,18 @@ def get_indptr(x: torch.Tensor) -> torch.Tensor: def _unpack_paged_kv_cache( paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], kv_layout: str, -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(paged_kv_cache, tuple): paged_k_cache, paged_v_cache = paged_kv_cache return ( - None, _expand_4d(paged_k_cache, kv_layout), _expand_4d(paged_v_cache, kv_layout), ) elif torch.is_tensor(paged_kv_cache): - return (_expand_5d(paged_kv_cache, kv_layout), None, None) + # NOTE(Zihao): split on the second dimension + paged_kv_cache = _expand_5d(paged_kv_cache, kv_layout) + paged_k_cache, paged_v_cache = paged_kv_cache.unbind(dim=1) + return paged_k_cache, paged_v_cache else: raise KeyError( "Unrecongized paged_kv_cache type {}, expect a single tensor or a tuple of tensor.".format( diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index d69d93ce..fc998ded 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -51,15 +51,16 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq); kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1); } - thrust::device_vector kv_data(num_pages * 2 * num_kv_heads * page_size * head_dim); + thrust::device_vector k_data(num_pages * num_kv_heads * page_size * head_dim); + thrust::device_vector v_data(num_pages * num_kv_heads * page_size * head_dim); thrust::device_vector kv_indptr(kv_indptr_host); thrust::device_vector kv_indices(kv_indicies_host); thrust::device_vector kv_last_page_len(kv_last_page_len_host); - paged_kv_t paged_kv(num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data.data()), - thrust::raw_pointer_cast(kv_indices.data()), - thrust::raw_pointer_cast(kv_indptr.data()), - thrust::raw_pointer_cast(kv_last_page_len.data())); + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()), + thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()), + thrust::raw_pointer_cast(kv_last_page_len.data())); // Allocate input data: thrust::device_vector q(batch_size * num_qo_heads * head_dim); thrust::device_vector o(batch_size * num_qo_heads * head_dim); @@ -113,15 +114,16 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq); kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1); } - thrust::device_vector kv_data(num_pages * 2 * num_kv_heads * page_size * head_dim); + thrust::device_vector k_data(num_pages * num_kv_heads * page_size * head_dim); + thrust::device_vector v_data(num_pages * num_kv_heads * page_size * head_dim); thrust::device_vector kv_indptr(kv_indptr_host); thrust::device_vector kv_indices(kv_indicies_host); thrust::device_vector kv_last_page_len(kv_last_page_len_host); - paged_kv_t paged_kv(num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data.data()), - thrust::raw_pointer_cast(kv_indices.data()), - thrust::raw_pointer_cast(kv_indptr.data()), - thrust::raw_pointer_cast(kv_last_page_len.data())); + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()), + thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()), + thrust::raw_pointer_cast(kv_last_page_len.data())); // Allocate input data: thrust::device_vector q(batch_size * num_qo_heads * head_dim); diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index d4b60488..94824e89 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -76,7 +76,8 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { std::vector q_h = std::move(testcase_float_data[0]), shared_k_h = std::move(testcase_float_data[1]), shared_v_h = std::move(testcase_float_data[2]), - kv_data_h = std::move(testcase_float_data[3]); + k_data_h = std::move(testcase_float_data[3]), + v_data_h = std::move(testcase_float_data[4]); std::vector kv_indices_combined_h = std::move(testcase_int_data[1]), kv_indices_unique_h = std::move(testcase_int_data[2]), @@ -85,10 +86,10 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { kv_last_page_len_combined_h = std::move(testcase_int_data[5]), kv_last_page_len_unique_h = std::move(testcase_int_data[6]); - thrust::device_vector kv_data_d(kv_data_h); + thrust::device_vector k_data_d(k_data_h), v_data_d(v_data_h); thrust::device_vector q_d(q_h); - state.add_global_memory_reads(kv_data_h.size() + q_h.size(), "Read"); + state.add_global_memory_reads(k_data_h.size() + v_data_h.size() + q_h.size(), "Read"); state.add_global_memory_writes(q_h.size(), "Write"); if (use_cascade) { @@ -102,7 +103,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { kv_last_page_len_unique_d(kv_last_page_len_unique_h); paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); @@ -163,7 +164,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { kv_last_page_len_combined_d(kv_last_page_len_combined_h); paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); @@ -214,7 +215,8 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { std::vector q_h = std::move(testcase_float_data[0]), shared_k_h = std::move(testcase_float_data[1]), shared_v_h = std::move(testcase_float_data[2]), - kv_data_h = std::move(testcase_float_data[3]); + k_data_h = std::move(testcase_float_data[3]), + v_data_h = std::move(testcase_float_data[4]); std::vector qo_indptr_h = std::move(testcase_int_data[0]), kv_indices_combined_h = std::move(testcase_int_data[1]), @@ -224,11 +226,11 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { kv_last_page_len_combined_h = std::move(testcase_int_data[5]), kv_last_page_len_unique_h = std::move(testcase_int_data[6]); - thrust::device_vector kv_data_d(kv_data_h); + thrust::device_vector k_data_d(k_data_h), v_data_d(k_data_h); thrust::device_vector q_d(q_h); thrust::device_vector qo_indptr_d(qo_indptr_h); - state.add_global_memory_reads(kv_data_h.size() + q_h.size(), "Read"); + state.add_global_memory_reads(k_data_h.size() + v_data_h.size() + q_h.size(), "Read"); state.add_global_memory_writes(q_h.size(), "Write"); if (use_cascade) { @@ -242,7 +244,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { kv_last_page_len_unique_d(kv_last_page_len_unique_h); paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); @@ -303,7 +305,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { kv_last_page_len_combined_d(kv_last_page_len_combined_h); paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 7862a03e..57c64329 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -37,7 +37,8 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si } std::vector q; std::vector o_ref; - std::vector kv_data; + std::vector k_data; + std::vector v_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; @@ -71,18 +72,21 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si kv_indices.push_back(page_counter++); } } - kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); - utils::vec_zero_(kv_data); + k_data.resize(page_counter * num_kv_heads * page_size * head_dim); + v_data.resize(page_counter * num_kv_heads * page_size * head_dim); + utils::vec_zero_(k_data); + utils::vec_zero_(v_data); assert(q.size() == batch_size * num_qo_heads * head_dim); assert(o_ref.size() == batch_size * num_qo_heads * head_dim); flashinfer::paged_kv_t paged_kv_cpu( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), - kv_indptr.data(), kv_last_page_len.data()); + num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data.data(), v_data.data(), + kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector k_data_device(k_data); + thrust::device_vector v_data_device(v_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); @@ -92,7 +96,8 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si // create paged_kv object flashinfer::paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_device.data()), + thrust::raw_pointer_cast(k_data_device.data()), + thrust::raw_pointer_cast(v_data_device.data()), thrust::raw_pointer_cast(kv_indices_device.data()), thrust::raw_pointer_cast(kv_indptr_device.data()), thrust::raw_pointer_cast(kv_last_page_len_device.data())); diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index c1cfdf35..69111902 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -38,7 +38,8 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n for (size_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } - std::vector kv_data; + std::vector k_data; + std::vector v_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; @@ -61,22 +62,24 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n } } - kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); + k_data.resize(page_counter * num_kv_heads * page_size * head_dim); + v_data.resize(page_counter * num_kv_heads * page_size * head_dim); flashinfer::paged_kv_t paged_kv_cpu( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), - kv_indptr.data(), kv_last_page_len.data()); + num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data.data(), v_data.data(), + kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector k_data_device(k_data); + thrust::device_vector v_data_device(v_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object flashinfer::paged_kv_t paged_kv = paged_kv_cpu; - paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); - paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); + paged_kv.k_data = thrust::raw_pointer_cast(k_data_device.data()); + paged_kv.v_data = thrust::raw_pointer_cast(v_data_device.data()); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); @@ -255,7 +258,8 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } - std::vector kv_data; + std::vector k_data; + std::vector v_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; @@ -277,22 +281,24 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si } } - kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); + k_data.resize(page_counter * num_kv_heads * page_size * head_dim); + v_data.resize(page_counter * num_kv_heads * page_size * head_dim); flashinfer::paged_kv_t paged_kv_cpu( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), - kv_indptr.data(), kv_last_page_len.data()); + num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data.data(), v_data.data(), + kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector k_data_device(k_data); + thrust::device_vector v_data_device(v_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object flashinfer::paged_kv_t paged_kv = paged_kv_cpu; - paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); - paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); + paged_kv.k_data = thrust::raw_pointer_cast(k_data_device.data()); + paged_kv.v_data = thrust::raw_pointer_cast(v_data_device.data()); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); @@ -380,7 +386,8 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } - std::vector kv_data; + std::vector k_data; + std::vector v_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; @@ -402,22 +409,24 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( } } - kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); + k_data.resize(page_counter * num_kv_heads * page_size * head_dim); + v_data.resize(page_counter * num_kv_heads * page_size * head_dim); flashinfer::paged_kv_t paged_kv_cpu( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, kv_data.data(), kv_indices.data(), - kv_indptr.data(), kv_last_page_len.data()); + num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data.data(), v_data.data(), + kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector k_data_device(k_data); + thrust::device_vector v_data_device(v_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object flashinfer::paged_kv_t paged_kv = paged_kv_cpu; - paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); - paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); + paged_kv.k_data = thrust::raw_pointer_cast(k_data_device.data()); + paged_kv.v_data = thrust::raw_pointer_cast(v_data_device.data()); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); @@ -497,7 +506,8 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz std::vector q_lens{33}, kv_lens{32768}; std::vector q_indptr{0, 33}; std::vector append_indptr{0, 32768}; - std::vector kv_data; + std::vector k_data; + std::vector v_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; @@ -515,22 +525,24 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz kv_indices.push_back(page_counter++); } - kv_data.resize(page_counter * 1 * 2 * num_kv_heads * page_size * head_dim); + k_data.resize(page_counter * 1 * num_kv_heads * page_size * head_dim); + v_data.resize(page_counter * 1 * num_kv_heads * page_size * head_dim); flashinfer::paged_kv_t paged_kv_cpu( - num_kv_heads, page_size, head_dim, 1, kv_layout, kv_data.data(), kv_indices.data(), - kv_indptr.data(), kv_last_page_len.data()); + num_kv_heads, page_size, head_dim, 1, kv_layout, k_data.data(), v_data.data(), + kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, {k}, {v}, append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector k_data_device(k_data); + thrust::device_vector v_data_device(v_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); // create paged_kv object flashinfer::paged_kv_t paged_kv = paged_kv_cpu; - paged_kv.k_data = thrust::raw_pointer_cast(kv_data_device.data()); - paged_kv.v_data = paged_kv.k_data + paged_kv_cpu.kv_ptr_delta(); + paged_kv.k_data = thrust::raw_pointer_cast(k_data_device.data()); + paged_kv.v_data = thrust::raw_pointer_cast(v_data_device.data()); paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 22b73bc2..4530ea9b 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -245,7 +245,8 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, std::vector q_h = std::move(testcase_float_data[0]), shared_k_h = std::move(testcase_float_data[1]), shared_v_h = std::move(testcase_float_data[2]), - kv_data_h = std::move(testcase_float_data[3]); + k_data_h = std::move(testcase_float_data[3]), + v_data_h = std::move(testcase_float_data[3]); std::vector kv_indices_combined_h = std::move(testcase_int_data[1]), kv_indices_unique_h = std::move(testcase_int_data[2]), @@ -254,8 +255,9 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, kv_last_page_len_combined_h = std::move(testcase_int_data[5]), kv_last_page_len_unique_h = std::move(testcase_int_data[6]); - thrust::device_vector shared_k_d(shared_k_h), shared_v_d(shared_v_h), kv_data_d(kv_data_h), - q_d(q_h), o_baseline_d(q_h.size()), o_cascade_0_d(q_h.size()), o_cascade_1_d(q_h.size()); + thrust::device_vector shared_k_d(shared_k_h), shared_v_d(shared_v_h), k_data_d(k_data_h), + v_data_d(v_data_h), q_d(q_h), o_baseline_d(q_h.size()), o_cascade_0_d(q_h.size()), + o_cascade_1_d(q_h.size()); thrust::device_vector tmp_0_d(16 * 1024 * 1024); thrust::device_vector lse_cascade_0_d(batch_size * num_qo_heads), lse_cascade_1_d(batch_size * num_qo_heads); @@ -268,14 +270,14 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); @@ -370,7 +372,8 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, std::vector q_h = std::move(testcase_float_data[0]), shared_k_h = std::move(testcase_float_data[1]), shared_v_h = std::move(testcase_float_data[2]), - kv_data_h = std::move(testcase_float_data[3]); + k_data_h = std::move(testcase_float_data[3]), + v_data_h = std::move(testcase_float_data[4]); std::vector qo_indptr_h = std::move(testcase_int_data[0]), kv_indices_combined_h = std::move(testcase_int_data[1]), @@ -380,8 +383,9 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, kv_last_page_len_combined_h = std::move(testcase_int_data[5]), kv_last_page_len_unique_h = std::move(testcase_int_data[6]); - thrust::device_vector shared_k_d(shared_k_h), shared_v_d(shared_v_h), kv_data_d(kv_data_h), - q_d(q_h), o_baseline_d(q_h.size()), o_cascade_0_d(q_h.size()), o_cascade_1_d(q_h.size()); + thrust::device_vector shared_k_d(shared_k_h), shared_v_d(shared_v_h), k_data_d(k_data_h), + v_data_d(v_data_h), q_d(q_h), o_baseline_d(q_h.size()), o_cascade_0_d(q_h.size()), + o_cascade_1_d(q_h.size()); thrust::device_vector tmp_0_d(16 * 1024 * 1024); thrust::device_vector lse_cascade_0_d((batch_size * qo_append_length) * num_qo_heads), lse_cascade_1_d((batch_size * qo_append_length) * num_qo_heads); @@ -394,14 +398,14 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, paged_kv_t paged_kv_baseline_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_combined_d.data()), thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); paged_kv_t paged_kv_casacde_d( num_kv_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_d.data()), + thrust::raw_pointer_cast(k_data_d.data()), thrust::raw_pointer_cast(v_data_d.data()), thrust::raw_pointer_cast(kv_indices_unique_d.data()), thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); diff --git a/src/test_page.cu b/src/test_page.cu index f7b5bacd..b36b143d 100644 --- a/src/test_page.cu +++ b/src/test_page.cu @@ -79,19 +79,22 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si } indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size()); } - paged_kv_t paged_kv_cpu(num_heads, page_size, head_dim, batch_size, kv_layout, - kv_data_cpu.data(), indices_cpu.data(), indptr_cpu.data(), - last_page_len.data()); + paged_kv_t paged_kv_cpu( + num_heads, page_size, head_dim, batch_size, kv_layout, + /*k_data=*/kv_data_cpu.data(), + /*v_data=*/kv_data_cpu.data() + page_size * num_heads * head_dim, indices_cpu.data(), + indptr_cpu.data(), last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); thrust::device_vector indptr_gpu(indptr_cpu); thrust::device_vector indices_gpu(indices_cpu); thrust::device_vector last_page_len_gpu(last_page_len); - paged_kv_t paged_kv_gpu(num_heads, page_size, head_dim, batch_size, kv_layout, - thrust::raw_pointer_cast(kv_data_gpu.data()), - thrust::raw_pointer_cast(indices_gpu.data()), - thrust::raw_pointer_cast(indptr_gpu.data()), - thrust::raw_pointer_cast(last_page_len_gpu.data())); + paged_kv_t paged_kv_gpu( + num_heads, page_size, head_dim, batch_size, kv_layout, + /*k_data=*/thrust::raw_pointer_cast(kv_data_gpu.data()), + /*v_data=*/thrust::raw_pointer_cast(kv_data_gpu.data()) + page_size * num_heads * head_dim, + thrust::raw_pointer_cast(indices_gpu.data()), thrust::raw_pointer_cast(indptr_gpu.data()), + thrust::raw_pointer_cast(last_page_len_gpu.data())); thrust::device_vector append_indptr_gpu(append_indptr); thrust::device_vector keys_gpu(append_indptr.back() * num_heads * head_dim); diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index d9793fa3..b61b01d8 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -244,7 +244,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(page_table_values->dtype, dtype_idx, { paged_kv_t cache( nhead_kv, page_size, nfeat, num_total_seqs, kv_layout, - static_cast(pages->data), + /*k_data=*/static_cast(pages->data), + /*v_data=*/static_cast(pages->data) + pages->strides[1], static_cast(page_table_values->data) + page_table_values->byte_offset / sizeof(dtype_idx), static_cast(page_table_indptr->data) + @@ -391,7 +392,8 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ output->dtype, dtype_out, {DISPATCH_TVM_CUDA_IDTYPE(page_table_values->dtype, dtype_idx, { paged_kv_t cache( nhead_kv, page_size, nfeat, num_total_seqs, kv_layout, - static_cast(pages->data), + /*k_data=*/static_cast(pages->data), + /*v_data=*/static_cast(pages->data) + pages->strides[1], static_cast(page_table_values->data) + page_table_values->byte_offset / sizeof(dtype_idx), static_cast(page_table_indptr->data) + diff --git a/src/utils.h b/src/utils.h index 6785180e..17501e19 100644 --- a/src/utils.h +++ b/src/utils.h @@ -150,26 +150,25 @@ create_shared_prefix_testcase_data(size_t batch_size, size_t shared_prefix_lengt std::vector kv_indices_combined_h(kv_indptr_combined_h.back()); std::vector kv_indices_unique_h(kv_indptr_unique_h.back()); - std::vector kv_data_h(num_pages * 2 * num_kv_heads * page_size * head_dim); + std::vector k_data_h(num_pages * num_kv_heads * page_size * head_dim); + std::vector v_data_h(num_pages * num_kv_heads * page_size * head_dim); uint32_t page_id = 0; for (; page_id < (shared_prefix_length / page_size); page_id++) { for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { - std::copy( - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, - kv_data_h.begin() + - (((page_id * 2 + 0) * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); - std::copy( - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, - kv_data_h.begin() + - (((page_id * 2 + 1) * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + std::copy(shared_k_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, + shared_k_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, + k_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + std::copy(shared_v_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, + shared_v_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, + v_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); } } for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { @@ -187,13 +186,11 @@ create_shared_prefix_testcase_data(size_t batch_size, size_t shared_prefix_lengt utils::vec_normal_(k); utils::vec_normal_(v); std::copy(k.begin(), k.end(), - kv_data_h.begin() + - (((page_id * 2 + 0) * num_kv_heads + head_idx) * page_size + entry_idx) * - head_dim); + k_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); std::copy(v.begin(), v.end(), - kv_data_h.begin() + - (((page_id * 2 + 1) * num_kv_heads + head_idx) * page_size + entry_idx) * - head_dim); + v_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); } } kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + @@ -202,7 +199,8 @@ create_shared_prefix_testcase_data(size_t batch_size, size_t shared_prefix_lengt } } return std::make_tuple>, std::vector>>( - {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), std::move(kv_data_h)}, + {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), std::move(k_data_h), + std::move(v_data_h)}, {std::move(qo_indptr), std::move(kv_indices_combined_h), std::move(kv_indices_unique_h), std::move(kv_indptr_combined_h), std::move(kv_indptr_unique_h), std::move(kv_last_page_len_combined_h), std::move(kv_last_page_len_unique_h)}); diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index aba07289..1c9e78f0 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -33,6 +33,7 @@ @pytest.mark.parametrize( "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) +@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -46,15 +47,33 @@ def test_batch_decode_with_paged_kv_cache( return_lse, q_dtype, kv_dtype, + contiguous_kv, ): q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0) - if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0) - ) + if kv_layout == "HND": + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0) + kv_data = kv_data_fp32.to(kv_dtype) + kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] + kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data.stride(-4) + != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + ) + else: + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0) + kv_data = kv_data_fp32.to(kv_dtype) kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq kv_indices = torch.arange(0, total_num_pages).to(0).int() kv_last_page_len = torch.full( @@ -77,9 +96,9 @@ def test_batch_decode_with_paged_kv_cache( q_data_type=q_dtype, ) if return_lse: - o, _ = wrapper.run_return_lse(q, kv_data.to(kv_dtype)) + o, _ = wrapper.run_return_lse(q, kv_data) else: - o = wrapper.run(q, kv_data.to(kv_dtype)) + o = wrapper.run(q, kv_data) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -87,13 +106,13 @@ def test_batch_decode_with_paged_kv_cache( qi = q[i] ki = torch.cat( [ - kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] + kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] if kv_layout == "HND" - else kv_data[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] + else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -102,13 +121,13 @@ def test_batch_decode_with_paged_kv_cache( ).to(kv_dtype) vi = torch.cat( [ - kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] + kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] if kv_layout == "HND" - else kv_data[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] + else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -139,6 +158,7 @@ def test_batch_decode_with_paged_kv_cache( @pytest.mark.parametrize( "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) +@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_batch_decode_with_tuple_paged_kv_cache( batch_size, kv_len, @@ -152,18 +172,39 @@ def test_batch_decode_with_tuple_paged_kv_cache( return_lse, q_dtype, kv_dtype, + contiguous_kv, ): q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - kv_data = tuple( - ( - torch.randn(total_num_pages, num_kv_heads, page_size, head_dim).to(0) - if kv_layout == "HND" - else torch.randn(total_num_pages, page_size, num_kv_heads, head_dim).to(0) - ) - for _ in range(2) - ) + if kv_layout == "HND": + kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = [ + torch.randn(*kv_shape, dtype=torch.float32).to(0) for _ in range(2) + ] + kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] + for i in range(2): + kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] + kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data[i].stride(-4) + != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] + ) + else: + kv_data_fp32 = [ + torch.randn(*kv_shape, dtype=torch.float32).to(0) for _ in range(2) + ] + kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] + kv_data = tuple(kv_data) kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq kv_indices = torch.arange(0, total_num_pages).to(0).int() kv_last_page_len = torch.full( @@ -186,11 +227,11 @@ def test_batch_decode_with_tuple_paged_kv_cache( q_data_type=q_dtype, ) if return_lse: - o, _ = wrapper.run_return_lse(q, tuple(map(lambda _: _.to(kv_dtype), kv_data))) + o, _ = wrapper.run_return_lse(q, kv_data) else: - o = wrapper.run(q, tuple(map(lambda _: _.to(kv_dtype), kv_data))) + o = wrapper.run(q, kv_data) - k_cache, v_cache = kv_data + k_cache, v_cache = kv_data_fp32 for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] @@ -213,6 +254,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( vi = torch.cat( [ v_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] + .to(torch.float32) # torch.cat does not support some fp8 types .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( @@ -247,6 +289,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( @pytest.mark.parametrize( "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) +@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_cuda_graph_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -258,16 +301,33 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( pos_encoding_mode, q_dtype, kv_dtype, + contiguous_kv, ): q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0) - if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0) - ) - kv_data_dtype = kv_data.to(kv_dtype) + if kv_layout == "HND": + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0) + kv_data = kv_data_fp32.to(kv_dtype) + kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] + kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data.stride(-4) + != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + ) + else: + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0) + kv_data = kv_data_fp32.to(kv_dtype) kv_indptr_host_warmup = torch.arange(0, batch_size + 1).int() kv_indices_host_warmup = torch.arange(0, batch_size).int() kv_last_page_len_host_warmup = torch.full( @@ -304,13 +364,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - o = wrapper.run(q, kv_data_dtype) + o = wrapper.run(q, kv_data) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - o = wrapper.run(q, kv_data_dtype) + o = wrapper.run(q, kv_data) # replay multiple times for i in range(1, min(4, num_pages_per_seq)): @@ -363,13 +423,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( qi = q[i] ki = torch.cat( [ - kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] + kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] if kv_layout == "HND" - else kv_data[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] + else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -378,13 +438,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( ).to(kv_dtype) vi = torch.cat( [ - kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] + kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] if kv_layout == "HND" - else kv_data[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] + else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index eb897834..c2ad5c15 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -32,6 +32,7 @@ @pytest.mark.parametrize("use_cuda_graph", [True]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) @pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_batch_prefill_with_paged_kv_cache( batch_size, kv_len, @@ -46,18 +47,34 @@ def test_batch_prefill_with_paged_kv_cache( use_cuda_graph, logits_soft_cap, return_lse, + contiguous_kv, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() - if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) - .to(0) - .half() - ) + if kv_layout == "HND": + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0) + kv_data = kv_data_fp32.half() + kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] + kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data.stride(-4) + != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + ) + else: + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32).to(0) + kv_data = kv_data_fp32.half() kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq kv_indices_cpu = torch.arange(0, total_num_pages).int() kv_last_page_len_cpu = torch.full( @@ -165,13 +182,15 @@ def test_batch_prefill_with_paged_kv_cache( qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] ki = torch.cat( [ - kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] + kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] + kv_data_fp32[ + kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i] + ] if kv_layout == "HND" - else kv_data[ + else kv_data_fp32[ kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], : ] ) @@ -179,16 +198,18 @@ def test_batch_prefill_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ) + ).half() vi = torch.cat( [ - kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] + kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] + kv_data_fp32[ + kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i] + ] if kv_layout == "HND" - else kv_data[ + else kv_data_fp32[ kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], : ] ) @@ -196,7 +217,7 @@ def test_batch_prefill_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ) + ).half() o_ref_i = flashinfer.single_prefill_with_kv_cache( qi, ki, @@ -222,6 +243,7 @@ def test_batch_prefill_with_paged_kv_cache( @pytest.mark.parametrize("use_cuda_graph", [False, True]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) @pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_batch_prefill_with_tuple_paged_kv_cache( batch_size, kv_len, @@ -236,21 +258,40 @@ def test_batch_prefill_with_tuple_paged_kv_cache( use_cuda_graph, logits_soft_cap, return_lse, + contiguous_kv, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - kv_data = tuple( - ( - torch.randn(total_num_pages, num_kv_heads, page_size, head_dim).to(0).half() - if kv_layout == "HND" - else torch.randn(total_num_pages, page_size, num_kv_heads, head_dim) - .to(0) - .half() - ) - for _ in range(2) - ) + if kv_layout == "HND": + kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = [ + torch.randn(*kv_shape, dtype=torch.float32).to(0) for _ in range(2) + ] + kv_data = [kv_data_fp32[i].half() for i in range(2)] + for i in range(2): + kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] + kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data[i].stride(-4) + != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] + ) + else: + kv_data_fp32 = [ + torch.randn(*kv_shape, dtype=torch.float32).to(0) for _ in range(2) + ] + kv_data = [kv_data_fp32[i].half() for i in range(2)] + kv_data = tuple(kv_data) kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq kv_indices_cpu = torch.arange(0, total_num_pages).int() kv_last_page_len_cpu = torch.full( @@ -351,7 +392,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( g.replay() - k_cache, v_cache = kv_data + k_cache, v_cache = kv_data_fp32 for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] @@ -370,7 +411,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ) + ).half() vi = torch.cat( [ v_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] @@ -385,7 +426,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ) + ).half() o_ref_i = flashinfer.single_prefill_with_kv_cache( qi, ki, @@ -409,6 +450,7 @@ def test_batch_prefill_with_tuple_paged_kv_cache( @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) @pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("contiguous_kv", [True, False]) def test_batch_prefill_with_paged_kv_cache_custom_mask( batch_size, kv_len, @@ -421,18 +463,31 @@ def test_batch_prefill_with_paged_kv_cache_custom_mask( pos_encoding_mode, logits_soft_cap, return_lse, + contiguous_kv, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size - kv_data = ( - torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() - if kv_layout == "HND" - else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) - .to(0) - .half() - ) + if kv_layout == "HND": + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data = torch.randn(*kv_shape, dtype=torch.float32).to(0).half() + kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data.stride(-4) + != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + ) + else: + kv_data = torch.randn(*kv_shape, dtype=torch.float32).to(0).half() kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq kv_indices = torch.arange(0, total_num_pages).to(0).int() kv_last_page_len = torch.full(