#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include #include #include #include #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized> { private: __m256 values; public: using value_type = c10::complex; using size_type = int; static constexpr size_type size() { return 4; } Vectorized() {} Vectorized(__m256 v) : values(v) {} Vectorized(c10::complex val) { float real_value = val.real(); float imag_value = val.imag(); values = _mm256_setr_ps(real_value, imag_value, real_value, imag_value, real_value, imag_value, real_value, imag_value ); } Vectorized(c10::complex val1, c10::complex val2, c10::complex val3, c10::complex val4) { values = _mm256_setr_ps(val1.real(), val1.imag(), val2.real(), val2.imag(), val3.real(), val3.imag(), val4.real(), val4.imag() ); } operator __m256() const { return values; } template static Vectorized> blend(const Vectorized>& a, const Vectorized>& b) { // convert c10::complex index mask to V index mask: xy -> xxyy static_assert(mask > -1 && mask < 16, "Unexpected mask range"); switch (mask) { case 0: return a; case 1: return _mm256_blend_ps(a.values, b.values, 0x03); //b0000 0001 = b0000 0011 case 2: return _mm256_blend_ps(a.values, b.values, 0x0C); //b0000 0010 = b0000 1100 case 3: return _mm256_blend_ps(a.values, b.values, 0x0F); //b0000 0011 = b0000 1111 case 4: return _mm256_blend_ps(a.values, b.values, 0x30); //b0000 0100 = b0011 0000 case 5: return _mm256_blend_ps(a.values, b.values, 0x33); //b0000 0101 = b0011 0011 case 6: return _mm256_blend_ps(a.values, b.values, 0x3C); //b0000 0110 = b0011 1100 case 7: return _mm256_blend_ps(a.values, b.values, 0x3F); //b0000 0111 = b0011 1111 case 8: return _mm256_blend_ps(a.values, b.values, 0xC0); //b0000 1000 = b1100 0000 case 9: return _mm256_blend_ps(a.values, b.values, 0xC3); //b0000 1001 = b1100 0011 case 10: return _mm256_blend_ps(a.values, b.values, 0xCC); //b0000 1010 = b1100 1100 case 11: return _mm256_blend_ps(a.values, b.values, 0xCF); //b0000 1011 = b1100 1111 case 12: return _mm256_blend_ps(a.values, b.values, 0xF0); //b0000 1100 = b1111 0000 case 13: return _mm256_blend_ps(a.values, b.values, 0xF3); //b0000 1101 = b1111 0011 case 14: return _mm256_blend_ps(a.values, b.values, 0xFC); //b0000 1110 = b1111 1100 default: break; } return b; } static Vectorized> blendv(const Vectorized>& a, const Vectorized>& b, const Vectorized>& mask) { // convert c10::complex index mask to V index mask: xy -> xxyy auto mask_ = _mm256_unpacklo_ps(mask.values, mask.values); return _mm256_blendv_ps(a.values, b.values, mask_); } template static Vectorized> arange(c10::complex base = 0., step_t step = static_cast(1)) { return Vectorized>(base, base + step, base + c10::complex(2)*step, base + c10::complex(3)*step); } static Vectorized> set(const Vectorized>& a, const Vectorized>& b, int64_t count = size()) { switch (count) { case 0: return a; case 1: return blend<1>(a, b); case 2: return blend<3>(a, b); case 3: return blend<7>(a, b); } return b; } static Vectorized> loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm256_loadu_ps(reinterpret_cast(ptr)); __at_align__ float tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. for (const auto i : c10::irange(2*size())) { tmp_values[i] = 0.0; } std::memcpy( tmp_values, reinterpret_cast(ptr), count * sizeof(c10::complex)); return _mm256_load_ps(tmp_values); } void store(void* ptr, int count = size()) const { if (count == size()) { _mm256_storeu_ps(reinterpret_cast(ptr), values); } else if (count > 0) { float tmp_values[2*size()]; _mm256_storeu_ps(reinterpret_cast(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); } } const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { __at_align__ c10::complex tmp[size()]; store(tmp); for (const auto i : c10::irange(size())) { tmp[i] = f(tmp[i]); } return loadu(tmp); } __m256 abs_2_() const { auto val_2 = _mm256_mul_ps(values, values); // a*a b*b auto ret = _mm256_hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b return _mm256_permute_ps(ret, 0xD8); } __m256 abs_() const { return _mm256_sqrt_ps(abs_2_()); // abs abs } Vectorized> abs() const { const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); return _mm256_and_ps(abs_(), real_mask); // abs 0 } __m256 angle_() const { //angle = atan2(b/a) auto b_a = _mm256_permute_ps(values, 0xB1); // b a return Sleef_atan2f8_u10(values, b_a); // 90-angle angle } Vectorized> angle() const { const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); auto angle = _mm256_permute_ps(angle_(), 0xB1); // angle 90-angle return _mm256_and_ps(angle, real_mask); // angle 0 } Vectorized> sgn() const { auto abs = abs_(); auto zero = _mm256_setzero_ps(); auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ); auto abs_val = Vectorized(abs); auto div = values / abs_val.values; // x / abs(x) return _mm256_blendv_ps(div, zero, mask); } __m256 real_() const { const __m256 real_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); return _mm256_and_ps(values, real_mask); } Vectorized> real() const { return real_(); } __m256 imag_() const { const __m256 imag_mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF)); return _mm256_and_ps(values, imag_mask); } Vectorized> imag() const { return _mm256_permute_ps(imag_(), 0xB1); //b a } __m256 conj_() const { const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); return _mm256_xor_ps(values, sign_mask); // a -b } Vectorized> conj() const { return conj_(); } Vectorized> log() const { // Most trigonomic ops use the log() op to improve complex number performance. return map(std::log); } Vectorized> log2() const { const __m256 log2_ = _mm256_set1_ps(std::log(2)); return _mm256_div_ps(log(), log2_); } Vectorized> log10() const { const __m256 log10_ = _mm256_set1_ps(std::log(10)); return _mm256_div_ps(log(), log10_); } Vectorized> log1p() const { AT_ERROR("not supported for complex numbers"); } Vectorized> asin() const { // asin(x) // = -i*ln(iz + sqrt(1 -z^2)) // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) const __m256 one = _mm256_set1_ps(1); auto conj = conj_(); auto b_a = _mm256_permute_ps(conj, 0xB1); //-b a auto ab = _mm256_mul_ps(conj, b_a); //-ab -ab auto im = _mm256_add_ps(ab, ab); //-2ab -2ab auto val_2 = _mm256_mul_ps(values, values); // a*a b*b auto re = _mm256_hsub_ps(val_2, _mm256_permute_ps(val_2, 0xB1)); // a*a-b*b b*b-a*a re = _mm256_permute_ps(re, 0xD8); re = _mm256_sub_ps(one, re); auto root = Vectorized(_mm256_blend_ps(re, im, 0xAA)).sqrt(); //sqrt(re + i*im) auto ln = Vectorized(_mm256_add_ps(b_a, root)).log(); //ln(iz + sqrt()) return Vectorized(_mm256_permute_ps(ln.values, 0xB1)).conj(); //-i*ln() } Vectorized> acos() const { return map(std::acos); } Vectorized> atan() const; Vectorized> atan2(const Vectorized>& /*b*/) const { AT_ERROR("not supported for complex numbers"); } Vectorized> erf() const { AT_ERROR("not supported for complex numbers"); } Vectorized> erfc() const { AT_ERROR("not supported for complex numbers"); } Vectorized> exp() const { //exp(a + bi) // = exp(a)*(cos(b) + sin(b)i) auto exp = Sleef_expf8_u10(values); //exp(a) exp(b) exp = _mm256_blend_ps(exp, _mm256_permute_ps(exp, 0xB1), 0xAA); //exp(a) exp(a) auto sin_cos = Sleef_sincosf8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] auto cos_sin = _mm256_blend_ps(_mm256_permute_ps(sin_cos.y, 0xB1), sin_cos.x, 0xAA); //cos(b) sin(b) return _mm256_mul_ps(exp, cos_sin); } Vectorized> expm1() const { AT_ERROR("not supported for complex numbers"); } Vectorized> sin() const { return map(std::sin); } Vectorized> sinh() const { return map(std::sinh); } Vectorized> cos() const { return map(std::cos); } Vectorized> cosh() const { return map(std::cosh); } Vectorized> ceil() const { return _mm256_ceil_ps(values); } Vectorized> floor() const { return _mm256_floor_ps(values); } Vectorized> hypot(const Vectorized>& /*b*/) const { AT_ERROR("not supported for complex numbers"); } Vectorized> igamma(const Vectorized>& /*x*/) const { AT_ERROR("not supported for complex numbers"); } Vectorized> igammac(const Vectorized>& /*x*/) const { AT_ERROR("not supported for complex numbers"); } Vectorized> neg() const { auto zero = _mm256_setzero_ps(); return _mm256_sub_ps(zero, values); } Vectorized> nextafter(const Vectorized>& /*b*/) const { AT_ERROR("not supported for complex numbers"); } Vectorized> round() const { return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } Vectorized> tan() const { return map(std::tan); } Vectorized> tanh() const { return map(std::tanh); } Vectorized> trunc() const { return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } Vectorized> sqrt() const { return map(std::sqrt); } Vectorized> reciprocal() const; Vectorized> rsqrt() const { return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { __at_align__ c10::complex x_tmp[size()]; __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (const auto i : c10::irange(size())) { x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); } return loadu(x_tmp); } // Comparison using the _CMP_**_OQ predicate. // `O`: get false if an operand is NaN // `Q`: do not raise if an operand is NaN Vectorized> operator==(const Vectorized>& other) const { return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ); } Vectorized> operator!=(const Vectorized>& other) const { return _mm256_cmp_ps(values, other.values, _CMP_NEQ_UQ); } Vectorized> operator<(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator<=(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator>(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> operator>=(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> eq(const Vectorized>& other) const; Vectorized> ne(const Vectorized>& other) const; Vectorized> lt(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> le(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> gt(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } Vectorized> ge(const Vectorized>& /*other*/) const { TORCH_CHECK(false, "not supported for complex numbers"); } }; template <> Vectorized> inline operator+(const Vectorized> &a, const Vectorized> &b) { return _mm256_add_ps(a, b); } template <> Vectorized> inline operator-(const Vectorized> &a, const Vectorized> &b) { return _mm256_sub_ps(a, b); } template <> Vectorized> inline operator*(const Vectorized> &a, const Vectorized> &b) { //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); auto ac_bd = _mm256_mul_ps(a, b); //ac bd auto d_c = _mm256_permute_ps(b, 0xB1); //d c d_c = _mm256_xor_ps(sign_mask, d_c); //d -c auto ad_bc = _mm256_mul_ps(a, d_c); //ad -bc auto ret = _mm256_hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc ret = _mm256_permute_ps(ret, 0xD8); return ret; } template <> Vectorized> inline operator/(const Vectorized> &a, const Vectorized> &b) { //re + im*i = (a + bi) / (c + di) //re = (ac + bd)/abs_2() //im = (bc - ad)/abs_2() const __m256 sign_mask = _mm256_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); auto ac_bd = _mm256_mul_ps(a, b); //ac bd auto d_c = _mm256_permute_ps(b, 0xB1); //d c d_c = _mm256_xor_ps(sign_mask, d_c); //-d c auto ad_bc = _mm256_mul_ps(a, d_c); //-ad bc auto re_im = _mm256_hadd_ps(ac_bd, ad_bc);//ac + bd bc - ad re_im = _mm256_permute_ps(re_im, 0xD8); return _mm256_div_ps(re_im, b.abs_2_()); } // reciprocal. Implement this here so we can use multiplication. inline Vectorized> Vectorized>::reciprocal() const { //re + im*i = (a + bi) / (c + di) //re = (ac + bd)/abs_2() = c/abs_2() //im = (bc - ad)/abs_2() = d/abs_2() const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); auto c_d = _mm256_xor_ps(sign_mask, values); //c -d return _mm256_div_ps(c_d, abs_2_()); } inline Vectorized> Vectorized>::atan() const { // atan(x) = i/2 * ln((i + z)/(i - z)) const __m256 i = _mm256_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); const Vectorized i_half = _mm256_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); auto sum = Vectorized(_mm256_add_ps(i, values)); // a 1+b auto sub = Vectorized(_mm256_sub_ps(i, values)); // -a 1-b auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) return i_half*ln; // i/2*ln() } template <> Vectorized> inline maximum(const Vectorized>& a, const Vectorized>& b) { auto abs_a = a.abs_2_(); auto abs_b = b.abs_2_(); auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_LT_OQ); auto max = _mm256_blendv_ps(a, b, mask); // Exploit the fact that all-ones is a NaN. auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); return _mm256_or_ps(max, isnan); } template <> Vectorized> inline minimum(const Vectorized>& a, const Vectorized>& b) { auto abs_a = a.abs_2_(); auto abs_b = b.abs_2_(); auto mask = _mm256_cmp_ps(abs_a, abs_b, _CMP_GT_OQ); auto min = _mm256_blendv_ps(a, b, mask); // Exploit the fact that all-ones is a NaN. auto isnan = _mm256_cmp_ps(abs_a, abs_b, _CMP_UNORD_Q); return _mm256_or_ps(min, isnan); } template <> Vectorized> inline operator&(const Vectorized>& a, const Vectorized>& b) { return _mm256_and_ps(a, b); } template <> Vectorized> inline operator|(const Vectorized>& a, const Vectorized>& b) { return _mm256_or_ps(a, b); } template <> Vectorized> inline operator^(const Vectorized>& a, const Vectorized>& b) { return _mm256_xor_ps(a, b); } inline Vectorized> Vectorized>::eq( const Vectorized>& other) const { return (*this == other) & Vectorized>(_mm256_set1_ps(1.0f)); } inline Vectorized> Vectorized>::ne( const Vectorized>& other) const { return (*this != other) & Vectorized>(_mm256_set1_ps(1.0f)); } #endif }}}