Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Sampling kernel only support FP32 now? #531

Open
yz-tang opened this issue Oct 15, 2024 · 1 comment
Open

[Question] Sampling kernel only support FP32 now? #531

yz-tang opened this issue Oct 15, 2024 · 1 comment

Comments

@yz-tang
Copy link
Contributor

yz-tang commented Oct 15, 2024

I found test_sampling.cu, there is only for FP32 test。I try use FP16, It not work.

@yzh119
Copy link
Collaborator

yzh119 commented Oct 16, 2024

It's easy to add support for fp16:

In

cudaError_t status = sampling::SamplingFromProb(static_cast<float*>(probs.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
(and all other functions in this file), we cast all inputs to fp32
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
cudaError_t status = sampling::SamplingFromProb(static_cast<float*>(probs.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
. To use fp16 kernels, we just need to dispatch different data types using the dispatch macro (
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
)

But as you mentioned, fp16 might fail some extreme cases because the fp16 probabilities might not sum up to 1 anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants