AtCoder Grand Contest 010: E - Rearranging

,

http://agc010.contest.atcoder.jp/tasks/agc010_e

適当にやればできるという感じがあったので流れで書いたらなんとなく通った。ただしそのような書き方の常として、不注意なバグを埋めて苦しんだ。

solution

連結成分ごとに順序を立てて、仕上げに挿入ソート。$O(N^2)$。

互いに素でない数同士を辺で結んでグラフを作る。これは非連結であり、制約の範囲は各連結成分内で閉じている。 連結成分内で最も端に持ってきたい数を決めて、それからのDFSの訪問順に並べれば上手くいく。 DFSでなくBFSとかだとだめで、反例としては2 6 30 10みたいな合流するもの。

後攻の行うのはただの整列なのでそのようにやる。挿入ソートがよい。ただし貪欲に見るだけだと不足することに注意。

implementation

#include <iostream>
#include <vector>
#include <algorithm>
#include <functional>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
int gcd(int a, int b) { while (a) { b %= a; swap(a, b); } return b; }
bool is_swappable(int a, int b) { return gcd(a, b) == 1; }
int main() {
    // input
    int n; cin >> n;
    vector<int> a(n); repeat (i,n) cin >> a[i];
    // rearrange
    whole(sort, a);
    vector<int> b;
    vector<bool> used(n);
    function<void (int)> go = [&](int i) {
        used[i] = true;
        b.push_back(a[i]);
        repeat (j,n) if (not used[j] and not is_swappable(a[i], a[j])) {
            go(j);
        }
    };
    repeat (i,n) if (not used[i]) {
        go(i);
    }
    // insertion sort
    repeat (i,n) {
        int j = i;
        for (int k = i-1; k >= 0 and is_swappable(b[k], b[i]); -- k) {
            if (b[k] < b[i]) j = k;
        }
        rotate(b.begin() + j, b.begin() + i, b.begin() + i + 1);
    }
    // output
    for (auto it : b) cout << it << ' '; cout << endl;
    return 0;
}

AtCoder Grand Contest 010: D - Decrementing

,

1000点にしてはかなり簡単だった。 この回はBもCも難しくて未提出撤退をした記憶があるのになあ。

solution

最大公約数で割る操作が勝敗に影響するのは偶数で割るときだけ。$O(N \log A_i)$。

まず最大公約数で割る操作を行わないとして考えよう。 この場合は単純で、$\sum_i (A_i - 1)$の偶奇が必勝手番を決定する。

最大公約数で割る操作はこの$\sum_i (A_i - 1)$の偶奇を入れ替えうる。 最大公約数$g$が奇数のとき、$kg - k = k(g - 1)$は偶数なので$\sum_i (A_i - 1)$の偶奇は不変。 よって$g$が偶数の場合だけに注目すればよい。 また、先手であれ後手であれ、相手と同様に数字を選んでいけば$g$が偶数の割る操作は回避できる。 例外として、先攻の初手においてはこれが発生しうる。

$\sum_i (A_i - 1)$の偶奇を確認し先手必勝でないなら偶数で割ることを試みる、を再帰的にやればよい。

implementation

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
int gcd(int a, int b) { while (a) { b %= a; swap(a, b); } return b; }
bool solve(int n, vector<int> & a) {
    ll sum = whole(accumulate, a, 0ll);
    bool is_first = (sum - n) % 2;
    int even = whole(count_if, a, [&](int ai) { return ai % 2 == 0; });
    auto odd = whole(find_if,  a, [&](int ai) { return ai % 2 == 1; });
    if (not is_first and even == n-1 and *odd != 1) {
        -- (*odd);
        int d = a[0]; repeat (i,n) d = gcd(d, a[i]);
        repeat (i,n) a[i] /= d;
        return not solve(n, a);
    } else {
        return is_first;
    }
}
int main() {
    int n; cin >> n;
    vector<int> a(n); repeat (i,n) cin >> a[i];
    cout << (solve(n, a) ? "First" : "Second") << endl;
    return 0;
}

auxiliary vectorをdumpしてみる

,

プログラムがそのentry pointから実行を始めるとき、そのstack上にはargvenvpの中身の文字列が積まれている。 これらの他にauxiliary vectorという値が積まれている。 通常はlibcのgetauxval関数を用いて取得するが、プログラム開始時のstackの構造への理解のため直接これを出力させた。 なお、同様の出力はLD_SHOW_AUXV環境変数を用いても得られる。

結果

64bit

$ gcc a.c
$ ./a.out
AT_SYSINFO_EHDR : 0x7ffe0e5ef000
AT_HWCAP : 0xbfebfbff
AT_PAGESZ : 0x1000
AT_CLKTCK : 0x64
AT_PHDR : 0x400040
AT_PHENT : 0x38
AT_PHNUM : 0x9
AT_BASE : 0x7f4ad46ee000
AT_FLAGS : (nil)
AT_ENTRY : 0x4004c0
AT_UID : 0x3e8
AT_EUID : 0x3e8
AT_GID : 0x3e8
AT_EGID : 0x3e8
AT_SECURE : (nil)
AT_RANDOM : 0x7ffe0e5e8d59
AT_EXECFN : "./a.out"
AT_PLATFORM : "x86_64"
AT_NULL : (nil)

AT_EXECFN等の参照先はenvpのそれらと同様にstack中に存在する。

32bit

$ sed -i~ s/Elf64/Elf32/g aux.c
$ gcc -m32 a.c
$ ./a.out
AT_SYSINFO : 0xf7795be0
AT_SYSINFO_EHDR : 0xf7795000
AT_HWCAP : 0xbfebfbff
AT_PAGESZ : 0x1000
AT_CLKTCK : 0x64
AT_PHDR : 0x8048034
AT_PHENT : 0x20
AT_PHNUM : 0x9
AT_BASE : 0xf7796000
AT_FLAGS : (nil)
AT_ENTRY : 0x8048370
AT_UID : 0x3e8
AT_EUID : 0x3e8
AT_GID : 0x3e8
AT_EGID : 0x3e8
AT_SECURE : (nil)
AT_RANDOM : 0xfff1938b
AT_EXECFN : "./a.out"
AT_PLATFORM : "i686"
AT_NULL : (nil)

実装

手元の環境(Ubuntu 16.04)では、Elf*_auxv_t/usr/include/elf.hに、AT_*の定義は/usr/include/x86_64-linux-gnu/bits/auxv.hまたは/usr/include/linux/auxvec.hにあった。

