「みんなのプロコン 2018」: D - XOR XorY

,

https://beta.atcoder.jp/contests/yahoo-procon2018-qual/tasks/yahoo_procon2018_qual_d

本番中では間に合わなかったが解説なしで解くのは解けた。

solution

  • 制約は$a_i \oplus a_j \oplus A_{i, j} \in \{ X, Y \}$と書ける。これは$A_{i, j}’ = A_{i, j} \oplus X$と$Y’ = Y \oplus X$を考えることで $a_i \oplus a_j \oplus A_{i, j}’ \in { 0, Y’ }$とできる。つまり$X = 0$と仮定してよい。
  • 条件を細かく分けると次の$4$個。これらを全て満たせばよい。
    • 反射性みたいなもの: 対角成分 $A_{i, i} \in \{ 0, Y \}$ でなければならない。
    • 対称性みたいなもの: 対角成分 $A_{i, j} = A_{j, i}$ でなければならない。
    • 推移性みたいなもの: $a_i \oplus a_j \oplus A_{i, j} \in \{ 0, Y \}$と$a_j \oplus a_k \oplus A_{j, k} \in \{ 0, Y \}$が成り立っているとする。このとき $a_i \oplus a_k \oplus A_{i, k} \in \{ 0, Y \}$ が常成り立つことが容易に示せる。なので気にしなくてよい。
    • 生成元みたいなもの: $a_1$は固定されたものとする。任意の$j \ge 2$について$a_1 \oplus a_j \oplus A_{1, j} \in \{ 0, Y \}$でなければならない。他は適当にできるので、これだけを考えればよい。
  • $a_1$について総当たり。$a_2, \dots, a_N$が存在しうるかの判定は$x_\ast$の各元を同値関係 $a \sim b \iff a \oplus b \in \{ 0, Y \}$ で割った上でやる。その後 同値関係で割った部分を戻して数え上げる。数え上げ部分は愚直でよい。
  • 計算量は$O(N^2)$。

implementation

#include <bits/stdc++.h>
#define REP(i, n) for (int i = 0; (i) < int(n); ++ (i))
#define REP3(i, m, n) for (int i = (m); (i) < int(n); ++ (i))
#define ALL(x) begin(x), end(x)
using ll = long long;
using namespace std;
template <typename X, typename T> auto vectors(X x, T a) { return vector<T>(x, a); }
template <typename X, typename Y, typename Z, typename... Zs> auto vectors(X x, Y y, Z z, Zs... zs) { auto cont = vectors(y, z, zs...); return vector<decltype(cont)>(x, cont); }

ll powmod(ll x, ll y, ll m) {
    assert (0 <= x and x < m);
    assert (0 <= y);
    ll z = 1;
    for (ll i = 1; i <= y; i <<= 1) {
        if (y & i) z = z * x % m;
        x = x * x % m;
    }
    return z;
}
ll modinv(ll x, ll p) {
    assert (x % p != 0);
    return powmod(x, p - 2, p);
}
template <int32_t MOD>
int32_t fact(int n) {
    static vector<int32_t> memo(1, 1);
    while (n >= memo.size()) {
        memo.push_back(memo.back() *(int64_t) memo.size() % MOD);
    }
    return memo[n];
}
template <int32_t PRIME>
int32_t inv_fact(int n) {
    static vector<int32_t> memo(1, 1);
    while (n >= memo.size()) {
        memo.push_back(memo.back() *(int64_t) modinv(memo.size(), PRIME) % PRIME);
    }
    return memo[n];
}
template <int MOD>
int choose(int n, int r) {
    if (n < r) return 0;
    return fact<MOD>(n) *(ll) inv_fact<MOD>(n - r) % MOD *(ll) inv_fact<MOD>(r) % MOD;
}

constexpr int mod = 1e9 + 7;
int solve(int n, int k, int x, int y, vector<int> & z, vector<vector<int> > & a) {
    // prepare
    sort(ALL(z));
    REP (i, k) REP (j, k) {
        a[i][j] ^= x;
    }
    y ^= x;
    // count
    // the condition: ai ^ aj ^ Aij' \in { 0, Y' }
    REP (i, k) {
        if (a[i][i] != 0 and a[i][i] != y) {
            return 0;
        }
    }
    REP (i, k) REP (j, k) {
        int delta = a[i][j] ^ a[j][i];
        if (delta != 0 and delta != y) {
            return 0;
        }
    }
    map<int, int> freq; for (int z_i : z) freq[min(z_i, z_i ^ y)] += 1;
    map<int, int> merged; for (int z_i : z) if ((z_i ^ y) < z_i) merged[z_i ^ y] += 1;
    ll result = 0;
    for (auto it : freq) {
        vector<int> b(n, -1);
        b[0] = it.first;
        map<int, int> used;
        used[b[0]] += 1;
        REP3 (j, 1, k) {
            int c = a[0][j] ^ b[0];
            c = min(c, c ^ y);
            if (used[c] == freq[c]) {
                goto next;
            }
            used[c] += 1;
        }
        {
            ll acc = 1;
            for (auto it : used) {
                int n = it.second;
                int r1 = freq[it.first] - merged[it.first];
                int r2 = merged[it.first];
                ll acc2 = 0;
                REP (r, n + 1) if (r <= r1 and n - r <= r2) {
                    acc2 += choose<mod>(n, r);
                    acc2 %= mod;
                }
                acc *= acc2;
                acc %= mod;
            }
            result += acc;
            result %= mod;
        }
next: ;
    }
    return result;
}

int main() {
    // input
    int n, k, x, y; cin >> n >> k >> x >> y;
    vector<int> z(n);
    REP (i, n) cin >> z[i];
    auto a = vectors(k, k, int());
    REP (i, k) REP (j, k) cin >> a[i][j];
    // solve
    int result = solve(n, k, x, y, z, a);
    // output
    cout << result << endl;
    return 0;
}