Learning Algorithms

アルゴリズムの勉強メモ

CS Academy 062 E. Trees Partition

CS Academy 062 E. Tree Partition

CS Academy

解法

なんらかの方法によって、各部分木の頂点集合とその補集合が即座にわかればよさそうです。

しかし各ノードに部分木の頂点集合のsetを持たせるのはさすがに厳しいです。そこで、各頂点に乱数を割り振って、そこから頂点集合に対するあるハッシュ値を作ることを考えます。

色々な演算の中で特に簡単に使えそうなものは $xor$ で、部分木内の各頂点の乱数の $xor$ をその頂点集合のハッシュ値とします。全体の $xor$ である $all$ を求めておけば、(選んだ頂点集合のハッシュ値) $xor\ all$ によって残りの頂点集合のハッシュ値も即座に計算可能です。

これが重複なく分散していると仮定すれば、あとは両方の木の各部分木でこのハッシュ値をてきとうに求めておき、setなどで一致するものの個数を数えれば、それが答えです。

精度についてはまだよくわかっていません。

実装
#include <cstdio>
#include <algorithm>
#include <vector>
#include <set>
#include <functional>
using namespace std;

unsigned long xor128() {
        static unsigned long x = 123456789, y = 362436069, z = 521288629, w = 88675123;
        unsigned long t = (x ^ (x << 11));
        x = y; y = z; z = w;
        return (w = (w ^ (w >> 19)) ^ (t ^ (t >> 8)));
}

int main() {
        int n;
        scanf("%d", &n);
        vector<vector<int>> g1(n), g2(n);
        for (int i = 1; i < n; i ++) {
                int p;
                scanf("%d", &p);
                p --;
                g1[p].push_back(i);
        }
        for (int i = 1; i < n; i ++) {
                int p;
                scanf("%d", &p);
                p --;
                g2[p].push_back(i);
        }
        vector<unsigned long> hash(n);
        for (int i = 0; i < n; i ++) hash[i] = xor128();
        unsigned long all = 0;
        for (int i = 0; i < n; i ++) all ^= hash[i];
        vector<unsigned long> res1(n), res2(n);
        set<unsigned long> st;
        function<void (int)> dfs1 = [&](int u) {
                for (auto v : g1[u]) {
                        dfs1(v);
                        res1[u] ^= res1[v];
                }
                res1[u] ^= hash[u];
        };
        dfs1(0);
        function<void (int)> dfs2 = [&](int u) {
                for (auto v : g2[u]) {
                        dfs2(v);
                        res2[u] ^= res2[v];
                }
                res2[u] ^= hash[u];
                st.insert(res2[u]);
        };
        dfs2(0);
        int ans = 0;
        for (int i = 1; i < n; i ++) ans += st.count(res1[i]) || st.count(res1[i] ^ all);
        printf("%d\n", ans);
        return 0;
}