Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: JIT compilation #507

Merged
merged 65 commits into from
Oct 7, 2024
Merged

feat: JIT compilation #507

merged 65 commits into from
Oct 7, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Sep 25, 2024

This PR implements the JIT compilation (#170 ) of flashinfer, after this PR, flashinfer will compile kernels just-in-time for different input data types and shapes, and cached the kernels at the disk, instead of pre-compile a set of kernels in the wheel.

We also provide AOT mode (which should be installed from https://github.com/flashinfer-ai/flashinfer/tree/main/flashinfer-aot) which pre-compiles a set of flashinfer operators for production environment (see #510 ). In AOT mode, we use pre-compiled operators whenever possible, and only JIT compiles kernels that are not pre-compiled.

Motivation

The pip wheel size is exploding as we add support to more data types, more head dimensions, more attention variants and more kernel implementation. Pre-compile everything is not sustainable, and impedes development speed.

This PR refactors the codebase to use torch's JIT Compiling Extensions feature instead of pre-compile kernels in the wheel.

Attention Variants

We learned from FlexAttention and describes every attention variant as a template class, each instance of the struct can carry some closure variable defined in local memory or shared memory, below are two examples (logits soft cap and alibi attention, the programming interface is tentative and will be updated as we improve the programmability of the JIT template):

template <typename ParamsT>
struct LogitsSoftCap {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ LogitsSoftCap(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return params.logits_soft_cap * math::log2e * float(math::tanh(logits));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};

template <typename ParamsT>
struct ALIBIAttention {
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;
  using IdType = typename ParamsT::IdType;

  uint32_t qo_len, kv_len;
  uint32_t window_left;

  __device__ __host__ ALIBIAttention(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.sm_scale * math::log2e;
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx,
                                               uint32_t qo_idx, uint32_t kv_idx,
                                               uint32_t qo_head_idx, uint32_t kv_head_idx) {
    return logits + params.alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx));
  }

  __device__ __forceinline__ bool LogitsMask(const ParamsT& params, uint32_t batch_idx,
                                             uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx,
                                             uint32_t kv_head_idx) {
    return true;
  }
};

User can customize their own ParamsT class and variants class to define their own attention variants, we hope such refactor will make the codebase more concise and extensive.

Roadmap

After this PR, we will add support for:

  1. PyPI wheels Downloadable Package in PyPI #153
  2. fp8 tensor cores attention: Does Flashinfer support 8-bit attention calculation? #502
  3. different head dimensions: [Feature Request] Versatile head dimension #142 [Tentative] Adding 192 head dim (step_size = 12) #454 failed to dispatch head_dim 96 #455
  4. flashattention3 Feature: Flash Attention 3 #369
  5. multi-head latency attention Support MLA (Multi-Head Latent Attention) in DeepSeek-v2 #237
  6. Generate ParamsT and Attention variants description from python dsl

The development of this features have been blocked by the limitation of wheel size (binary size >= 2GB will trigger some linking issues), I hope this PR will make development easier in the future.

@yzh119 yzh119 merged commit 3613a5b into main Oct 7, 2024
@yzh119 yzh119 deleted the jit branch October 11, 2024 23:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant