AOJ 2239: Nearest Station

,

http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2239

  • abs付け忘れてる箇所があって$1$WA
  • for (int k = 0; k < int(n); ++ k)nのoverflowで$1$WA。

solution

半分全列挙。チケットで進める距離は急激に増えるので$n \le 40 \approx \log m$と仮定できる。$O(\sqrt{m} \log m)$。 $m$を通り過ぎて戻る場合が存在するなど、綺麗な半分全列挙にはならないので注意が必要。 また$a = b = 1$の場合はこのようにはできないので例外処理。

implementation

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <vector>
#define repeat(i, n) for (int i = 0; (i) < int(n); ++(i))
#define repeat_from(i, m, n) for (int i = (m); (i) < int(n); ++(i))
#define whole(f, x, ...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
template <class T> inline void setmin(T & a, T const & b) { a = min(a, b); }

ll solve(ll n, ll m, ll a, ll b, ll p, ll q) {
    if (a == 1 and b == 1) {
        ll delta = p + q;
        ll result = m;
        setmin(result, abs(m - min(n,  m              / delta) * delta));
        setmin(result, abs(m - min(n, (m + delta - 1) / delta) * delta));
        return result;
    }
    vector<ll> tickets;
    for (ll k = 0; k < n; ++ k) {
        double ticket = p * pow(a, k) + q * pow(b, k);
        if (2 * m + 3 < ticket) break;
        tickets.push_back(ticket);
    }
    if (tickets.empty()) {
        return m;
    } else if (tickets.size() == 1) {
        return min(m, abs(m - tickets[0]));
    }
    auto generate = [&](int l, int r) {
        vector<ll> acc;
        acc.push_back(0);
        repeat_from (k, l, r) {
            ll ticket = tickets[k];
            int prev_size = acc.size();
            repeat (i, prev_size) {
                acc.push_back(acc[i] + ticket);
            }
            whole(sort, acc);
            acc.erase(whole(unique, acc), acc.end());
        }
        return acc;
    };
    vector<ll> lower = generate(0, tickets.size() / 2);
    vector<ll> upper = generate(tickets.size() / 2, tickets.size());
    ll result = m;
    for (ll x : lower) {
        int i = whole(lower_bound, upper, m - x) - upper.begin();
        repeat_from (j, max(0, i - 3), min<int>(upper.size(), i + 3)) {
            ll y = upper[j];
            setmin(result, abs(m - (x + y)));
        }
    }
    return result;
}

int main() {
    ll n, m, a, b, p, q; scanf("%lld%lld%lld%lld%lld%lld", &n, &m, &a, &b, &p, &q);
    ll result = solve(n, m, a, b, p, q);
    printf("%lld\n", result);
    return 0;
}