#pragma once #include #include #include #include namespace at { // For FP16 or BFloat16 inputs, ops should perform internal math in FP32. template struct OpMathType { using type = scalar_t; }; template <> struct OpMathType { using type = float; }; template <> struct OpMathType { using type = float; }; template <> struct OpMathType> { using type = c10::complex; }; template using opmath_type = typename OpMathType::type; namespace { inline c10::ScalarType toOpMathType(const c10::ScalarType type) { switch (type) { #define DEFINE_CASE(scalar_t, TypeNum) \ case ScalarType::TypeNum: \ return CppTypeToScalarType>::value; AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); } } } // namespace } // namespace at