Yukicoder No.202 1円玉投げ

,

http://yukicoder.me/problems/no/202

_mm_mullo_epi32を知らず_mm_mul_epi32と書いててWAで苦しんだ。 積は幅が倍になるのでmul mullo mulhiと$3$種ある。言われてみれば当然。

手元でclangの最適化全マシなら無修正の愚直コードが余裕で通る速度になるが、yukicoderはgccだけなので頑張る必要があった。

solution

愚直 + 定数倍高速化。$O(N)$。

SSEで書き直すだけで通る。速度は$3,4$倍ぐらいになる。 threadでの並列でもよいかもしれない。

implementation

http://yukicoder.me/submissions/157852

#pragma GCC optimize "O3"
#pragma GCC target "tune=native"
#pragma GCC target "avx"
#include <cstdio>
#include <algorithm>
#include <immintrin.h>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
using namespace std;
template <class T> inline T sq(T x) { return x*x; }
constexpr int max_n = 100000;
constexpr int r = 10;
__attribute__((aligned(32))) int32_t x[max_n];
__attribute__((aligned(32))) int32_t y[max_n];
__attribute__((aligned(32))) int32_t z[max_n]; // is_removed
int main() {
    int n; scanf("%d", &n);
    repeat (i,n) scanf("%d%d", &x[i], &y[i]);
    int i = 0;
    for (; i+3 < n; i += 4) {
        const __m128i xi = _mm_load_si128((__m128i *)(x + i));
        const __m128i yi = _mm_load_si128((__m128i *)(y + i));
        __m128i zi = _mm_setzero_si128();
        repeat (j,i) {
            __m128i xj = _mm_set1_epi32(x[j]);
            __m128i yj = _mm_set1_epi32(y[j]);
            __m128i zj = _mm_set1_epi32(not z[j]);
            xj = _mm_sub_epi32(xj, xi);
            yj = _mm_sub_epi32(yj, yi);
            xj = _mm_mullo_epi32(xj, xj);
            yj = _mm_mullo_epi32(yj, yj);
            const __m128i d = _mm_add_epi32(xj, yj);
            const __m128i e = _mm_set1_epi32(sq(r + r));
            __m128i cmp = _mm_cmplt_epi32(d, e);
            zi = _mm_add_epi32(zi, _mm_mullo_epi32(cmp, zj));
        }
        _mm_store_si128((__m128i *)(z + i), zi);
        repeat (di, 4) {
            if (z[i+di]) {
                z[i+di] = 1;
            } else {
                repeat (dj, di) {
                    if (not z[i+dj] and sq(x[i+dj] - x[i+di]) + sq(y[i+dj] - y[i+di]) < sq(r + r) ) {
                        z[i+di] = 1;
                        break;
                    }
                }
            }
        }
    }
    for (; i < n; ++ i) {
        repeat (j,i) {
            if (not z[j] and sq(x[j] - x[i]) + sq(y[j] - y[i]) < sq(r + r) ) {
                z[i] = 1;
                break;
            }
        }
    }
    int cnt = count(z, z + n, 0);
    printf("%d\n", cnt);
    return 0;
}