#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) #include #endif namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) template <> class Vectorized { private: static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; public: // values needs to be public for compilation with clang // as vec512.h uses it __m512d values; using value_type = double; using size_type = int; static constexpr size_type size() { return 8; } Vectorized() {} Vectorized(__m512d v) : values(v) {} Vectorized(double val) { values = _mm512_set1_pd(val); } Vectorized(double val1, double val2, double val3, double val4, double val5, double val6, double val7, double val8) { values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8); } operator __m512d() const { return values; } template static Vectorized blend(const Vectorized& a, const Vectorized& b) { return _mm512_mask_blend_pd(mask, a.values, b.values); } static Vectorized blendv(const Vectorized& a, const Vectorized& b, const Vectorized& mask) { auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ); return _mm512_mask_blend_pd(mmask, a.values, b.values); } template static Vectorized arange(double base = 0., step_t step = static_cast(1)) { return Vectorized(base, base + step, base + 2 * step, base + 3 * step, base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step); } static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { switch (count) { case 0: return a; case 1: return blend<1>(a, b); case 2: return blend<3>(a, b); case 3: return blend<7>(a, b); case 4: return blend<15>(a, b); case 5: return blend<31>(a, b); case 6: return blend<63>(a, b); case 7: return blend<127>(a, b); } return b; } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm512_loadu_pd(reinterpret_cast(ptr)); __at_align__ double tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(size())) { tmp_values[i] = 0.0; } std::memcpy( tmp_values, reinterpret_cast(ptr), count * sizeof(double)); return _mm512_load_pd(tmp_values); } void store(void* ptr, int count = size()) const { if (count == size()) { _mm512_storeu_pd(reinterpret_cast(ptr), values); } else if (count > 0) { double tmp_values[size()]; _mm512_storeu_pd(reinterpret_cast(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(double)); } } const double& operator[](int idx) const = delete; double& operator[](int idx) = delete; int zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ); return static_cast(cmp); } Vectorized isnan() const { auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized map(double (*const f)(double)) const { __at_align__ double tmp[size()]; store(tmp); for (const auto i : c10::irange(size())) { tmp[i] = f(tmp[i]); } return loadu(tmp); } Vectorized abs() const { auto mask = _mm512_set1_pd(-0.f); return _mm512_andnot_pd(mask, values); } Vectorized angle() const { const auto zero_vec = _mm512_castsi512_pd(zero_vector); const auto nan_vec = _mm512_set1_pd(NAN); const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ); const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask, 0xFFFFFFFFFFFFFFFF); const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan), zero_vec, _CMP_EQ_OQ); const auto pi = _mm512_set1_pd(c10::pi); const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ); auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi); angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec); return angle; } Vectorized real() const { return *this; } Vectorized imag() const { return _mm512_set1_pd(0); } Vectorized conj() const { return *this; } Vectorized acos() const { return Vectorized(Sleef_acosd8_u10(values)); } Vectorized asin() const { return Vectorized(Sleef_asind8_u10(values)); } Vectorized atan() const { return Vectorized(Sleef_atand8_u10(values)); } Vectorized atan2(const Vectorized &b) const { return Vectorized(Sleef_atan2d8_u10(values, b)); } Vectorized copysign(const Vectorized &sign) const { return Vectorized(Sleef_copysignd8(values, sign)); } Vectorized erf() const { return Vectorized(Sleef_erfd8_u10(values)); } Vectorized erfc() const { return Vectorized(Sleef_erfcd8_u15(values)); } Vectorized erfinv() const { return map(calc_erfinv); } Vectorized exp() const { return Vectorized(Sleef_expd8_u10(values)); } Vectorized expm1() const { return Vectorized(Sleef_expm1d8_u10(values)); } Vectorized fmod(const Vectorized& q) const { return Vectorized(Sleef_fmodd8(values, q)); } Vectorized hypot(const Vectorized &b) const { return Vectorized(Sleef_hypotd8_u05(values, b)); } Vectorized i0() const { return map(calc_i0); } Vectorized i0e() const { return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { __at_align__ double tmp[size()]; __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (const auto i : c10::irange(size())) { tmp[i] = calc_igamma(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { __at_align__ double tmp[size()]; __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (const auto i : c10::irange(size())) { tmp[i] = calc_igammac(tmp[i], tmp_x[i]); } return loadu(tmp); } Vectorized log() const { return Vectorized(Sleef_logd8_u10(values)); } Vectorized log2() const { return Vectorized(Sleef_log2d8_u10(values)); } Vectorized log10() const { return Vectorized(Sleef_log10d8_u10(values)); } Vectorized log1p() const { return Vectorized(Sleef_log1pd8_u10(values)); } Vectorized sin() const { return Vectorized(Sleef_sind8_u10(values)); } Vectorized sinh() const { return Vectorized(Sleef_sinhd8_u10(values)); } Vectorized cos() const { return Vectorized(Sleef_cosd8_u10(values)); } Vectorized cosh() const { return Vectorized(Sleef_coshd8_u10(values)); } Vectorized ceil() const { return _mm512_ceil_pd(values); } Vectorized floor() const { return _mm512_floor_pd(values); } Vectorized frac() const; Vectorized neg() const { return _mm512_xor_pd(_mm512_set1_pd(-0.), values); } Vectorized nextafter(const Vectorized &b) const { return Vectorized(Sleef_nextafterd8(values, b)); } Vectorized round() const { return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } Vectorized tan() const { return Vectorized(Sleef_tand8_u10(values)); } Vectorized tanh() const { return Vectorized(Sleef_tanhd8_u10(values)); } Vectorized trunc() const { return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vectorized lgamma() const { return Vectorized(Sleef_lgammad8_u10(values)); } Vectorized sqrt() const { return _mm512_sqrt_pd(values); } Vectorized reciprocal() const { return _mm512_div_pd(_mm512_set1_pd(1), values); } Vectorized rsqrt() const { return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values)); } Vectorized pow(const Vectorized &b) const { return Vectorized(Sleef_powd8_u10(values, b)); } // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized operator==(const Vectorized& other) const { auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized operator!=(const Vectorized& other) const { auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized operator<(const Vectorized& other) const { auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized operator<=(const Vectorized& other) const { auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized operator>(const Vectorized& other) const { auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized operator>=(const Vectorized& other) const { auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized eq(const Vectorized& other) const; Vectorized ne(const Vectorized& other) const; Vectorized lt(const Vectorized& other) const; Vectorized le(const Vectorized& other) const; Vectorized gt(const Vectorized& other) const; Vectorized ge(const Vectorized& other) const; }; template <> Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { return _mm512_add_pd(a, b); } template <> Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { return _mm512_sub_pd(a, b); } template <> Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { return _mm512_mul_pd(a, b); } template <> Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { return _mm512_div_pd(a, b); } // frac. Implement this here so we can use subtraction. inline Vectorized Vectorized::frac() const { return *this - this->trunc(); } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { auto zero_vec = _mm512_set1_epi64(0); Vectorized max = _mm512_max_pd(a, b); auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); // Exploit the fact that all-ones is a NaN. return _mm512_or_pd(max, isnan); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template <> Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { auto zero_vec = _mm512_set1_epi64(0); Vectorized min = _mm512_min_pd(a, b); auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF)); // Exploit the fact that all-ones is a NaN. return _mm512_or_pd(min, isnan); } template <> Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { return _mm512_min_pd(max, _mm512_max_pd(min, a)); } template <> Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { return _mm512_max_pd(min, a); } template <> Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { return _mm512_min_pd(max, a); } template <> Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { return _mm512_and_pd(a, b); } template <> Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { return _mm512_or_pd(a, b); } template <> Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { return _mm512_xor_pd(a, b); } inline Vectorized Vectorized::eq(const Vectorized& other) const { return (*this == other) & Vectorized(1.0); } inline Vectorized Vectorized::ne(const Vectorized& other) const { return (*this != other) & Vectorized(1.0); } inline Vectorized Vectorized::gt(const Vectorized& other) const { return (*this > other) & Vectorized(1.0); } inline Vectorized Vectorized::ge(const Vectorized& other) const { return (*this >= other) & Vectorized(1.0); } inline Vectorized Vectorized::lt(const Vectorized& other) const { return (*this < other) & Vectorized(1.0); } inline Vectorized Vectorized::le(const Vectorized& other) const { return (*this <= other) & Vectorized(1.0); } template <> inline void convert(const double* src, double* dst, int64_t n) { int64_t i; #pragma unroll for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); } #pragma unroll for (; i < n; i++) { dst[i] = src[i]; } } template <> Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return _mm512_fmadd_pd(a, b, c); } #endif }}}