-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
68c3719
commit 2496f5b
Showing
5 changed files
with
421 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import cascade |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from .kernels.cascade import ( | ||
merge_state_in_place_kernel, | ||
merge_state_kernel, | ||
merge_states_kernel, | ||
variable_length_merge_states_kernel, | ||
) | ||
from .utils import check_device, check_dim, check_input, check_shape | ||
|
||
|
||
def merge_state( | ||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor | ||
): | ||
check_input(v_a) | ||
check_input(s_a) | ||
check_input(v_b) | ||
check_input(s_b) | ||
check_device([v_a, s_a, v_b, s_b]) | ||
check_dim(3, v_a) | ||
check_dim(2, s_a) | ||
check_dim(3, v_b) | ||
check_dim(2, s_b) | ||
check_shape(v_a, v_b) | ||
check_shape(s_a, s_b) | ||
assert v_a.size(0) == s_a.size(0) | ||
assert v_a.size(1) == s_b.size(1) | ||
s_a = s_a.to(torch.float32) | ||
s_b = s_b.to(torch.float32) | ||
seq_len = v_a.size(0) | ||
num_heads = v_a.size(1) | ||
head_dim = v_a.size(2) | ||
v_merged = torch.empty_like(v_a).to(s_a.device) | ||
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device) | ||
bdx = head_dim | ||
bdy = num_heads | ||
|
||
merge_state_kernel[lambda meta: (seq_len,)]( | ||
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy | ||
) | ||
|
||
return v_merged, s_merged | ||
|
||
|
||
def merge_state_in_place( | ||
v: torch.Tensor, | ||
s: torch.Tensor, | ||
v_other: torch.Tensor, | ||
s_other: torch.Tensor, | ||
mask: Optional[torch.Tensor] = None, | ||
): | ||
check_input(v) | ||
check_input(s) | ||
check_input(v_other) | ||
check_input(s_other) | ||
check_device([v, s, v_other, s_other]) | ||
check_dim(3, v) | ||
check_dim(2, s) | ||
check_dim(3, v_other) | ||
check_dim(2, s_other) | ||
check_shape(v, v_other) | ||
check_shape(s, s_other) | ||
assert v.size(0) == s.size(0) | ||
assert v.size(1) == s.size(1) | ||
assert s.dtype == torch.float32 | ||
assert s_other.dtype == torch.float32 | ||
if mask is not None: | ||
check_dim(1, mask) | ||
assert v.size(0) == mask.size(0) | ||
assert mask.device == device | ||
seq_len = v.size(0) | ||
num_heads = v.size(1) | ||
head_dim = v.size(2) | ||
|
||
bdx = head_dim | ||
bdy = num_heads | ||
merge_state_in_place_kernel[(seq_len,)]( | ||
v, s, v_other, s_other, num_heads, head_dim, mask, bdx=bdx, bdy=bdy | ||
) | ||
|
||
|
||
def merge_states(v: torch.Tensor, s: torch.Tensor): | ||
check_input(v) | ||
check_input(s) | ||
check_device([v, s]) | ||
check_dim(4, v) | ||
check_dim(3, s) | ||
assert v.size(0) == s.size(0) | ||
assert v.size(1) == s.size(1) | ||
assert v.size(2) == s.size(2) | ||
seq_len = v.size(0) | ||
num_index_sets = v.size(1) | ||
num_heads = v.size(2) | ||
head_dim = v.size(3) | ||
s = s.to(torch.float32) | ||
v_merged = torch.empty( | ||
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device | ||
) | ||
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device) | ||
|
||
bdx = head_dim | ||
bdy = num_heads | ||
merge_states_kernel[(seq_len,)]( | ||
v, | ||
s, | ||
v_merged, | ||
s_merged, | ||
num_index_sets, | ||
num_heads, | ||
head_dim, | ||
bdx=bdx, | ||
bdy=bdy, | ||
) | ||
return v_merged, s_merged | ||
|
||
|
||
def variable_length_merge_states( | ||
v: torch.Tensor, s: torch.Tensor, indptr: torch.Tensor | ||
): | ||
check_input(v) | ||
check_input(s) | ||
check_device([v, s]) | ||
check_dim(3, v) | ||
check_dim(2, s) | ||
assert v.size(0) == s.size(0) | ||
assert v.size(1) == s.size(1) | ||
seq_len = indptr.size(0) - 1 | ||
num_heads = v.size(1) | ||
head_dim = v.size(2) | ||
s = s.to(torch.float32) | ||
indptr = indptr.to(torch.int32) | ||
v_merged = torch.empty( | ||
(seq_len, num_heads, head_dim), dtype=v.dtype, device=v.device | ||
) | ||
s_merged = torch.empty((seq_len, num_heads), dtype=s.dtype, device=s.device) | ||
|
||
bdx = head_dim | ||
bdy = num_heads | ||
variable_length_merge_states_kernel[(seq_len,)]( | ||
v, | ||
s, | ||
indptr, | ||
v_merged, | ||
s_merged, | ||
num_heads, | ||
head_dim, | ||
bdx=bdx, | ||
bdy=bdy, | ||
) | ||
return v_merged, s_merged |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def state_merge(o, m, d, other_o, other_m, other_d): | ||
m_max = tl.maximum(m, other_m) | ||
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max) | ||
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max) | ||
return o, m_max, d | ||
|
||
|
||
@triton.jit | ||
def state_normalize(o, m, d): | ||
o = o / d | ||
return o, m, d | ||
|
||
|
||
@triton.jit | ||
def state_get_lse(o, m, d): | ||
return m + tl.log2(d) | ||
|
||
|
||
@triton.jit | ||
def merge_state_kernel( | ||
v_a_ptr, | ||
s_a_ptr, | ||
v_b_ptr, | ||
s_b_ptr, | ||
v_merged_ptr, | ||
s_merged_ptr, | ||
num_heads, | ||
head_dim, | ||
bdx: tl.constexpr, | ||
bdy: tl.constexpr, | ||
): | ||
pos = tl.program_id(axis=0) | ||
for tx in tl.range(bdx): | ||
for head_idx in tl.range(bdy): | ||
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx) | ||
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx) | ||
|
||
offsets = (pos * num_heads + head_idx) * head_dim + tx | ||
v_a = tl.load(v_a_ptr + offsets) | ||
v_b = tl.load(v_b_ptr + offsets) | ||
|
||
v_merged, s_max, d = state_merge( | ||
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1 | ||
) | ||
v_merged, s_max, d = state_normalize(v_merged, s_max, d) | ||
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx | ||
tl.store(v_merged_ptr + v_merged_offset, v_merged) | ||
|
||
if s_merged_ptr: | ||
tl.store( | ||
s_merged_ptr + pos * num_heads + head_idx, | ||
tl.log2(d) + s_max, | ||
) | ||
|
||
|
||
@triton.jit | ||
def merge_state_in_place_kernel( | ||
v_ptr, | ||
s_ptr, | ||
v_other_ptr, | ||
s_other_ptr, | ||
num_heads, | ||
head_dim, | ||
mask_ptr, | ||
bdx: tl.constexpr, | ||
bdy: tl.constexpr, | ||
): | ||
pos = tl.program_id(axis=0) | ||
if mask_ptr: | ||
if tl.load(mask_ptr + pos) == 0: | ||
return | ||
|
||
for head_idx in tl.range(bdy): | ||
s_val = tl.load(s_ptr + pos * num_heads + head_idx) | ||
s_other_val = tl.load(s_other_ptr + pos * num_heads + head_idx) | ||
s_max = tl.maximum(s_val, s_other_val) | ||
s_val = tl.exp2(s_val - s_max) | ||
s_other_val = tl.exp2(s_other_val - s_max) | ||
scale = s_val / (s_val + s_other_val) | ||
other_scale = s_other_val / (s_val + s_other_val) | ||
for tx in tl.range(bdx): | ||
offset = (pos * num_heads + head_idx) * head_dim + tx | ||
v_vec = tl.load(v_ptr + offset) | ||
v_other_vec = tl.load(v_other_ptr + offset) | ||
v_vec = scale * v_vec + other_scale * v_other_vec | ||
tl.store(v_ptr + offset, v_vec) | ||
if s_ptr: | ||
tl.store( | ||
s_ptr + pos * num_heads + head_idx, | ||
tl.log2(s_val + s_other_val) + s_max, | ||
) | ||
|
||
|
||
@triton.jit | ||
def merge_states_kernel( | ||
v_ptr, | ||
s_ptr, | ||
v_merged_ptr, | ||
s_merged_ptr, | ||
num_index_sets, | ||
num_heads, | ||
head_dim, | ||
bdx: tl.constexpr, | ||
bdy: tl.constexpr, | ||
): | ||
pos = tl.program_id(axis=0) | ||
|
||
for tx in tl.range(bdx): | ||
for head_idx in tl.range(bdy): | ||
o, m, d = 0.0, -5e4, 1.0 | ||
for iter in tl.range(num_index_sets): | ||
s = tl.load( | ||
s_ptr + (pos * num_index_sets + iter) * num_heads + head_idx | ||
) | ||
v = tl.load( | ||
v_ptr | ||
+ ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim | ||
+ tx | ||
) | ||
o, m, d = state_merge(o, m, d, v, s, 1) | ||
o, m, d = state_normalize(o, m, d) | ||
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o) | ||
if s_merged_ptr: | ||
tl.store( | ||
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d) | ||
) | ||
|
||
|
||
@triton.jit | ||
def variable_length_merge_states_kernel( | ||
v_ptr, | ||
s_ptr, | ||
indptr, | ||
v_merged_ptr, | ||
s_merged_ptr, | ||
num_heads, | ||
head_dim, | ||
bdx: tl.constexpr, | ||
bdy: tl.constexpr, | ||
): | ||
pos = tl.program_id(axis=0) | ||
for tx in tl.range(bdx): | ||
for head_idx in tl.range(bdy): | ||
o, m, d = 0.0, -5e4, 1.0 | ||
for iter in tl.range(tl.load(indptr + pos), tl.load(indptr + pos + 1)): | ||
s = tl.load(s_ptr + iter * num_heads + head_idx) | ||
v = tl.load(v_ptr + (iter * num_heads + head_idx) * head_dim + tx) | ||
o, m, d = state_merge(o, m, d, v, s, 1) | ||
o, m, d = state_normalize(o, m, d) | ||
tl.store(v_merged_ptr + (pos * num_heads + head_idx) * head_dim + tx, o) | ||
if s_merged_ptr: | ||
tl.store( | ||
s_merged_ptr + pos * num_heads + head_idx, state_get_lse(o, m, d) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import List | ||
|
||
import torch | ||
|
||
|
||
def check_input(x: torch.Tensor): | ||
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor" | ||
assert x.is_contiguous(), f"{str(x)} must be contiguous" | ||
|
||
|
||
def check_dim(d, x: torch.Tensor): | ||
assert x.dim() == d, f"{str(x)} must be a {d}D tensor" | ||
|
||
|
||
def check_shape(a: torch.Tensor, b: torch.Tensor): | ||
assert a.dim() == b.dim(), f"tensors should have same dim" | ||
for i in range(a.dim()): | ||
assert a.size(i) == b.size( | ||
i | ||
), f"tensors shape mismatch, {a.size()} and {b.size()}" | ||
|
||
|
||
def check_device(tensors: List[torch.Tensor]): | ||
device = tensors[0].device | ||
for t in tensors: | ||
assert ( | ||
t.device == device | ||
), f"All tensors should be on the same device, but got {device} and {t.device}" |
Oops, something went wrong.