#include <stdlib.h>
#include <stdio.h>
#include <elf.h>
#include <x86_64-linux-gnu/bits/auxv.h> // <linux/auxvec.h>
const char *strauxvtype(uint64_t a_type) {
    switch (a_type) {
        case AT_NULL:           return "AT_NULL";
        case AT_IGNORE:         return "AT_IGNORE";
        case AT_EXECFD:         return "AT_EXECFD";
        case AT_PHDR:           return "AT_PHDR";
        case AT_PHENT:          return "AT_PHENT";
        case AT_PHNUM:          return "AT_PHNUM";
        case AT_PAGESZ:         return "AT_PAGESZ";
        case AT_BASE:           return "AT_BASE";
        case AT_FLAGS:          return "AT_FLAGS";
        case AT_ENTRY:          return "AT_ENTRY";
        case AT_NOTELF:         return "AT_NOTELF";
        case AT_UID:            return "AT_UID";
        case AT_EUID:           return "AT_EUID";
        case AT_GID:            return "AT_GID";
        case AT_EGID:           return "AT_EGID";
        case AT_CLKTCK:         return "AT_CLKTCK";
        case AT_PLATFORM:       return "AT_PLATFORM";
        case AT_HWCAP:          return "AT_HWCAP";
        case AT_FPUCW:          return "AT_FPUCW";
        case AT_DCACHEBSIZE:    return "AT_DCACHEBSIZE";
        case AT_ICACHEBSIZE:    return "AT_ICACHEBSIZE";
        case AT_UCACHEBSIZE:    return "AT_UCACHEBSIZE";
        case AT_IGNOREPPC:      return "AT_IGNOREPPC";
        case AT_SECURE:         return "AT_SECURE";
        case AT_BASE_PLATFORM:  return "AT_BASE_PLATFORM";
        case AT_RANDOM:         return "AT_RANDOM";
        case AT_HWCAP2:         return "AT_HWCAP2";
        case AT_EXECFN:         return "AT_EXECFN";
        case AT_SYSINFO:        return "AT_SYSINFO";
        case AT_SYSINFO_EHDR:   return "AT_SYSINFO_EHDR";
        case AT_L1I_CACHESHAPE: return "AT_L1I_CACHESHAPE";
        case AT_L1D_CACHESHAPE: return "AT_L1D_CACHESHAPE";
        case AT_L2_CACHESHAPE:  return "AT_L2_CACHESHAPE";
        case AT_L3_CACHESHAPE:  return "AT_L3_CACHESHAPE";
        default: { char *p = malloc(32); sprintf(p, "(%d)", a_type); return p; }
    }
}
Elf64_auxv_t *auxv_from_envp(char **envp) {
    char **p = envp;
    while (*p) ++ p;
    ++ p;
    return (Elf64_auxv_t *)p;
}
int main(int argc, char **argv, char **envp) {
    Elf64_auxv_t *auxv = auxv_from_envp(envp);
    while (1) {
        printf("%s : ", strauxvtype(auxv->a_type));
        printf(auxv->a_type == AT_EXECFN || auxv->a_type == AT_PLATFORM ? "\"%s\"\n" : "%p\n", auxv->a_un.a_val);
        if (auxv->a_type == AT_NULL) break;
        ++ auxv;
    }
    return 0;
}

参考


ELFの.interp sectionを書き換えてその挙動を確認してみる

,

ELFにおいて、共有libraryのlinkはOSでなくheader内で指定されてlinkerが行う。 このlinkerを指定する文字列を書き換え、その挙動を確認した。

準備

ELFのprogram headerのtypeとして、PT_INTERPがある。 これは単一のbinary中に高々$1$つまで存在し、そのsegment内の文字列としてinterpreterを指定する。 interpreterが指定されているとき、本体がloadされるより先にそのinterpreterがloadされる。 用途としては共有libraryの準備であり、その場合interpreterが本体プログラムをloadする。 INTERP segmentはたいてい.interp sectionを唯一のsegmentとして含む(ただしsection名は必ずしも.interpである必要はない)。

準備として、普通のプログラムを用意する。 例え陽にlibcの関数を呼んでいなかったとしても(例えば__libc_start_mainなどのために) libcは動的linkされていて.interp sectionが存在する。 今回、pathは/lib64/ld-linux-x86-64.so.2であった。

#include <stdio.h>
int main(void) {
    printf("Hello, world!\n");
    return 0;
}
$ gcc helloworld.c -o helloworld

$ ./helloworld
Hello, world!

$ objdump -s helloworld | grep interp -A 2
Contents of section .interp:
 400238 2f6c6962 36342f6c 642d6c69 6e75782d  /lib64/ld-linux-
 400248 7838362d 36342e73 6f2e3200           x86-64.so.2. 

上書き

interpreterとして指定するプログラムの処理内容は(指定するだけなら)なんでもよい。

この例ではinterruptedと表示して終了するプログラムを使用する。Hello, world!の代わりにこれが表示されれば成功である。

#include <stdio.h>
int main(int argc, char **argv) {
    printf("interrupted\n");
    return 1;
}

ただし再帰的にinterpreterを要求するのは許されないようで、静的linkする必要がある。

$ gcc -static interp.c -o interp

適当に.interpを編集する。null終端の文字列が認識されるので、後ろにゴミを残してもよい。

$ objdump -s helloworld | grep interp -A 1
Contents of section .interp:
 400238 2f746d70 2f696e74 65727000 6e75782d  /tmp/interp.nux-
 400248 7838362d 36342e73 6f2e3200           x86-64.so.2.

このような準備の元で、./helloworldを叩くとHello, world!でなくinterruptedと表示される。 これは期待される挙動である。

$ cp interp /tmp

$ ./helloworld
interrupted

引数とかもちゃんと渡ってきてたりする。(ld-linuxはargv = NULLでも仕事をするのでこれを読んでいるのではない。)

$ cat interp.c
#include <stdio.h>
#include <stdlib.h>
int main(int argc, char **argv) {
    for (int i = 0; i < argc; ++ i) {
        printf("argv[%d] : %s\n", i, argv[i]);
    }
    scanf("%*c");
}
$ ./helloworld foo bar
argv[0] : ./helloworld
argv[1] : foo
argv[2] : bar
^Z

$ ps aux | grep '[h]elloworld\|[i]nterp'
user     10136  0.0  0.0   1120     8 pts/18   T    00:57   0:00 ./helloworld foo bar

$ gdb -p `pidof helloworld`
...
gdb-peda$ vmmap
Start              End                Perm	Name
0x00400000         0x004c9000         r-xp	/tmp/interp
0x006c8000         0x006cb000         rw-p	/tmp/interp
0x006cb000         0x006cd000         rw-p	mapped
0x01510000         0x01533000         rw-p	[heap]
0x00007fff59965000 0x00007fff59987000 rw-p	[stack]
0x00007fff599f6000 0x00007fff599f8000 r--p	[vvar]
0x00007fff599f8000 0x00007fff599fa000 r-xp	[vdso]
0xffffffffff600000 0xffffffffff601000 r-xp	[vsyscall]



Codegate 2017 prequals: angrybird

,

solution

edit the binary + angr.

radare2 is useful to edit binaries.

