Learning Algorithms

アルゴリズムの勉強メモ

Codeforces 458 E. Palindromes in a Tree

Problem - E - Codeforces

pekempeyさんが書かれたコードから学びました。わかりやすすぎてありがたすぎます。

解法

まず、ある文字の集合を適切に並べると回文になるという条件は、出現回数が奇数である文字の種類が $1$ 個以下である、という条件に言い換えることができます。

これは $26$ ビット表現で表すと都合がよいです。すなわちあるパスについて、そのパスに書かれた文字を並び替えたときにそれが回文になるという条件は、すべての頂点について各文字の位置にビットを立てたものの $XOR$ をとった結果できる値について、$1$ となっているビットの数が $1$ 個以下であるという条件に言い換えることができます(この問題ではさりげなく、与えられる文字は'a'から't'までであると書いてあるので $20$ ビットあれば十分です)。

ちなみにこの辺のテクを理解していると以下の問題も典型的に見えるかもしれません。

D - Yet Another Palindrome Partitioning

さて、これで準備ができました。問題はパスの数え上げをしたいということなので、重心分解することを考えるとよさそうです。

せっかくなので、以下の記事にまとめた重心分解による分割統治法の一般形を使って解きたいと思います。

重心分解による分割統治法の一般形について - Learning Algorithms

まず重心 $C$ によって分解したとします。このとき $C$ を使わないパスの数え上げについては、再帰的に勝手に計算されるので、無視してよいです。

次に $C$ を含むパスの数え上げをします。重心 $C$ の答えにだけではなく、数え上げるときに通った頂点の答えにも足し込んでいかないといけないのがポイントです。

まず頂点 $C$ から $DFS$ をして、 $C$ を端点とするパスの $XOR$ の結果の出現回数をcntに記録していきます。これを $C$ から見た各部分木について計算します。

次に、$C$ を端点とする、条件を満たすパスの数え上げをして $C$ に足し込みます。これは上で述べた条件より、二進数で書くとcnt[00...000] + cnt[00...001] + cnt[00...010] + ...を足すということです。これはすなわち以下のように書けます。

long long res = cnt[0];
for (int i = 0; i < 20; i ++) {
        res += cnt[1 << i];
}

次に $C$ をまたぐパスについて考えます。ある部分木について一旦そこでカウントしものをすべて戻して考えると、「ある頂点からスタートして、$C$ を通過して、自分自身に戻るパス」を取り除けます。部分木内の各頂点について、一方の端点はその頂点自身であり、もう一方の端点は別の部分木のある頂点であるようなパスの数え上げをしていきます。これらは木の性質から、通過した頂点すべてに足し込んでいく必要があります。部分木内のすべての計算を終えたら、この部分木についてのカウントは一旦 $0$ に戻していたはずなので、それをもう一度カウントしてから次の部分木に移って同じことを繰り返します。

最後にこの得られた値の半分( $C$ をまたぐパスはすべて $2$ 回カウントされているはずなので)を足し込んで、$C$ に関するカウントをすべて元に戻して終わりです。

以上のことは次のように実装できます。書き換えたのはやはり、上の記事で言うところの//compute something with the centroidの部分だけです。

ところで、重心をちゃんと $1$ 個または $2$ 個見つけて、vectorにいれてそれを返してcentroid[0]で使う、ってやっている部分がやっぱり時間かかってそうなので、重心分解のときは重心を $1$ 個見つけたらすぐにそれを返す関数を作る方が良さそうですね(それはそう)
追記:重心を一つ見つけたらすぐに返す関数に書き換えたので、かなり高速化されました。

実装
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <string>
#include <vector>
#include <map>
#include <functional>
using namespace std;

int OneCentroid(int root, const vector<vector<int>> &g, const vector<bool> &dead) {
        static vector<int> sz(g.size());
        function<void (int, int)> get_sz = [&](int u, int prev) {
                sz[u] = 1;
                for (auto v : g[u]) if (v != prev && !dead[v]) {
                        get_sz(v, u);
                        sz[u] += sz[v];
                }
        };
        get_sz(root, -1);
        int n = sz[root];
        function<int (int, int)> dfs = [&](int u, int prev) {
                for (auto v : g[u]) if (v != prev && !dead[v]) {
                        if (sz[v] > n / 2) {
                                return dfs(v, u);
                        }
                }
                return u;
        };
        return dfs(root, -1);
}

vector<long long> CentroidDecomposition(const vector<vector<int>> &g, const vector<int> &word) {
        int n = (int) g.size();
        vector<long long> ans(n);
        vector<bool> dead(n, false);
        vector<long long> cnt(1 << 20, 0);
        function<void (int)> rec = [&](int start) {
                int c = OneCentroid(start, g, dead);
                dead[c] = true;
                //compute something within a subtree alone (without the centroid)
                for (auto u : g[c]) if (!dead[u]) {
                        rec(u);
                }
                //compute something with the centroid
                function<void (int, int, int, int)> add_sub = [&](int u, int prev, int val, bool is_add) {
                        val ^= word[u];
                        cnt[val] += (is_add ? 1 : -1);
                        for (auto v : g[u]) if (v != prev && !dead[v]) {
                                add_sub(v, u, val, is_add);
                        }
                };
                function<long long (int, int, int)> calc = [&](int u, int prev, int val) {
                        val ^= word[u];
                        long long res = cnt[val];
                        for (int i = 0; i < 20; i ++) {
                                res += cnt[(1 << i) ^ val];
                        }
                        for (auto v : g[u]) if (v != prev && !dead[v]) {
                                res += calc(v, u, val);
                        }
                        ans[u] += res;
                        return res;
                };
                add_sub(c, -1, 0, true);
                long long res = cnt[0];
                for (int i = 0; i < 20; i ++) {
                        res += cnt[1 << i];
                }
                for (auto u : g[c]) if (!dead[u]) {
                        add_sub(u, c, word[c], false);
                        res += calc(u, c, 0);
                        add_sub(u, c, word[c], true);
                }
                ans[c] += res / 2;
                add_sub(c, -1, 0, false);
                //
                dead[c] = false;
        };
        rec(0);
        return ans;
}

int main() {
        int n;
        scanf("%d", &n);
        vector<vector<int>> g(n);
        for (int i = 0; i < n - 1; i ++) {
                int a, b;
                scanf("%d%d", &a, &b);
                a --, b --;
                g[a].push_back(b);
                g[b].push_back(a);
        }
        string s;
        cin >> s;
        vector<int> word(n);
        for (int i = 0; i < s.size(); i ++) {
                word[i] = 1 << (s[i] - 'a');
        }
        vector<long long> ans = CentroidDecomposition(g, word);
        for (int i = 0; i < n; i ++) {
                printf("%lld%c", ans[i] + 1, i == n - 1 ? '\n' : ' ');
        }
        return 0;
}