#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #include #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized> { private: __m256d values; public: using value_type = c10::complex; using size_type = int; static constexpr size_type size() { return 2; } Vectorized() {} Vectorized(__m256d v) : values(v) {} Vectorized(c10::complex val) { double real_value = val.real(); double imag_value = val.imag(); values = _mm256_setr_pd(real_value, imag_value, real_value, imag_value); } Vectorized(c10::complex val1, c10::complex val2) { values = _mm256_setr_pd(val1.real(), val1.imag(), val2.real(), val2.imag()); } operator __m256d() const { return values; } template static Vectorized> blend(const Vectorized>& a, const Vectorized>& b) { // convert c10::complex index mask to V index mask: xy -> xxyy static_assert (mask > -1 && mask < 4, "Unexpected mask value"); switch (mask) { case 0: return a; case 1: return _mm256_blend_pd(a.values, b.values, 0x03); case 2: return _mm256_blend_pd(a.values, b.values, 0x0c); case 3: break; } return b; } static Vectorized> blendv(const Vectorized>& a, const Vectorized>& b, const Vectorized>& mask) { // convert c10::complex index mask to V index mask: xy -> xxyy auto mask_ = _mm256_unpacklo_pd(mask.values, mask.values); return _mm256_blendv_pd(a.values, b.values, mask_); } template static Vectorized> arange(c10::complex base = 0., step_t step = static_cast(1)) { return Vectorized>(base, base + step); } static Vectorized> set(const Vectorized>& a, const Vectorized>& b, int64_t count = size()) { switch (count) { case 0: return a; case 1: return blend<1>(a, b); } return b; } static Vectorized> loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm256_loadu_pd(reinterpret_cast(ptr)); __at_align__ double tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(2*size())) { tmp_values[i] = 0.0; } std::memcpy( tmp_values, reinterpret_cast(ptr), count * sizeof(c10::complex)); return _mm256_load_pd(tmp_values); } void store(void* ptr, int count = size()) const { if (count == size()) { _mm256_storeu_pd(reinterpret_cast(ptr), values); } else if (count > 0) { double tmp_values[2*size()]; _mm256_storeu_pd(reinterpret_cast(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); } } const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { __at_align__ c10::complex tmp[size()]; store(tmp); for (const auto i : c10::irange(size())) { tmp[i] = f(tmp[i]); } return loadu(tmp); } __m256d abs_2_() const { auto val_2 = _mm256_mul_pd(values, values); // a*a b*b return _mm256_hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b } __m256d abs_() const { return _mm256_sqrt_pd(abs_2_()); // abs abs } Vectorized> abs() const { const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); return _mm256_and_pd(abs_(), real_mask); // abs 0 } __m256d angle_() const { //angle = atan2(b/a) auto b_a = _mm256_permute_pd(values, 0x05); // b a return Sleef_atan2d4_u10(values, b_a); // 90-angle angle } Vectorized> angle() const { const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); auto angle = _mm256_permute_pd(angle_(), 0x05); // angle 90-angle return _mm256_and_pd(angle, real_mask); // angle 0 } Vectorized> sgn() const { auto abs = abs_(); auto zero = _mm256_setzero_pd(); auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ); auto abs_val = Vectorized(abs); auto div = values / abs_val.values; // x / abs(x) return blendv(div, zero, mask); } __m256d real_() const { const __m256d real_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); return _mm256_and_pd(values, real_mask); } Vectorized> real() const { return real_(); } __m256d imag_() const { const __m256d imag_mask = _mm256_castsi256_pd(_mm256_setr_epi64x(0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, 0xFFFFFFFFFFFFFFFF)); return _mm256_and_pd(values, imag_mask); } Vectorized> imag() const { return _mm256_permute_pd(imag_(), 0x05); //b a } __m256d conj_() const { const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); return _mm256_xor_pd(values, sign_mask); // a -b } Vectorized> conj() const { return conj_(); } Vectorized> log() const { // Most trigonomic ops use the log() op to improve complex number performance. return map(std::log); } Vectorized> log2() const { const __m256d log2_ = _mm256_set1_pd(std::log(2)); return _mm256_div_pd(log(), log2_); } Vectorized> log10() const { const __m256d log10_ = _mm256_set1_pd(std::log(10)); return _mm256_div_pd(log(), log10_); } Vectorized> log1p() const { AT_ERROR("not supported for complex numbers"); } Vectorized> asin() const { // asin(x) // = -i*ln(iz + sqrt(1 -z^2)) // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) const __m256d one = _mm256_set1_pd(1); auto conj = conj_(); auto b_a = _mm256_permute_pd(conj, 0x05); //-b a auto ab = _mm256_mul_pd(conj, b_a); //-ab -ab auto im = _mm256_add_pd(ab, ab); //-2ab -2ab auto val_2 = _mm256_mul_pd(values, values); // a*a b*b auto re = _mm256_hsub_pd(val_2, _mm256_permute_pd(val_2, 0x05)); // a*a-b*b b*b-a*a re = _mm256_sub_pd(one, re); auto root = Vectorized(_mm256_blend_pd(re, im, 0x0A)).sqrt(); //sqrt(re + i*im) auto ln = Vectorized(_mm256_add_pd(b_a, root)).log(); //ln(iz + sqrt()) return Vectorized(_mm256_permute_pd(ln.values, 0x05)).conj(); //-i*ln() } Vectorized> acos() const { // acos(x) = pi/2 - asin(x) constexpr auto pi_2d = c10::pi / 2; const __m256d pi_2 = _mm256_setr_pd(pi_2d, 0.0, pi_2d, 0.0); return _mm256_sub_pd(pi_2, asin()); } Vectorized> atan() const; Vectorized> atan2(const Vectorized>&) const { AT_ERROR("not supported for complex numbers"); } Vectorized> erf() const { AT_ERROR("not supported for complex numbers"); } Vectorized> erfc() const { AT_ERROR("not supported for complex numbers"); } Vectorized> exp() const { //exp(a + bi) // = exp(a)*(cos(b) + sin(b)i) auto exp = Sleef_expd4_u10(values); //exp(a) exp(b) exp = _mm256_blend_pd(exp, _mm256_permute_pd(exp, 0x05), 0x0A); //exp(a) exp(a) auto sin_cos = Sleef_sincosd4_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] auto cos_sin = _mm256_blend_pd(_mm256_permute_pd(sin_cos.y, 0x05), sin_cos.x, 0x0A); //cos(b) sin(b) return _mm256_mul_pd(exp, cos_sin); } Vectorized> expm1() const { AT_ERROR("not supported for complex numbers"); } Vectorized> sin() const { return map(std::sin); } Vectorized> sinh() const { return map(std::sinh); } Vectorized> cos() const { return map(std::cos); } Vectorized> cosh() const { return map(std::cosh); } Vectorized> ceil() const { return _mm256_ceil_pd(values); } Vectorized> floor() const { return _mm256_floor_pd(values); } Vectorized> hypot(const Vectorized> &) const { AT_ERROR("not supported for complex numbers"); } Vectorized> igamma(const Vectorized> &) const { AT_ERROR("not supported for complex numbers"); } Vectorized> igammac(const Vectorized> &) const { AT_ERROR("not supported for complex numbers"); } Vectorized> neg() const { auto zero = _mm256_setzero_pd(); return _mm256_sub_pd(zero, values); } Vectorized> nextafter(const Vectorized> &) const { AT_ERROR("not supported for complex numbers"); } Vectorized> round() const { return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } Vectorized> tan() const { return map(std::tan); } Vectorized> tanh() const { return map(std::tanh); } Vectorized> trunc() const { return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vectorized> sqrt() const { return map(std::sqrt); } Vectorized> reciprocal() const; Vectorized> rsqrt() const { return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { __at_align__ c10::complex x_tmp[size()]; __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (const auto i : c10::irange(size())) { x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); } return loadu(x_tmp); } // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized> operator==(const Vectorized>& other) const { return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); } Vectorized> operator!=(const Vectorized>& other) const { return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); } Vectorized> operator<(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator<=(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator>(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator>=(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> eq(const Vectorized>& other) const; Vectorized> ne(const Vectorized>& other) const; Vectorized> lt(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> le(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> gt(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> ge(const Vectorized>&) const { TORCH_CHECK(false, "not supported for complex numbers"); } }; template <> Vectorized> inline operator+(const Vectorized> &a, const Vectorized> &b) { return _mm256_add_pd(a, b); } template <> Vectorized> inline operator-(const Vectorized> &a, const Vectorized> &b) { return _mm256_sub_pd(a, b); } template <> Vectorized> inline operator*(const Vectorized> &a, const Vectorized> &b) { //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); auto ac_bd = _mm256_mul_pd(a, b); //ac bd auto d_c = _mm256_permute_pd(b, 0x05); //d c d_c = _mm256_xor_pd(sign_mask, d_c); //d -c auto ad_bc = _mm256_mul_pd(a, d_c); //ad -bc auto ret = _mm256_hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc return ret; } template <> Vectorized> inline operator/(const Vectorized> &a, const Vectorized> &b) { //re + im*i = (a + bi) / (c + di) //re = (ac + bd)/abs_2() //im = (bc - ad)/abs_2() const __m256d sign_mask = _mm256_setr_pd(-0.0, 0.0, -0.0, 0.0); auto ac_bd = _mm256_mul_pd(a, b); //ac bd auto d_c = _mm256_permute_pd(b, 0x05); //d c d_c = _mm256_xor_pd(sign_mask, d_c); //-d c auto ad_bc = _mm256_mul_pd(a, d_c); //-ad bc auto re_im = _mm256_hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad return _mm256_div_pd(re_im, b.abs_2_()); } // reciprocal. Implement this here so we can use multiplication. inline Vectorized> Vectorized>::reciprocal() const{ //re + im*i = (a + bi) / (c + di) //re = (ac + bd)/abs_2() = c/abs_2() //im = (bc - ad)/abs_2() = d/abs_2() const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); auto c_d = _mm256_xor_pd(sign_mask, values); //c -d return _mm256_div_pd(c_d, abs_2_()); } inline Vectorized> Vectorized>::atan() const { // atan(x) = i/2 * ln((i + z)/(i - z)) const __m256d i = _mm256_setr_pd(0.0, 1.0, 0.0, 1.0); const Vectorized i_half = _mm256_setr_pd(0.0, 0.5, 0.0, 0.5); auto sum = Vectorized(_mm256_add_pd(i, values)); // a 1+b auto sub = Vectorized(_mm256_sub_pd(i, values)); // -a 1-b auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) return i_half*ln; // i/2*ln() } template <> Vectorized> inline maximum(const Vectorized>& a, const Vectorized>& b) { auto abs_a = a.abs_2_(); auto abs_b = b.abs_2_(); auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_LT_OQ); auto max = _mm256_blendv_pd(a, b, mask); // Exploit the fact that all-ones is a NaN. auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); return _mm256_or_pd(max, isnan); } template <> Vectorized> inline minimum(const Vectorized>& a, const Vectorized>& b) { auto abs_a = a.abs_2_(); auto abs_b = b.abs_2_(); auto mask = _mm256_cmp_pd(abs_a, abs_b, _CMP_GT_OQ); auto min = _mm256_blendv_pd(a, b, mask); // Exploit the fact that all-ones is a NaN. auto isnan = _mm256_cmp_pd(abs_a, abs_b, _CMP_UNORD_Q); return _mm256_or_pd(min, isnan); } template <> Vectorized> inline operator&(const Vectorized>& a, const Vectorized>& b) { return _mm256_and_pd(a, b); } template <> Vectorized> inline operator|(const Vectorized>& a, const Vectorized>& b) { return _mm256_or_pd(a, b); } template <> Vectorized> inline operator^(const Vectorized>& a, const Vectorized>& b) { return _mm256_xor_pd(a, b); } inline Vectorized> Vectorized>::eq(const Vectorized>& other) const { return (*this == other) & Vectorized>(_mm256_set1_pd(1.0)); } inline Vectorized> Vectorized>::ne(const Vectorized>& other) const { return (*this != other) & Vectorized>(_mm256_set1_pd(1.0)); } #endif }}}