$ diff <(objdump -d -M intel angrybird) <(objdump -d -M intel angrybird.modified)
2c2
< angrybird:     file format elf64-x86-64
---
> angrybird.modified:     file format elf64-x86-64
152,154c152,160
<   40071a:     67 8b 04 24             mov    eax,DWORD PTR [esp]
<   40071e:     83 f8 00                cmp    eax,0x0
<   400721:     0f 85 b9 fe ff ff       jne    4005e0 <exit@plt>
---
>   40071a:     b8 00 00 00 00          mov    eax,0x0
>   40071f:     90                      nop
>   400720:     90                      nop
>   400721:     90                      nop
>   400722:     90                      nop
>   400723:     90                      nop
>   400724:     90                      nop
>   400725:     90                      nop
>   400726:     90                      nop
163,167c169,173
<   40073a:     48 8b 45 f8             mov    rax,QWORD PTR [rbp-0x8]
<   40073e:     ba 05 00 00 00          mov    edx,0x5
<   400743:     be 8e 50 40 00          mov    esi,0x40508e
<   400748:     48 89 c7                mov    rdi,rax
<   40074b:     e8 30 fe ff ff          call   400580 <strncmp@plt>
---
>   40073a:     bf 38 60 60 00          mov    edi,0x606038
>   40073f:     be 8e 50 40 00          mov    esi,0x40508e
>   400744:     b9 05 00 00 00          mov    ecx,0x5
>   400749:     f3 a4                   rep movs BYTE PTR es:[rdi],BYTE PTR ds:[rsi]
>   40074b:     b8 00 00 00 00          mov    eax,0x0
183c189,194
<   40077b:     0f 84 5f fe ff ff       je     4005e0 <exit@plt>
---
>   40077b:     90                      nop
>   40077c:     90                      nop
>   40077d:     90                      nop
>   40077e:     90                      nop
>   40077f:     90                      nop
>   400780:     90                      nop
#!/usr/bin/env python2
import angr # angr-6.7.1.31
p = angr.Project('./angrybird.modified')
state = p.factory.entry_state()
goal = 0x404fdb
pathgroup = p.factory.path_group(state)
pathgroup.explore(find=goal)
for path in pathgroup.found:
    print repr(path.state.posix.dumps(0))
(angr) $ ./a.py
WARNING | 2017-02-10 18:56:15,305 | simuvex.plugins.symbolic_memory | Concretizing symbolic length. Much sad; think about implementing.
'Im_so_cute&pretty_:)\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n\n\n\n\n\n\n\n\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
./a.py  37.70s user 1.03s system 99% cpu 38.787 total

Yukicoder No.440 2次元チワワ問題

,

https://yukicoder.me/problems/no/440

Yukicoderにしてはちょっと定数倍厳しいなと思ったら非想定解だったぽい。 想定解の方が綺麗。

solution

segment木。$O(HW + Q (H \log W + W \log H))$。

縦横および向きで独立なので、ある($1$次元の)文字列の部分文字列からcwwを数える問題としてよい。 区間に対してその区間中のc w cw cwwの数をそれぞれ持っておけば、区間の情報の合成は$O(1)$、文字列の長さを$N$とするとsegment木で$O(\log N)$でcwwの数が取れる。 縦横および向きの$2H + 2W$個のそれぞれに対してこれを$O(\log W), O(\log H)$でやればクエリごとに$O(H \log W + W \log H)$であり、間に合う。 ただし定数倍が厳しいので適当する必要があるかも。

implementation

#include <iostream>
#include <vector>
#include <functional>
#include <cmath>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_from(i,m,n) for (int i = (m); (i) < int(n); ++(i))
using ll = long long;
using namespace std;

template <typename T>
struct segment_tree { // on monoid
    int n;
    vector<T> a;
    function<T (T,T)> append; // associative
    T unit; // unit
    segment_tree() = default;
    segment_tree(int a_n, T a_unit, function<T (T,T)> a_append) {
        n = pow(2,ceil(log2(a_n)));
        a.resize(2*n-1, a_unit);
        unit = a_unit;
        append = a_append;
    }
    void point_update(int i, T z) {
        a[i+n-1] = z;
        for (i = (i+n)/2; i > 0; i /= 2) {
            a[i-1] = append(a[2*i-1], a[2*i]);
        }
    }
    T range_concat(int l, int r) {
        return range_concat(0, 0, n, l, r);
    }
    T range_concat(int i, int il, int ir, int l, int r) {
        if (l <= il and ir <= r) {
            return a[i];
        } else if (ir <= l or r <= il) {
            return unit;
        } else {
            return append(
                    range_concat(2*i+1, il, (il+ir)/2, l, r),
                    range_concat(2*i+2, (il+ir)/2, ir, l, r));
        }
    }
};
int nc2(int n) { return n*(n-1)/2; }

struct cww_t {
    int c, w, cw, wc, cww, wwc;
};
const cww_t unit = {};
cww_t append(cww_t const & x, cww_t const & y) {
    cww_t z;
    z.c = x.c + y.c;
    z.w = x.w + y.w;
    z.cw = x.cw + x.c * y.w + y.cw;
    z.wc = x.wc + x.w * y.c + y.wc;
    z.cww = x.cww + x.cw * y.w + x.c * nc2(y.w) + y.cww;
    z.wwc = x.wwc + nc2(x.w) * y.c + x.w * y.wc + y.wwc;
    return z;
}
int main() {
    int h, w; cin >> h >> w;
    vector<string> f(h); repeat (y,h) cin >> f[y];
    vector<segment_tree<cww_t> > segtree_hr(h, segment_tree<cww_t>(w, unit, append));
    vector<segment_tree<cww_t> > segtree_vt(w, segment_tree<cww_t>(h, unit, append));
    repeat (y,h) repeat (x,w) segtree_hr[y].point_update(x, (cww_t) { (f[y][x] == 'c'), (f[y][x] == 'w'), 0, 0 });
    repeat (x,w) repeat (y,h) segtree_vt[x].point_update(y, (cww_t) { (f[y][x] == 'c'), (f[y][x] == 'w'), 0, 0 });
    int q; cin >> q;
    while (q --) {
        int ly, lx, ry, rx; cin >> ly >> lx >> ry >> rx; -- ly; -- lx;
        ll ans = 0;
        repeat_from (y,ly,ry) { auto it = segtree_hr[y].range_concat(lx, rx); ans += it.cww + it.wwc; }
        repeat_from (x,lx,rx) { auto it = segtree_vt[x].range_concat(ly, ry); ans += it.cww + it.wwc; }
        cout << ans << endl;
    }
    return 0;
}

Yukicoder No.255 Splarrraaay スプラーレェーーイ

,

https://yukicoder.me/problems/no/255

前編であるNo.230 Splarraay スプラレェーイを解いた後にそのまま投げたら座圧が必要なの見落としててREが乱立したし、$10^{18}+9$で剰余取るのを忘れててWAで困ったりもした。 剰余取るのが必要なケースはひとつだけっぽいのと$10^{18}+9$という大きな値なので、自分の提出を含めて足せば落ちる提出はありそう。

solution

