#pragma once #include #include #include namespace at { namespace cuda { template cudaDataType getCudaDataType() { TORCH_INTERNAL_ASSERT(false, "Cannot convert type ", typeid(scalar_t).name(), " to cudaDataType.") } template<> inline cudaDataType getCudaDataType() { return CUDA_R_16F; } template<> inline cudaDataType getCudaDataType() { return CUDA_R_32F; } template<> inline cudaDataType getCudaDataType() { return CUDA_R_64F; } template<> inline cudaDataType getCudaDataType>() { return CUDA_C_16F; } template<> inline cudaDataType getCudaDataType>() { return CUDA_C_32F; } template<> inline cudaDataType getCudaDataType>() { return CUDA_C_64F; } // HIP doesn't define integral types #ifndef USE_ROCM template<> inline cudaDataType getCudaDataType() { return CUDA_R_8U; } template<> inline cudaDataType getCudaDataType() { return CUDA_R_8I; } template<> inline cudaDataType getCudaDataType() { return CUDA_R_32I; } #endif #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000 template<> inline cudaDataType getCudaDataType() { return CUDA_R_16I; } template<> inline cudaDataType getCudaDataType() { return CUDA_R_64I; } template<> inline cudaDataType getCudaDataType() { return CUDA_R_16BF; } #endif inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) { switch (scalar_type) { // HIP doesn't define integral types #ifndef USE_ROCM case c10::ScalarType::Byte: return CUDA_R_8U; case c10::ScalarType::Char: return CUDA_R_8I; case c10::ScalarType::Int: return CUDA_R_32I; #endif case c10::ScalarType::Half: return CUDA_R_16F; case c10::ScalarType::Float: return CUDA_R_32F; case c10::ScalarType::Double: return CUDA_R_64F; case c10::ScalarType::ComplexHalf: return CUDA_C_16F; case c10::ScalarType::ComplexFloat: return CUDA_C_32F; case c10::ScalarType::ComplexDouble: return CUDA_C_64F; #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000 case c10::ScalarType::Short: return CUDA_R_16I; case c10::ScalarType::Long: return CUDA_R_64I; case c10::ScalarType::BFloat16: return CUDA_R_16BF; #endif default: TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") } } } // namespace cuda } // namespace at