#pragma once #include #include #include #include C10_CLANG_DIAGNOSTIC_PUSH() #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") #endif #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") #endif namespace c10 { template struct needs_real { constexpr static bool value = (is_complex::value && !is_complex::value); }; template struct maybe_real { C10_HOST_DEVICE static inline src_t apply(src_t src) { return src; } }; template struct maybe_real { C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { return src.real(); } }; // Note: deliberately ignores undefined behavior, consistent with NumPy. // PyTorch's type conversions can cause a variety of undefined behavior, // including float to integral overflow and signed to unsigned integer overflow. // Some of this undefined behavior is addressed below. template struct static_cast_with_inter_type { C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( src_t src) { constexpr bool real = needs_real::value; auto r = maybe_real::apply(src); return static_cast(r); } }; // Partial template instantiation for casting to uint8. // Note: Converting from negative float values to unsigned integer types is // undefined behavior in C++, and current CPU and GPU compilers exhibit // divergent behavior. Casting from negative float values to signed // integer types and then to unsigned integer types is not undefined, // however, so this cast improves the consistency of type conversions // to uint8 across compilers. // Further note: Type conversions across compilers still have other undefined // and divergent behavior. template struct static_cast_with_inter_type { C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( src_t src) { constexpr bool real = needs_real::value; return static_cast( static_cast(maybe_real::apply(src))); } }; template <> struct static_cast_with_inter_type, c10::BFloat16> { C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< c10::Half> apply(c10::BFloat16 src) { return static_cast>(c10::complex{src}); } }; template <> struct static_cast_with_inter_type, c10::Half> { C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< c10::Half> apply(c10::Half src) { return static_cast>(c10::complex{src}); } }; template <> struct static_cast_with_inter_type< c10::complex, c10::complex> { C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< c10::Half> apply(c10::complex src) { return static_cast>( static_cast>(src)); } }; template C10_HOST_DEVICE To convert(From f) { return static_cast_with_inter_type::apply(f); } // Define separately to avoid being inlined and prevent code-size bloat C10_API void report_overflow(const char* name); template To checked_convert(From f, const char* name) { // Converting to bool can't overflow so we exclude this case from checking. if (!std::is_same::value && overflows(f)) { report_overflow(name); } return convert(f); } } // namespace c10 C10_CLANG_DIAGNOSTIC_POP() // Trigger tests for D25440771. TODO: Remove this line any time you want.