座標圧縮 + 遅延伝播segment木。$O(Q \log Q)$。

$N \le 10^{13}$と大きい。クエリを先読みして座標圧縮し、配列の要素には長さの属性を持たせる。

木に与えるクエリは合成ができないといけないが、このため厚みの属性と合成の可否の属性のふたつが必要。 合成の可否とは、「Aの色で厚み2で塗る」と「Aの色で厚み2で塗る」は「Aの色で厚み4で塗る」に合成できるが、「Aの色で厚み2で塗る」と「(Bの色で塗った後に)Aの色で厚み2で塗る」は合成しても「(Bの色で塗った後に)Aの色で厚み2で塗る」にしかならないという区別のため。

$10^{18}+9$での剰余を忘れないように。

implementation

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <array>
#include <map>
#include <functional>
#include <cmath>
#include <cassert>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;

template <typename M, typename Q>
struct lazy_propagation_segment_tree { // on monoids
    int n;
    vector<M> a;
    vector<Q> q;
    function<M (M,M)> append_m; // associative
    function<Q (Q,Q)> append_q; // associative, not necessarily commutative
    function<M (Q,M)> apply; // distributive, associative
    M unit_m; // unit
    Q unit_q; // unit
    lazy_propagation_segment_tree() = default;
    lazy_propagation_segment_tree(int a_n, M a_unit_m, Q a_unit_q, function<M (M,M)> a_append_m, function<Q (Q,Q)> a_append_q, function<M (Q,M)> a_apply) {
        n = pow(2,ceil(log2(a_n)));
        a.resize(2*n-1, a_unit_m);
        q.resize(max(0, 2*n-1-n), a_unit_q);
        unit_m = a_unit_m;
        unit_q = a_unit_q;
        append_m = a_append_m;
        append_q = a_append_q;
        apply = a_apply;
    }
    void range_apply(int l, int r, Q z) {
        assert (0 <= l and l <= r and r <= n);
        range_apply(0, 0, n, l, r, z);
    }
    void range_apply(int i, int il, int ir, int l, int r, Q z) {
        if (l <= il and ir <= r) {
            a[i] = apply(z, a[i]);
            if (i < q.size()) q[i] = append_q(z, q[i]);
        } else if (ir <= l or r <= il) {
            // nop
        } else {
            range_apply(2*i+1, il, (il+ir)/2, 0, n, q[i]);
            range_apply(2*i+1, il, (il+ir)/2, l, r, z);
            range_apply(2*i+2, (il+ir)/2, ir, 0, n, q[i]);
            range_apply(2*i+2, (il+ir)/2, ir, l, r, z);
            a[i] = append_m(a[2*i+1], a[2*i+2]);
            q[i] = unit_q;
        }
    }
    M range_concat(int l, int r) {
        assert (0 <= l and l <= r and r <= n);
        return range_concat(0, 0, n, l, r);
    }
    M range_concat(int i, int il, int ir, int l, int r) {
        if (l <= il and ir <= r) {
            return a[i];
        } else if (ir <= l or r <= il) {
            return unit_m;
        } else {
            return apply(q[i], append_m(
                    range_concat(2*i+1, il, (il+ir)/2, l, r),
                    range_concat(2*i+2, (il+ir)/2, ir, l, r)));
        }
    }
};
template <typename T>
map<T,int> coordinate_compression_map(vector<T> const & xs) {
    int n = xs.size();
    vector<int> ys(n);
    whole(iota, ys, 0);
    whole(sort, ys, [&](int i, int j) { return xs[i] < xs[j]; });
    map<T,int> f;
    for (int i : ys) {
        if (not f.count(xs[i])) { // make unique
            int j = f.size();
            f[xs[i]] = j; // f[xs[i]] has a side effect, increasing the f.size()
        }
    }
    return f;
}

const ll mod = ll(1e18)+9;
struct state_t {
    ll size;
    array<ll,5> acc;
};
struct query_t {
    enum { UNIT, LENGTH, FILL } type;
    ll arg1;
    int arg2;
    bool arg3;
};
int main() {
    // input
    ll n; int q; cin >> n >> q;
    vector<int> x(q); vector<ll> l(q), r(q); repeat (i,q) { cin >> x[i] >> l[i] >> r[i]; ++ r[i]; }
    // prepare
    map<ll,int> compress; {
        vector<ll> ps;
        ps.push_back(0);
        ps.push_back(n);
        repeat (i,q) {
            ps.push_back(l[i]);
            ps.push_back(r[i]);
        }
        compress = coordinate_compression_map(ps);
    }
    lazy_propagation_segment_tree<state_t,query_t> segtree(compress[n], (state_t) { 0, {} }, (query_t) { query_t::UNIT }, [&](state_t const & a, state_t const & b) {
        state_t c;
        c.size = a.size + b.size;
        repeat (i,5) c.acc[i] = a.acc[i] + b.acc[i];
        return c;
    }, [&](query_t q, query_t p) {
        if (q.type == query_t::UNIT) return p;
        if (p.type == query_t::UNIT) return q;
        assert (q.type == query_t::FILL);
        assert (p.type == query_t::FILL);
        if (q.arg1 == p.arg1) { // if same color
            if (not q.arg3) { // if not reset
                q.arg2 += p.arg2;
                q.arg3 = p.arg3;
            }
        } else {
            q.arg3 = true; // reset
        }
        return q;
    }, [&](query_t p, state_t a) {
        if (p.type == query_t::UNIT) return a;
        if (p.type == query_t::LENGTH) return (state_t) { p.arg1, {} };
        int color = p.arg1;
        int depth = p.arg2;
        bool reset = p.arg3;
        state_t b = {};
        b.size = a.size;
        b.acc[color] = ((reset ? 0 : a.acc[color]) + depth * a.size % mod) % mod;
        return b;
    });
    for (auto cur = compress.begin(), nxt = ++ compress.begin(); nxt != compress.end(); ++ cur, ++ nxt) {
        assert (cur->second + 1 == nxt->second);
        segtree.range_apply(cur->second, nxt->second, (query_t) { query_t::LENGTH, nxt->first - cur->first });
        segtree.range_concat(cur->second, nxt->second);
    }
    // solve
    ll acc[5] = {};
    repeat (i,q) {
        if (x[i] == 0) {
            state_t it = segtree.range_concat(compress[l[i]], compress[r[i]]);
            int j = whole(max_element, it.acc) - it.acc.begin();
            if (whole(count, it.acc, it.acc[j]) == 1) {
                acc[j] = (acc[j] + it.acc[j]) % mod;
            }
        } else {
            segtree.range_apply(compress[l[i]], compress[r[i]], (query_t) { query_t::FILL, x[i]-1, 1, false });
        }
    }
    state_t it = segtree.range_concat(compress[0], compress[n]);
    repeat (i,5) acc[i] = (acc[i] + it.acc[i]) % mod;
    // output
    repeat (i,5) cout << acc[i] << ' '; cout << endl;
    return 0;
}

Yukicoder No.230 Splarraay スプラレェーイ

,

http://yukicoder.me/problems/no/230

以前挑んで失敗した跡があったが、ライブラリ整備してれば流れで書くだけ。

solution

遅延伝播segment木。$O(N + Q \log N)$。

参考: https://kimiyuki.net/blog/2017/01/17/segment-tree-requirements/

implementation

#include <iostream>
#include <vector>
#include <numeric>
#include <array>
#include <functional>
#include <cmath>
#include <cassert>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;

template <typename M, typename Q>
struct lazy_propagation_segment_tree { // on monoids
    int n;
    vector<M> a;
    vector<Q> q;
    function<M (M,M)> append_m; // associative
    function<Q (Q,Q)> append_q; // associative, not necessarily commutative
    function<M (Q,M)> apply; // distributive, associative
    M unit_m; // unit
    Q unit_q; // unit
    lazy_propagation_segment_tree() = default;
    lazy_propagation_segment_tree(int a_n, M a_unit_m, Q a_unit_q, function<M (M,M)> a_append_m, function<Q (Q,Q)> a_append_q, function<M (Q,M)> a_apply) {
        n = pow(2,ceil(log2(a_n)));
        a.resize(2*n-1, a_unit_m);
        q.resize(max(0, 2*n-1-n), a_unit_q);
        unit_m = a_unit_m;
        unit_q = a_unit_q;
        append_m = a_append_m;
        append_q = a_append_q;
        apply = a_apply;
    }
    void range_apply(int l, int r, Q z) {
        assert (0 <= l and l <= r and r <= n);
        range_apply(0, 0, n, l, r, z);
    }
    void range_apply(int i, int il, int ir, int l, int r, Q z) {
        if (l <= il and ir <= r) {
            a[i] = apply(z, a[i]);
            if (i < q.size()) q[i] = append_q(z, q[i]);
        } else if (ir <= l or r <= il) {
            // nop
        } else {
            range_apply(2*i+1, il, (il+ir)/2, 0, n, q[i]);
            range_apply(2*i+1, il, (il+ir)/2, l, r, z);
            range_apply(2*i+2, (il+ir)/2, ir, 0, n, q[i]);
            range_apply(2*i+2, (il+ir)/2, ir, l, r, z);
            a[i] = append_m(a[2*i+1], a[2*i+2]);
            q[i] = unit_q;
        }
    }
    M range_concat(int l, int r) {
        assert (0 <= l and l <= r and r <= n);
        return range_concat(0, 0, n, l, r);
    }
    M range_concat(int i, int il, int ir, int l, int r) {
        if (l <= il and ir <= r) {
            return a[i];
        } else if (ir <= l or r <= il) {
            return unit_m;
        } else {
            return apply(q[i], append_m(
                    range_concat(2*i+1, il, (il+ir)/2, l, r),
                    range_concat(2*i+2, (il+ir)/2, ir, l, r)));
        }
    }
};

int main() {
    int n; cin >> n;
    lazy_propagation_segment_tree<array<int,3>,int> segtree(n, {}, -1, [&](array<int,3> a, array<int,3> b) {
        array<int,3> c;
        repeat (i,3) c[i] = a[i] + b[i];
        return c;
    }, [&](int q, int p) {
        if (q == -1) return p;
        return q;
    }, [&](int p, array<int,3> a) {
        if (p == -1) return a;
        array<int,3> b = {};
        if (p == 0) {
            b[0] = 1;
        } else {
            b[p] = whole(accumulate, a, 0);
        }
        return b;
    });
    repeat (i,n) segtree.range_apply(i, i+1, 0);
    ll a = 0, b = 0;
    int q; cin >> q;
    while (q --) {
        int x, l, r; cin >> x >> l >> r; ++ r;
        if (x == 0) {
            array<int,3> it = segtree.range_concat(l, r);
            if (it[1] > it[2]) {
                a += it[1];
            } else if (it[1] < it[2]) {
                b += it[2];
            }
        } else {
            segtree.range_apply(l, r, x);
        }
    }
    array<int,3> it = segtree.range_concat(0, n);
    a += it[1];
    b += it[2];
    cout << a << ' ' << b << endl;
    return 0;
}

割り込みによるスタックの上方の値の予期せぬ書き換えについて

,

CTFやbrainfuck golfではスタックの進む先の空間を普通の領域として利用することがある。 この領域が勝手に書き変わる場合について、具体例としてsignalによる割り込みを思い付いたため検証した。 結論としては、signalが飛び自分で設定したhandlerが走ると壊れることがあるということである。

設定

以下のようなC言語のプログラムを考える。 整数を入力させをれをそのまま出力するだけのプログラムである。 ただし、スタックの進む先のアドレスに対し、そこへ一瞬だけ書き込んで読み出す。

int main(void) {
    int x, *p;
    p = &x - 0x10;
    scanf("%d", &x);
    *p = x;
    x = *p;
    printf("%d\n", x);
}

これは間違いなく規格違反だろうが、実際のところ何事もなかったかのように動く。

$ gcc --version
gcc (Ubuntu 5.4.1-2ubuntu1~16.04) 5.4.1 20160904
Copyright (C) 2015 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

$ gcc a.c
$ ./a.out
1234
1234

signalが絡むとこのようなプログラムが失敗しうることを確認する。

確認

先に通常の場合を確認する。

この上の例をdisassembleすると以下のようになる。 *p = x;からx = *p;の間には特に他の命令はないため、ほぼ間違いなく値は保存されると言ってよいだろう。 この間にはpを触らない命令なら他に何を入れてもよいが、もちろん関数をcallするなどすれば値は壊れうる。

00000000004005f6 <main>:
  4005f6:       55                      push   rbp
  4005f7:       48 89 e5                mov    rbp,rsp
  4005fa:       48 83 ec 20             sub    rsp,0x20
  4005fe:       64 48 8b 04 25 28 00    mov    rax,QWORD PTR fs:0x28
  400605:       00 00 
  400607:       48 89 45 f8             mov    QWORD PTR [rbp-0x8],rax
  40060b:       31 c0                   xor    eax,eax
  # p = &x - 0x10;
  40060d:       48 8d 45 ec             lea    rax,[rbp-0x14]
  400611:       48 83 e8 40             sub    rax,0x40
  400615:       48 89 45 f0             mov    QWORD PTR [rbp-0x10],rax
  # scanf("&d", &x);
  400619:       48 8d 45 ec             lea    rax,[rbp-0x14]
  40061d:       48 89 c6                mov    rsi,rax
  400620:       bf f4 06 40 00          mov    edi,0x4006f4
  400625:       b8 00 00 00 00          mov    eax,0x0
  40062a:       e8 b1 fe ff ff          call   4004e0 <__isoc99_scanf@plt>
  # *p = x;
  40062f:       8b 55 ec                mov    edx,DWORD PTR [rbp-0x14]
  400632:       48 8b 45 f0             mov    rax,QWORD PTR [rbp-0x10]
  400636:       89 10                   mov    DWORD PTR [rax],edx
  # x = *p;
  400638:       48 8b 45 f0             mov    rax,QWORD PTR [rbp-0x10]
  40063c:       8b 00                   mov    eax,DWORD PTR [rax]
  40063e:       89 45 ec                mov    DWORD PTR [rbp-0x14],eax
  # printf("%d\n", x);
  400641:       8b 45 ec                mov    eax,DWORD PTR [rbp-0x14]
  400644:       89 c6                   mov    esi,eax
  400646:       bf f7 06 40 00          mov    edi,0x4006f7
  40064b:       b8 00 00 00 00          mov    eax,0x0
  400650:       e8 6b fe ff ff          call   4004c0 <printf@plt>
  # return 0;
  400655:       b8 00 00 00 00          mov    eax,0x0
  40065a:       48 8b 4d f8             mov    rcx,QWORD PTR [rbp-0x8]
  40065e:       64 48 33 0c 25 28 00    xor    rcx,QWORD PTR fs:0x28
  400665:       00 00 
  400667:       74 05                   je     40066e <main+0x78>
  400669:       e8 42 fe ff ff          call   4004b0 <__stack_chk_fail@plt>
  40066e:       c9                      leave  
  40066f:       c3                      ret    

signal

disassemble結果としてはまったく問題なくても値が保存されない例として、signalが考えられる。 つまり、*p = x;からx = *p;の間で何らかのsignalが飛びそのhandlerが呼ばれた場合、このhandlerは値を壊しうる。

狙った位置でsignalを飛ばすためbusy waitを入れ、検証コードは以下のようになった。

#include <stdio.h>
#include <string.h>
#include <signal.h>
#include <unistd.h>

void func(void) {
    long long x;
    long long *p;
    p = &x - 0x20;
    scanf("%lld", &x);
    *p = x;
    while (x --) ; // busy wait
    x = *p;
    printf("%lld\n", x);
}

void handler(int sig) {
    char buf[4096];
    memset(buf, 0, sizeof(buf));
}

int main(void) {
    signal(SIGALRM, handler);
    alarm(4);
    func();
    return 0;
}

実際、実行すると以下のようになる。 signalによる割り込みが発生した場合では、結果が壊れていることが分かる。

$ echo 123456789 | time ./a.out
123456789
0.25s 1424KB
$ echo 12345678999 | time ./a.out
0
22.47s 1424KB

ただし、手元の環境では、SIGSTOPSIGCONTのdefault handlerではstackの破壊は起きなかった。


AtCoder Grand Contest 010: C - Cleaning

,

http://agc010.contest.atcoder.jp/tasks/agc010_c

Aを書いた後それをBかCが書けたら投げようと思ったら、BもCも最後まで分からなかったので無提出rating変化なし。 あまり良くないとは思うのだけれど、短期的なratingを考えると十分選択肢に入る戦略なのでこうなってしまう。

木DPだとは思っていたが、各部分木を全てちょうど葉と同一視できるのに気付かず、その頂点が要求するパスの数の加減$L_i$と上限$R_i$で木DPしようとしていた。

solution

木DP。部分木は葉と同一視できるので根が$A_i = 0$な葉と見做せるかを答える。$O(N)$。

ある葉でない部分木$i$を葉と同一視するとして持つ石の数$A’_i$を考える。 子は全て葉としてよい。 子の持つ石の総和$\mathrm{sum} = \sum_{j \in \mathrm{children}(i)} A_j$とする。 この部分木の中で閉じたパスの数を$b$、外へ出ていくパスの数を$c$とすると、$A_i = b + c$かつ$\mathrm{sum} = 2b + c$である。 これは連立方程式として解けて$b = \mathrm{sum} - A_i$。一意に定まることに注意。 また$c = A_i - b$は求めたい$A’_i$と等しい。 ここでそのような$(b, c)$がパスの張り方として有効かどうかの確認が必要。 $0 \le b, 0 \le c$に加えて、内部で$b$本張れるかを見れば十分。 子で$A_j$が最も大きいものを$j$として、$A_j \ge \frac{\mathrm{sum}}{2}$であればこの子が制限をして$b \le \mathrm{sum} - A_j$である必要がある。 そうでないとすると、単に$\mathrm{sum}$が制限をして$b \le \frac{\mathrm{sum}}{2}$を確認すればよい。

implementation

