From 2496f5b37028bfa268653a7953443e27922aa8db Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Sun, 28 Jul 2024 20:23:02 -0700 Subject: [PATCH] triton: cascade kernels (#396) --- python/flashinfer/triton/__init__.py | 1 + python/flashinfer/triton/cascade.py | 152 +++++++++++++++++++ python/flashinfer/triton/kernels/cascade.py | 159 ++++++++++++++++++++ python/flashinfer/triton/utils.py | 28 ++++ python/tests/test_triton_cascade.py | 81 ++++++++++ 5 files changed, 421 insertions(+) create mode 100644 python/flashinfer/triton/__init__.py create mode 100644 python/flashinfer/triton/cascade.py create mode 100644 python/flashinfer/triton/kernels/cascade.py create mode 100644 python/flashinfer/triton/utils.py create mode 100644 python/tests/test_triton_cascade.py diff --git a/python/flashinfer/triton/__init__.py b/python/flashinfer/triton/__init__.py new file mode 100644 index 00000000..290519dc --- /dev/null +++ b/python/flashinfer/triton/__init__.py @@ -0,0 +1 @@ +from . import cascade diff --git a/python/flashinfer/triton/cascade.py b/python/flashinfer/triton/cascade.py new file mode 100644 index 00000000..3a35b3ad --- /dev/null +++ b/python/flashinfer/triton/cascade.py @@ -0,0 +1,152 @@ +from typing import Optional + +import torch + +from .kernels.cascade import ( + merge_state_in_place_kernel, + merge_state_kernel, + merge_states_kernel, + variable_length_merge_states_kernel, +) +from .utils import check_device, check_dim, check_input, check_shape + + +def merge_state( + v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor +): + check_input(v_a) + check_input(s_a) + check_input(v_b) + check_input(s_b) + check_device([v_a, s_a, v_b, s_b]) + check_dim(3, v_a) + check_dim(2, s_a) + check_dim(3, v_b) + check_dim(2, s_b) + check_shape(v_a, v_b) + check_shape(s_a, s_b) + assert v_a.size(0) == s_a.size(0) + assert v_a.size(1) == s_b.size(1) + s_a = s_a.to(torch.float32) + s_b = s_b.to(torch.float32) + seq_len = v_a.size(0) + num_heads = v_a.size(1) + head_dim = v_a.size(2) + v_merged = torch.empty_like(v_a).to(s_a.device) + s_merged = torch.empty((seq_len, num_heads)).to(s_a.device) + bdx = head_dim + bdy = num_heads + + merge_state_kernel[lambda meta: (seq_len,)]( + v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy + ) + + return v_merged, s_merged + + +def merge_state_in_place( + v: torch.Tensor, + s: torch.Tensor, + v_other: torch.Tensor, + s_other: torch.Tensor, + mask: Optional[torch.Tensor] = None, +): + check_input(v) + check_input(s) + check_input(v_other) + check_input(s_other) + check_device([v, s, v_other, s_other]) + check_dim(3, v) + check_dim(2, s) + check_dim(3, v_other) + check_dim(2, s_other) + check_shape(v, v_other) + check_shape(s, s_other) + assert v.size(0) == s.size(0) + assert v.size(1) == s.size(1) + assert s.dtype == torch.float32 + assert s_other.dtype == torch.float32 + if mask is not None: + check_dim(1, mask) + assert v.size(0) == mask.size(0) + assert mask.device == device + seq_len = v.size(0) + num_heads = v.size(1) + head_dim = v.size(2) + + bdx = head_dim + bdy = num_heads + merge_state_in_place_kernel[(seq_len,)]( + v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy + ) + + +def merge_states(v: torch.Tensor, s: torch.Tensor): + check_input(v) + check_input(s) + check_device([v, s]) + check_dim(4, v) + check_dim(3, s) + assert v.size(0) == s.size(0) + assert v.size(1) == s.size(1) + assert v.size(2) == s.size(2) + seq_len = v.size(0) + num_index_sets = v.size(1) + num_heads = v.size(2) + head_dim = v.size(3) + s = s.to(torch.float32) + v_merged = torch.empty( + (seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device + ) + s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device) + + bdx = head_dim + bdy = num_heads + merge_states_kernel[(seq_len,)]( + v, + s, + v_merged, + s_merged, + num_index_sets, + num_heads, + head_dim, + bdx=bdx, + bdy=bdy, + ) + return v_merged, s_merged + + +def variable_length_merge_states( + v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor +): + check_input(v) + check_input(s) + check_device([v, s]) + check_dim(3, v) + check_dim(2, s) + assert v.size(0) == s.size(0) + assert v.size(1) == s.size(1) + seq_len = indptr.size(0) - 1 + num_heads = v.size(1) + head_dim = v.size(2) + s = s.to(torch.float32) + indptr = indptr.to(torch.int32) + v_merged = torch.empty( + (seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device + ) + s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device) + + bdx = head_dim + bdy = num_heads + variable_length_merge_states_kernel[(seq_len,)]( + v, + s, + indptr, + v_merged, + s_merged, + num_heads, + head_dim, + bdx=bdx, + bdy=bdy, + ) + return v_merged, s_merged diff --git a/python/flashinfer/triton/kernels/cascade.py b/python/flashinfer/triton/kernels/cascade.py new file mode 100644 index 00000000..855fb9b7 --- /dev/null +++ b/python/flashinfer/triton/kernels/cascade.py @@ -0,0 +1,159 @@ +import triton +import triton.language as tl + + +@triton.jit +def state_merge(o, m, d, other_o, other_m, other_d): + m_max = tl.maximum(m, other_m) + d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max) + o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max) + return o, m_max, d + + +@triton.jit +def state_normalize(o, m, d): + o = o / d + return o, m, d + + +@triton.jit +def state_get_lse(o, m, d): + return m + tl.log2(d) + + +@triton.jit +def merge_state_kernel( + v_a_ptr, + s_a_ptr, + v_b_ptr, + s_b_ptr, + v_merged_ptr, + s_merged_ptr, + num_heads, + head_dim, + bdx: tl.constexpr, + bdy: tl.constexpr, +): + pos = tl.program_id(axis=0) + for tx in tl.range(bdx): + for head_idx in tl.range(bdy): + s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx) + s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx) + + offsets = (pos * num_heads + head_idx) * head_dim + tx + v_a = tl.load(v_a_ptr + offsets) + v_b = tl.load(v_b_ptr + offsets) + + v_merged, s_max, d = state_merge( + o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1 + ) + v_merged, s_max, d = state_normalize(v_merged, s_max, d) + v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx + tl.store(v_merged_ptr + v_merged_offset, v_merged) + + if s_merged_ptr: + tl.store( + s_merged_ptr + pos * num_heads + head_idx, + tl.log2(d) + s_max, + ) + + +@triton.jit +def merge_state_in_place_kernel( + v_ptr, + s_ptr, + v_other_ptr, + s_other_ptr, + num_heads, + head_dim, + mask_ptr, + bdx: tl.constexpr, + bdy: tl.constexpr, +): + pos = tl.program_id(axis=0) + if mask_ptr: + if tl.load(mask_ptr + pos) == 0: + return + + for head_idx in tl.range(bdy): + s_val = tl.load(s_ptr + pos * num_heads + head_idx) + s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx) + s_max = tl.maximum(s_val, s_other_val) + s_val = tl.exp2(s_val - s_max) + s_other_val = tl.exp2(s_other_val - s_max) + scale = s_val / (s_val + s_other_val) + other_scale = s_other_val / (s_val + s_other_val) + for tx in tl.range(bdx): + offset = (pos * num_heads + head_idx) * head_dim + tx + v_vec = tl.load(v_ptr + offset) + v_other_vec = tl.load(v_other_ptr + offset) + v_vec = scale * v_vec + other_scale * v_other_vec + tl.store(v_ptr + offset, v_vec) + if s_ptr: + tl.store( + s_ptr + pos * num_heads + head_idx, + tl.log2(s_val + s_other_val) + s_max, + ) + + +@triton.jit +def merge_states_kernel( + v_ptr, + s_ptr, + v_merged_ptr, + s_merged_ptr, + num_index_sets, + num_heads, + head_dim, + bdx: tl.constexpr, + bdy: tl.constexpr, +): + pos = tl.program_id(axis=0) + + for tx in tl.range(bdx): + for head_idx in tl.range(bdy): + o, m, d = 0.0, -5e4, 1.0 + for iter in tl.range(num_index_sets): + s = tl.load( + s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx + ) + v = tl.load( + v_ptr + + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim + + tx + ) + o, m, d = state_merge(o, m, d, v, s, 1) + o, m, d = state_normalize(o, m, d) + tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o) + if s_merged_ptr: + tl.store( + s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d) + ) + + +@triton.jit +def variable_length_merge_states_kernel( + v_ptr, + s_ptr, + indptr, + v_merged_ptr, + s_merged_ptr, + num_heads, + head_dim, + bdx: tl.constexpr, + bdy: tl.constexpr, +): + pos = tl.program_id(axis=0) + for tx in tl.range(bdx): + for head_idx in tl.range(bdy): + o, m, d = 0.0, -5e4, 1.0 + for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)): + s = tl.load(s_ptr + iter * num_heads + head_idx) + v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx) + o, m, d = state_merge(o, m, d, v, s, 1) + o, m, d = state_normalize(o, m, d) + tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o) + if s_merged_ptr: + tl.store( + s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d) + ) diff --git a/python/flashinfer/triton/utils.py b/python/flashinfer/triton/utils.py new file mode 100644 index 00000000..799d5ab1 --- /dev/null +++ b/python/flashinfer/triton/utils.py @@ -0,0 +1,28 @@ +from typing import List + +import torch + + +def check_input(x: torch.Tensor): + assert x.is_cuda, f"{str(x)} must be a CUDA Tensor" + assert x.is_contiguous(), f"{str(x)} must be contiguous" + + +def check_dim(d, x: torch.Tensor): + assert x.dim() == d, f"{str(x)} must be a {d}D tensor" + + +def check_shape(a: torch.Tensor, b: torch.Tensor): + assert a.dim() == b.dim(), f"tensors should have same dim" + for i in range(a.dim()): + assert a.size(i) == b.size( + i + ), f"tensors shape mismatch, {a.size()} and {b.size()}" + + +def check_device(tensors: List[torch.Tensor]): + device = tensors[0].device + for t in tensors: + assert ( + t.device == device + ), f"All tensors should be on the same device, but got {device} and {t.device}" diff --git a/python/tests/test_triton_cascade.py b/python/tests/test_triton_cascade.py new file mode 100644 index 00000000..ae8f2b42 --- /dev/null +++ b/python/tests/test_triton_cascade.py @@ -0,0 +1,81 @@ +import pytest +import torch + +import flashinfer +import flashinfer.triton + + +@pytest.mark.parametrize("seq_len", [2048]) +@pytest.mark.parametrize("num_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +def test_merge_state(seq_len, num_heads, head_dim): + va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb) + v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb) + + assert torch.allclose(v_merged, v_merged_std, atol=1e-2) + assert torch.allclose(s_merged, s_merged_std, atol=1e-2) + + +@pytest.mark.parametrize("seq_len", [2048]) +@pytest.mark.parametrize("num_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +def test_merge_state_in_place(seq_len, num_heads, head_dim): + v = torch.randn(seq_len, num_heads, head_dim).half() + v_std = v.clone() + v, v_std = v.to("cuda:0"), v_std.to("cuda:0") + s = torch.randn(seq_len, num_heads, dtype=torch.float32) + s_std = s.clone() + s, s_std = s.to("cuda:0"), s_std.to("cuda:0") + v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other) + flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other) + + assert torch.allclose(v, v_std, atol=1e-2) + assert torch.allclose(s, s_std, atol=1e-2) + + +@pytest.mark.parametrize("seq_len", [2048]) +@pytest.mark.parametrize("num_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("num_states", [100]) +def test_merge_states(seq_len, num_states, num_heads, head_dim): + v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0") + s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0") + v_merged_std, s_merged_std = flashinfer.merge_states(v, s) + v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s) + + assert torch.allclose(v_merged, v_merged_std, atol=1e-2) + assert torch.allclose(s_merged, s_merged_std, atol=1e-2) + + +@pytest.mark.parametrize("seq_len", [2048]) +@pytest.mark.parametrize("num_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +def test_variable_length_merge_states(seq_len, num_heads, head_dim): + max_index_sets = 512 + lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,)) + indptr = [0] + for i in range(seq_len): + indptr.append(indptr[-1] + lengths[i]) + v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0") + s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0") + indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0") + v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states( + v, s, indptr + ) + for i in range(seq_len): + sub_v = v[indptr[i] : indptr[i + 1]] + sub_s = s[indptr[i] : indptr[i + 1]] + sub_v = torch.unsqueeze(sub_v, 0) + sub_s = torch.unsqueeze(sub_s, 0) + v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s) + v_merged_std = torch.squeeze(v_merged_std, 0) + s_merged_std = torch.squeeze(s_merged_std, 0) + assert v_merged[i].shape == v_merged_std.shape + assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2) + assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)