#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #include #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) #include #endif namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) template <> class Vectorized> { private: __m512d values; static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; public: using value_type = c10::complex; using size_type = int; static constexpr size_type size() { return 4; } Vectorized() {} Vectorized(__m512d v) : values(v) {} Vectorized(c10::complex val) { double real_value = val.real(); double imag_value = val.imag(); values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value, real_value, imag_value, real_value, imag_value); } Vectorized(c10::complex val1, c10::complex val2, c10::complex val3, c10::complex val4) { values = _mm512_setr_pd(val1.real(), val1.imag(), val2.real(), val2.imag(), val3.real(), val3.imag(), val4.real(), val4.imag()); } operator __m512d() const { return values; } template static Vectorized> blend(const Vectorized>& a, const Vectorized>& b) { // convert c10::complex index mask to V index mask: xy -> xxyy // NOLINTNEXTLINE(clang-diagnostic-warning) switch (mask) { case 0: return a; case 1: return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011 case 2: return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100 case 3: return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111 case 4: return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000 case 5: return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011 case 6: return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100 case 7: return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111 case 8: return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000 case 9: return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011 case 10: return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100 case 11: return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111 case 12: return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000 case 13: return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011 case 14: return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100 case 15: return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111 } return b; } static Vectorized> blendv(const Vectorized>& a, const Vectorized>& b, const Vectorized>& mask) { // convert c10::complex index mask to V index mask: xy -> xxyy auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values); auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ); return _mm512_mask_blend_pd(mmask, a.values, b.values); } template static Vectorized> arange(c10::complex base = 0., step_t step = static_cast(1)) { return Vectorized>(base, base + c10::complex(1)*step, base + c10::complex(2)*step, base + c10::complex(3)*step); } static Vectorized> set(const Vectorized>& a, const Vectorized>& b, int64_t count = size()) { switch (count) { case 0: return a; case 1: return blend<1>(a, b); case 2: return blend<3>(a, b); case 3: return blend<7>(a, b); } return b; } static Vectorized> loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm512_loadu_pd(reinterpret_cast(ptr)); __at_align__ double tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(2*size())) { tmp_values[i] = 0.0; } std::memcpy( tmp_values, reinterpret_cast(ptr), count * sizeof(c10::complex)); return _mm512_load_pd(tmp_values); } void store(void* ptr, int count = size()) const { if (count == size()) { _mm512_storeu_pd(reinterpret_cast(ptr), values); } else if (count > 0) { double tmp_values[2*size()]; _mm512_storeu_pd(reinterpret_cast(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); } } const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { __at_align__ c10::complex tmp[size()]; store(tmp); for (const auto i : c10::irange(size())) { tmp[i] = f(tmp[i]); } return loadu(tmp); } // AVX512 doesn't have horizontal add & horizontal sub instructions. // TODO: hadd_pd() & hsub_pd() may have scope for improvement. static inline __m512d hadd_pd(__m512d a, __m512d b) { __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); } static inline __m512d hsub_pd(__m512d a, __m512d b) { __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); } __m512d abs_2_() const { auto val_2 = _mm512_mul_pd(values, values); // a*a b*b return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b } __m512d abs_() const { return _mm512_sqrt_pd(abs_2_()); // abs abs } Vectorized> abs() const { const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); return _mm512_and_pd(abs_(), real_mask); // abs 0 } __m512d angle_() const { //angle = atan2(b/a) auto b_a = _mm512_permute_pd(values, 0x55); // b a return Sleef_atan2d8_u10(values, b_a); // 90-angle angle } Vectorized> angle() const { const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle return _mm512_and_pd(angle, real_mask); // angle 0 } Vectorized> sgn() const { auto abs = abs_(); auto zero = _mm512_setzero_pd(); auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ); auto mask_vec = _mm512_mask_set1_epi64(_mm512_castpd_si512(zero), mask, 0xFFFFFFFFFFFFFFFF); auto abs_val = Vectorized(abs); auto div = values / abs_val.values; // x / abs(x) return blendv(div, zero, _mm512_castsi512_pd(mask_vec)); } __m512d real_() const { const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); return _mm512_and_pd(values, real_mask); } Vectorized> real() const { return real_(); } __m512d imag_() const { const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF)); return _mm512_and_pd(values, imag_mask); } Vectorized> imag() const { return _mm512_permute_pd(imag_(), 0x55); //b a } __m512d conj_() const { const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); return _mm512_xor_pd(values, sign_mask); // a -b } Vectorized> conj() const { return conj_(); } Vectorized> log() const { // Most trigonomic ops use the log() op to improve complex number performance. return map(std::log); } Vectorized> log2() const { const __m512d log2_ = _mm512_set1_pd(std::log(2)); return _mm512_div_pd(log(), log2_); } Vectorized> log10() const { const __m512d log10_ = _mm512_set1_pd(std::log(10)); return _mm512_div_pd(log(), log10_); } Vectorized> log1p() const { AT_ERROR("not supported for complex numbers"); } Vectorized> asin() const { // asin(x) // = -i*ln(iz + sqrt(1 -z^2)) // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) const __m512d one = _mm512_set1_pd(1); auto conj = conj_(); auto b_a = _mm512_permute_pd(conj, 0x55); //-b a auto ab = _mm512_mul_pd(conj, b_a); //-ab -ab auto im = _mm512_add_pd(ab, ab); //-2ab -2ab auto val_2 = _mm512_mul_pd(values, values); // a*a b*b auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b b*b-a*a re = _mm512_sub_pd(one, re); auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); //sqrt(re + i*im) auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); //ln(iz + sqrt()) return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj(); //-i*ln() } Vectorized> acos() const { // acos(x) = pi/2 - asin(x) constexpr auto pi_2d = c10::pi / 2; const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0); return _mm512_sub_pd(pi_2, asin()); } Vectorized> atan() const; Vectorized> atan2(const Vectorized> &b) const { AT_ERROR("not supported for complex numbers"); } Vectorized> erf() const { AT_ERROR("not supported for complex numbers"); } Vectorized> erfc() const { AT_ERROR("not supported for complex numbers"); } Vectorized> exp() const { //exp(a + bi) // = exp(a)*(cos(b) + sin(b)i) auto exp = Sleef_expd8_u10(values); //exp(a) exp(b) exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) exp(a) auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55), sin_cos.x); //cos(b) sin(b) return _mm512_mul_pd(exp, cos_sin); } Vectorized> expm1() const { AT_ERROR("not supported for complex numbers"); } Vectorized> sin() const { return map(std::sin); } Vectorized> sinh() const { return map(std::sinh); } Vectorized> cos() const { return map(std::cos); } Vectorized> cosh() const { return map(std::cosh); } Vectorized> ceil() const { return _mm512_ceil_pd(values); } Vectorized> floor() const { return _mm512_floor_pd(values); } Vectorized> hypot(const Vectorized> &b) const { AT_ERROR("not supported for complex numbers"); } Vectorized> igamma(const Vectorized> &x) const { AT_ERROR("not supported for complex numbers"); } Vectorized> igammac(const Vectorized> &x) const { AT_ERROR("not supported for complex numbers"); } Vectorized> neg() const { auto zero = _mm512_setzero_pd(); return _mm512_sub_pd(zero, values); } Vectorized> nextafter(const Vectorized> &b) const { AT_ERROR("not supported for complex numbers"); } Vectorized> round() const { return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } Vectorized> tan() const { return map(std::tan); } Vectorized> tanh() const { return map(std::tanh); } Vectorized> trunc() const { return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vectorized> sqrt() const { return map(std::sqrt); } Vectorized> reciprocal() const; Vectorized> rsqrt() const { return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { __at_align__ c10::complex x_tmp[size()]; __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (const auto i : c10::irange(size())) { x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); } return loadu(x_tmp); } // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized> operator==(const Vectorized>& other) const { auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized> operator!=(const Vectorized>& other) const { auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ); return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF)); } Vectorized> operator<(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator<=(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator>(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator>=(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> eq(const Vectorized>& other) const; Vectorized> ne(const Vectorized>& other) const; Vectorized> lt(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> le(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> gt(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> ge(const Vectorized>& other) const { TORCH_CHECK(false, "not supported for complex numbers"); } }; template <> Vectorized> inline operator+(const Vectorized> &a, const Vectorized> &b) { return _mm512_add_pd(a, b); } template <> Vectorized> inline operator-(const Vectorized> &a, const Vectorized> &b) { return _mm512_sub_pd(a, b); } template <> Vectorized> inline operator*(const Vectorized> &a, const Vectorized> &b) { //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); auto ac_bd = _mm512_mul_pd(a, b); //ac bd auto d_c = _mm512_permute_pd(b, 0x55); //d c d_c = _mm512_xor_pd(sign_mask, d_c); //d -c auto ad_bc = _mm512_mul_pd(a, d_c); //ad -bc auto ret = Vectorized>::hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc return ret; } template <> Vectorized> inline operator/(const Vectorized> &a, const Vectorized> &b) { //re + im*i = (a + bi) / (c + di) //re = (ac + bd)/abs_2() //im = (bc - ad)/abs_2() const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); auto ac_bd = _mm512_mul_pd(a, b); //ac bd auto d_c = _mm512_permute_pd(b, 0x55); //d c d_c = _mm512_xor_pd(sign_mask, d_c); //-d c auto ad_bc = _mm512_mul_pd(a, d_c); //-ad bc auto re_im = Vectorized>::hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad return _mm512_div_pd(re_im, b.abs_2_()); } // reciprocal. Implement this here so we can use multiplication. inline Vectorized> Vectorized>::reciprocal() const{ //re + im*i = (a + bi) / (c + di) //re = (ac + bd)/abs_2() = c/abs_2() //im = (bc - ad)/abs_2() = d/abs_2() const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); auto c_d = _mm512_xor_pd(sign_mask, values); //c -d return _mm512_div_pd(c_d, abs_2_()); } inline Vectorized> Vectorized>::atan() const { // atan(x) = i/2 * ln((i + z)/(i - z)) const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); auto sum = Vectorized(_mm512_add_pd(i, values)); // a 1+b auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) return i_half*ln; // i/2*ln() } template <> Vectorized> inline maximum(const Vectorized>& a, const Vectorized>& b) { auto zero_vec = _mm512_set1_epi64(0); auto abs_a = a.abs_2_(); auto abs_b = b.abs_2_(); auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ); auto max = _mm512_mask_blend_pd(mask, a, b); // Exploit the fact that all-ones is a NaN. auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); return _mm512_or_pd(max, _mm512_castsi512_pd(isnan)); } template <> Vectorized> inline minimum(const Vectorized>& a, const Vectorized>& b) { auto zero_vec = _mm512_set1_epi64(0); auto abs_a = a.abs_2_(); auto abs_b = b.abs_2_(); auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ); auto min = _mm512_mask_blend_pd(mask, a, b); // Exploit the fact that all-ones is a NaN. auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, 0xFFFFFFFFFFFFFFFF); return _mm512_or_pd(min, _mm512_castsi512_pd(isnan)); } template <> Vectorized> inline operator&(const Vectorized>& a, const Vectorized>& b) { return _mm512_and_pd(a, b); } template <> Vectorized> inline operator|(const Vectorized>& a, const Vectorized>& b) { return _mm512_or_pd(a, b); } template <> Vectorized> inline operator^(const Vectorized>& a, const Vectorized>& b) { return _mm512_xor_pd(a, b); } inline Vectorized> Vectorized>::eq(const Vectorized>& other) const { return (*this == other) & Vectorized>(_mm512_set1_pd(1.0)); } inline Vectorized> Vectorized>::ne(const Vectorized>& other) const { return (*this != other) & Vectorized>(_mm512_set1_pd(1.0)); } #endif }}}