diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 13d7a54f..595a3858 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1804,7 +1804,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks - const int max_smem_per_threadblock = max_smem_per_sm / 2; + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; constexpr uint32_t num_warps_x = get_num_warps_x(); constexpr uint32_t num_warps_z = get_num_warps_z(); @@ -1949,7 +1950,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks - const int max_smem_per_threadblock = max_smem_per_sm / 2; + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_frags_z_reg = (HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && @@ -2089,7 +2091,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks - const int max_smem_per_threadblock = max_smem_per_sm / 2; + const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeIn) * 16) ? 2: 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_frags_z_reg = (HEAD_DIM >= 128 && num_frags_x == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama &&