Learning Algorithms

アルゴリズムの勉強メモ

第4回ドワンゴからの挑戦状予選 E. ニワンゴくんの家探し

第4回ドワンゴからの挑戦状予選 E. ニワンゴくんの家探し

E - ニワンゴくんの家探し

かなり好きな問題だったので、本番で通したかった。

解法

まず分割して考えていきたい気持ちになるので、重心に注目します。

重心が二つある場合は、その重心の $2$ 頂点でクエリを投げて解の存在範囲を二分できます。

問題は重心が一つの場合ですが、これも割と素直にやればよくて、重心の周りの部分木のうち大きい方から順に $2$ 個選んで、重心の隣の頂点を選んでそれらをクエリとして投げればよいです。

最悪ケースは重心の次数が $5$ で、それぞれの部分木の大きさがすべてほとんど同じで、かつクエリの結果が $0$ であったときで、このときは解の存在範囲が $ \frac {3}{5}$ になります(クエリとして選んだ部分木以外が存在範囲になるため)。

しかしこれでも $(\frac{3}{5})^{14} = 1.56728$ なので、$14$ 回の探索で十分なことが言えます。

実装は、まず何も考えずにこれを貼って、まだ探索範囲にある頂点数を求めるalive関数と、木を切ってしまうcut関数を作って、あとはそのままやりました。

実装
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cassert>
#include <functional>
using namespace std;

int ask(int u, int v) {
        u ++, v ++;
        printf("? %d %d\n", u, v);
        fflush(stdout);
        int get;
        scanf("%d", &get);
        get --;
        return get;
}

void answer(int ans) {
        printf("! %d\n", ans + 1);
}

int alive(int start, const vector<vector<int>> &g) {
        int n = 0;
        function<void (int, int)> get_sz = [&](int u, int prev) {
                n ++;
                for (auto v : g[u]) if (v != prev) {
                        get_sz(v, u);
                }
        };
        get_sz(start, -1);
        return n;
}

vector<int> Centroid(int start, const vector<vector<int>> &g) {
        int n = g.size();
        vector<int> sz(n);
        vector<int> centroid;
        int N = alive(start, g);
        function<void (int, int)> dfs = [&](int u, int prev) {
                sz[u] = 1;
                bool is_centroid = true;
                for (auto v : g[u]) if (v != prev) {
                        dfs(v, u);
                        sz[u] += sz[v];
                        if (sz[v] > N / 2) is_centroid = false;
                }
                if (N - sz[u] > N / 2) is_centroid = false;
                if (is_centroid) centroid.push_back(u);
        };
        dfs(start, -1);
        return centroid;
}

void cut(vector<vector<int>> &g, int u, int v) {
        for (int i = 0; i < g[u].size(); i ++) {
                if (g[u][i] == v) {
                        g[u].erase(g[u].begin() + i);
                }
        }
        for (int i = 0; i < g[v].size(); i ++) {
                if (g[v][i] == u) {
                        g[v].erase(g[v].begin() + i);
                }
        }
}

int main() {
        int n, q;
        scanf("%d%d", &n, &q);
        vector<vector<int>> g(n);
        for (int i = 0; i < n - 1; i ++) {
                int a, b;
                scanf("%d%d", &a, &b);
                a --, b --;
                g[a].push_back(b);
                g[b].push_back(a);
        }
        int start = 0;
        int ans = -1;
        while (true) {
                vector<int> centroid = Centroid(start, g);
                if (centroid.size() == 2) {
                        int u = centroid[0], v = centroid[1];
                        cut(g, u, v);
                        start = ask(u, v);
                        if (g[start].size() == 0) {
                                ans = start;
                                break;
                        }
                } else if (centroid.size() == 1) {
                        int c = centroid[0];
                        vector<int> sz(n);
                        function<void (int, int)> get_sz = [&](int u, int prev) {
                                sz[u] = 1;
                                for (auto v : g[u]) if (v != prev) {
                                        get_sz(v, u);
                                        sz[u] += sz[v];
                                }
                        };
                        get_sz(c, -1);
                        vector<pair<int, int>> sz_idx;
                        for (auto it : g[c]) {
                                sz_idx.emplace_back(sz[it], it);
                        }
                        if (sz_idx.size() == 0) {
                                ans = c;
                                break;
                        }
                        assert(sz_idx.size() >= 2);
                        sort(sz_idx.rbegin(), sz_idx.rend());
                        int u = sz_idx[0].second, v = sz_idx[1].second;
                        int get = ask(u, v);
                        if (get == u) {
                                cut(g, u, c);
                                start = get;
                        } else if (get == v){
                                cut(g, v, c);
                                start = get;
                        } else {
                                assert(get == -1);
                                cut(g, v, c);
                                cut(g, u, c);
                                start = c;
                        }
                } else {
                        assert(false);
                }
        }
        answer(ans);
        return 0;
}