#include <iostream>
#include <vector>
#include <algorithm>
#include <functional>
#include <cassert>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
bool solve(int n, vector<int> & a, vector<vector<int> > & g) {
    vector<int> dp(n);
    auto is_leaf = [&](int i) { return g[i].size() == 1; };
    function<bool (int, int)> go = [&](int i, int parent) {
        if (is_leaf(i)) {
            dp[i] = a[i];
        } else {
            ll sum_dp = 0;
            for (int j : g[i]) if (j != parent) {
                if (not go(j, i)) return false;
                sum_dp += dp[j];
            }
            ll b = sum_dp - a[i];
            ll c = a[i] - b;
            if (b < 0 or c < 0) return false;
            int argmax_dp = *whole(max_element, g[i], [&](int j, int k) { return make_pair(j != i, dp[j]) < make_pair(k != i, dp[k]); });
            ll max_dp = dp[argmax_dp];
            if (max_dp > sum_dp / 2) {
                if (sum_dp - max_dp < b) return false;
            } else {
                if (sum_dp / 2 < b) return false;
            }
            dp[i] = c;
        }
        return true;
    };
    assert (n >= 2);
    if (n == 2) {
        return a[0] == a[1];
    } else {
        int i = 0;
        while (is_leaf(i)) ++ i;
        if (not go(i, -1)) return false;
        return dp[i] == 0;
    }
}
int main() {
    int n; cin >> n;
    vector<int> a(n); repeat (i,n) cin >> a[i];
    vector<vector<int> > g(n);
    repeat (i,n-1) {
        int x, y; cin >> x >> y; -- x; -- y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    cout << (solve(n, a, g) ? "YES" : "NO") << endl;
    return 0;
}

AtCoder Grand Contest 010: B - Boxes

,

http://agc010.contest.atcoder.jp/tasks/agc010_b

この手の嘘解法は好きなのだが、yukicoderだったら実質WAだし、そうでなくてもペナルティが厳しい。

ところで、最初に常にYESを投げてYES NOの偏り$p$を計測してそれに応じて投げれば最初の計測と合わせてもちょっと有利になるというのを思い付いたので使っていきたい。 例えば、今回私の実装を投げたら$3$TLEだったので、YES NOを$\frac{1}{2}$ずつで返すと期待値$8$回だが、このテクを使えば最初の計測で外した場合(なお今回は常にYESが正解だった)でも$1 + \frac{3^3}{2^2\cdot 1} = 7.75$と有利。

solution

嘘解法。貪欲 + 定数倍高速化 + 時間計測乱択。貪欲部分は$\frac{N\sum A_i}{{}_NC_2}$で$O(\max A_i)$。

貪欲は$i \in \{ i \mid A_i = \min A_j \}$を始点として愚直に$N$回引くことを繰り返すもの。 未証明だが、十分数の乱択ケースで検証したのでたぶん大丈夫。

定数倍高速化について。 (i+j)%ni+j<n?i+j:i+j-nにして剰余を除去すること、clangを使ってSIMDによる最適化をしてもらうことが重要。 ${}_NC_2 \nmod \sum A_i$なら即座にNOを返すのも重要。

implementation

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <random>
#include <chrono>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_from(i,m,n) for (int i = (m); (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
bool solve(vector<int> & a) {
    auto start = chrono::system_clock::now();
    int n = a.size();
    const ll nc2 = n*(n+1ll)/2;
    const ll sum_a = whole(accumulate, a, 0ll);
    if (sum_a % nc2 != 0) return false;
    repeat (q, sum_a / nc2) {
        int i = whole(min_element, a) - a.begin();
        if (a[i] <= 0) break;
        repeat      (k,i)   a[k] -= (n+k)-i+1;
        repeat_from (k,i,n) a[k] -=    k -i+1;
        if (q % 10 == 0) {
            auto end = chrono::system_clock::now();
            double t = chrono::duration<double>(end - start).count();
            if (t > 1.9) {
                random_device device;
                uniform_int_distribution<int> dist(0, 1);
                return dist(device);
            }
        }
    }
    return whole(count, a, 0) == n;
}
int main() {
    int n; cin >> n;
    vector<int> a(n); repeat (i,n) cin >> a[i];
    cout << (solve(a) ? "YES" : "NO") << endl;
    return 0;
}

AtCoder Grand Contest 010: A - Addition

,

http://agc010.contest.atcoder.jp/tasks/agc010_a

solution

偶奇 + コーナーケース。主に入力に$O(N)$。

数$A_i$はそれぞれ偶奇のみ見ればよい。 偶数はいくつあっても単一の偶数に潰せ、奇数はふたつで偶数ひとつになる。 よって奇数が奇数個ある場合がNO。ただし単一の奇数のみで偶数もない場合はYES

solution

#!/usr/bin/env python3
_ = int(input())
even = 0
odd = 0
for a in map(int, input().split()):
    if a % 2 == 0:
        even += 1
    else:
        odd += 1
ans = odd % 2 == 0 or (even == 0 and odd == 1)
print(['NO', 'YES'][ans])

Yukicoder No.172 UFOを捕まえろ

,

http://yukicoder.me/problems/no/172

solution

$|x| + |y| + \lceil r\sqrt{2} \rceil$が答え。 $r$が小さい正整数で$\sqrt{2}$が無理数なので$r\sqrt{2}$は常に整数ではなく、天井ではなく床を取って$1$足せばよい。

implementation

$28$byte bash $\to$ dc。

tr -d -|dc -e9k?2v*++0k1/1+p

直後にtailsさんに$16$byteのを出された: http://yukicoder.me/submissions/147326。悲しいね。


バイナリ中のalarm関数の呼び出しを自動で除去させてみる

,

設定

与えられたバイナリを直接編集してalarm関数の呼び出しを除去する。特にこれを自動で行うプログラムを書く。

例えば次のようなC言語のコードから生成されるバイナリを考える。

#include <stdio.h>
#include <stdlib.h>
#include <signal.h>
#include <unistd.h>

void handler(int sig) {
    printf("SIGALRM recieved\n");
    exit(1);
}

int main(void) {
    signal(SIGALRM, handler);
    alarm(1);
    system("sleep 2");
    printf("Congratulations!\n");
    return 0;
}

これを以下のようにコンパイルすると、main関数内である$0x40067d$においてcall <alarm@plt>命令が見つかる。 実行するとsleep 2による待機中にalarm(1);によるSIGALRMが発生するため、Congratulations!の表示は行なわれず、先にSIGALRM recievedが出力され終了する。

このcall <alarm@plt>を自動で除去し、Congratulations!と表示されるように自動で修正させるのが目標である。

$ gcc foo.c

$ objdump -d -M intel a.out | grep ' <main>:' -A 16
0000000000400665 <main>:
  400665:	55                   	push   rbp
  400666:	48 89 e5             	mov    rbp,rsp
  400669:	be 46 06 40 00       	mov    esi,0x400646
  40066e:	bf 0e 00 00 00       	mov    edi,0xe
  400673:	e8 a8 fe ff ff       	call   400520 <signal@plt>
  400678:	bf 03 00 00 00       	mov    edi,0x3
  40067d:	e8 7e fe ff ff       	call   400500 <alarm@plt>
  400682:	bf 35 07 40 00       	mov    edi,0x400735
  400687:	e8 64 fe ff ff       	call   4004f0 <system@plt>
  40068c:	bf 3d 07 40 00       	mov    edi,0x40073d
  400691:	e8 4a fe ff ff       	call   4004e0 <puts@plt>
  400696:	b8 00 00 00 00       	mov    eax,0x0
  40069b:	5d                   	pop    rbp
  40069c:	c3                   	ret    
  40069d:	0f 1f 00             	nop    DWORD PTR [rax]

$ ./a.out
SIGALRM recieved

準備

今回はPython 3で記述し、また以下のふたつのライブラリを用いる。

Capstoneはdisassemblerであり、pyelftoolsはコンテナであるELFのparserである。

今回は利用しないが、emulationをしたいならUnicorn、assemblerが欲しいならKeystone、PEやMach-Oに対応させたいならpefileやmacholibがよいだろう。

なおCapstone,Keystone,Unicornは全てC言語+各種bindingsという形であり、Pythonに限らず利用できる。

実装

先に実装の全体を示す。 x86/x86_64 ELFの普通のバイナリに対して動く。$80$行とあまり長くない長さである。

#!/usr/bin/env python3
from elftools.elf.elffile import ELFFile
from capstone import *
from capstone.x86 import *

def find_call_alarm(path):
    # load elf
    print('[*] open: %s' % path)
    elf = ELFFile(open(path, 'rb'))

    # load disassembler
    if elf.header.e_machine == 'EM_X86_64':
        md = Cs(CS_ARCH_X86, CS_MODE_64)
    elif elf.header.e_machine == 'EM_386':
        md = Cs(CS_ARCH_X86, CS_MODE_32)
    else:
        assert False
    md.detail = True

    # get alarm@got
    relx_plt = elf.get_section_by_name('.rela.plt') or elf.get_section_by_name('.rel.plt')
    dynsym = elf.get_section_by_name('.dynsym')
    for reloc in relx_plt.iter_relocations():
        symbol = dynsym.get_symbol(reloc.entry.r_info_sym)
        if symbol.name == 'alarm':
            alarm_got = reloc.entry.r_offset
    print('[+] alarm@got = %#x' % alarm_got)

    # guess alarm@plt
    plt = elf.get_section_by_name('.plt')
    for insn in md.disasm(plt.data(), plt.header.sh_addr):
        if insn.mnemonic == 'jmp':
            value = None
            for op in insn.operands:
                if op.type == X86_OP_MEM:
                    if insn.reg_name(op.mem.base) == 'rip' and op.mem.index == 0:
                        value = insn.address + insn.size + op.mem.disp
                    elif op.mem.base == 0 and op.mem.index == 0:
                        value = op.mem.disp
            if value == alarm_got:
                alarm_plt = insn.address
    print('[+] alarm@plt = %#x' % alarm_plt)

    # find all "call alarm@plt"
    xref = []
    text = elf.get_section_by_name('.text')
    for insn in md.disasm(text.data(), text.header.sh_addr):
        if insn.mnemonic == 'call':
            for op in insn.operands:
                value = None
                if op.type == X86_OP_IMM:
                    value = op.imm
                if value == alarm_plt:
                    offset = insn.address - text.header.sh_addr + text.header.sh_offset
                    xref += [ { 'offset': offset, 'length': insn.size } ]
                    print('[*] %#x: call alarm@plt  (offset = %d)' % (insn.address, offset))

    return xref

def overwrite_with_nop(path, xref):
    # overwrite them with "nop"
    print('[*] overwrite: %s' % path)
    with open(path, 'rb+') as fh:
        for it in sorted(xref, key=lambda it: it['offset']):
            fh.seek(it['offset'] - fh.tell())
            fh.write(b'\x90' * it['length'])
    print('[+] done')


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('path', nargs='?', default='a.out')
    args = parser.parse_args()

    xref = find_call_alarm(args.path)
    overwrite_with_nop(args.path, xref)

if __name__ == '__main__':
    main()

実行例:

$ gcc foo.c

$ objdump -d -M intel a.out | grep ' <main>:' -A 16
0000000000400665 <main>:
  400665:	55                   	push   rbp
  400666:	48 89 e5             	mov    rbp,rsp
  400669:	be 46 06 40 00       	mov    esi,0x400646
  40066e:	bf 0e 00 00 00       	mov    edi,0xe
  400673:	e8 a8 fe ff ff       	call   400520 <signal@plt>
  400678:	bf 01 00 00 00       	mov    edi,0x1
  40067d:	e8 7e fe ff ff       	call   400500 <alarm@plt>
  400682:	bf 35 07 40 00       	mov    edi,0x400735
  400687:	e8 64 fe ff ff       	call   4004f0 <system@plt>
  40068c:	bf 3d 07 40 00       	mov    edi,0x40073d
  400691:	e8 4a fe ff ff       	call   4004e0 <puts@plt>
  400696:	b8 00 00 00 00       	mov    eax,0x0
  40069b:	5d                   	pop    rbp
  40069c:	c3                   	ret    
  40069d:	0f 1f 00             	nop    DWORD PTR [rax]

$ python3 kill-alarm.py
[*] open: a.out
[+] alarm@got = 0x601028
[+] alarm@plt = 0x400500
[*] 0x40067d: call alarm@plt  (offset = 1661)
[*] overwrite: a.out
[+] done

$ objdump -d -M intel a.out | grep ' <main>:' -A 20
0000000000400665 <main>:
  400665:	55                   	push   rbp
  400666:	48 89 e5             	mov    rbp,rsp
  400669:	be 46 06 40 00       	mov    esi,0x400646
  40066e:	bf 0e 00 00 00       	mov    edi,0xe
  400673:	e8 a8 fe ff ff       	call   400520 <signal@plt>
  400678:	bf 01 00 00 00       	mov    edi,0x1
  40067d:	90                   	nop
  40067e:	90                   	nop
  40067f:	90                   	nop
  400680:	90                   	nop
  400681:	90                   	nop
  400682:	bf 35 07 40 00       	mov    edi,0x400735
  400687:	e8 64 fe ff ff       	call   4004f0 <system@plt>
  40068c:	bf 3d 07 40 00       	mov    edi,0x40073d
  400691:	e8 4a fe ff ff       	call   4004e0 <puts@plt>
  400696:	b8 00 00 00 00       	mov    eax,0x0
  40069b:	5d                   	pop    rbp
  40069c:	c3                   	ret    
  40069d:	0f 1f 00             	nop    DWORD PTR [rax]

$ ./a.out
Congratulations!

解説

実装の詳細について解説する。

main

始めはmain関数。 find_call_alarm関数でalarmの呼び出しを列挙し、これをoverwrite_with_nop関数で破壊的に潰すという構成。

find_call_alarm

find_call_alarm関数について。

header

まずpyelftoolsを用いてELFFile(open(path, 'rb'))とファイルを読み、 その情報からCs(CS_ARCH_X86, CS_MODE_64)等としてCapstoneを呼び出し。 Capstoneは純粋なdisassemblerなので、ELFやPEのようなコンテナには関与しないことに注意。

got

次にGOT内でのalarmのentryのaddressである、alarm@gotの取得。 これは.rela.plt/.rel.plt.symtab.dynsymを読めばよい。

.rela.plt/.rel.pltはrelocation情報のtableである。 .dynsymはsymbol table、.dynstrはこれから参照される文字列 tableである。 GOTは(.interpで指定される)外部のlinkerにより実行時に操作する必要があるため、(実行時には不要な)他のsymbolが格納されている.symtab,.strtabとは違うsectionとなっている。

.rel.pltにあるのは以下のようなaddressとsymbolの対である。.rela.pltはここに加数r_addend(symbolで引いてきた値に加える値)を加えたもので、併存も可能だが基本的にどちらか一方だけだろう。

typedef struct {
    Elf64_Addr r_offset;
    uint64_t   r_info;
} Elf64_Rel;

これをなめてsymbol alarmを指すもののr_offsetalarm@gotである。 pyelftoolsは薄いので自分でそのようになめる。

plt

alarm@pltの推測。 linkerが動的に操作する必要のあるGOTと違ってその結果を勝手に見に行くだけであるPLTはELF内にsymbolを残す必要がなく、GOTとの対応等から推測する必要がある。

.plt内の命令を眺め、.got.plt内のalarm@gotを参照している位置を探すのがよいだろう。 これにはCapstoneを用いる。 emulatorであるUnicornを加えて持ってきてもよいが、今回は対象が固定的なので、jmp [rip + 0x12345678]jmp ds:0x12345678の形式をしている命令に関して手で参照先を計算する。 jmp [$base + $index * scale + disp]となっている。

text

最後にcall <alarm@plt>を列挙。 .textを開いてなめる。

.pltでの場合と同様に、call 0x12345678の形の命令についてalarm@pltとの一致を確認する。

overwrite_with_nop

これは素直にやる。 実行時のaddressとファイル内でのoffsetを混同しないように注意する。

所感

  • 自動化は楽しい
  • asmの操作はいいけどELFがつらい

資料

ELFについて: