Learning Algorithms

アルゴリズムの勉強メモ

Atcoder Grand Contest 008 F. Black Radius

Atcoder Grand Contest 008 F. Black Radius

F: Black Radius - AtCoder Grand Contest 008 | AtCoder

感想

こちらのei1333さんの記事†全方位木DP†について - ei1333の日記で全方位木dpの概念を学んだところだったので、そこまで苦労しなかった。わかりやすい記事をありがとうございます。

解法

まず重複しないようにすべての配色を数えるために、各頂点について、それが中心になるような配色としてありえるものを数えるという方針をとる。これを辺についても同じように考える。これらの頂点や辺が異なれば、必ずその配色も異なることは明らかである。

まず頂点を$\ v\ $に固定して考える。もしこの点がお気に入りの点であるならば、その半径$\ r\ $とすると、$\ 0 \leq r \leq d_1\ $の範囲で連続的に$\ r\ $を動かしていくことができる。そしてこれらは明らかに異なる配色になる。この上限$\ d_1\ $というのは、絵を書いてみればわかるように、$\ v\ $が中心になるようにするためには、$\ v\ $からのびる部分木の中で、深さが2番目に大きいものでなければいけない。 

この点がお気に入りの点でない場合は、その半径を$\ r\ $とすると、$\ d_2 \leq r \leq d_1\ $の範囲で連続的に変化させることができる。ただし、$\ d_1\ $は上と同じである。一方下限の方は、まずその色を塗るためのお気に入りの点が含まれている部分木がある必要がある。その頂点から、$\ v\ $を通り越して別の部分木に同じだけの$\ r\ $の範囲を塗るものでなければいけないので、その部分木というのは十分深さの小さいものが好ましい。お気に入りの点を含む部分木の中で、深さが最小のもののその深さを$\ d\ $とすると、その部分木内はすべて黒に塗り尽くしてさらに$\ v\ $を通り越して$\ r\ $だけ塗ることになるのだが、結局この$\ r\ $というのは$\ d\ $に一致しているので、この$\ d\ $が下限$\ d_2\ $である。

次に辺が中心になる場合を考える。少し考察すれば、ある辺が中心になるような配色というのは高々1個しかない。なぜなら、その辺が接続する頂点を$\ u, v\ $として、$\ u, v\ $からそれぞれの部分木を同時に塗っていくとき、そのどちらか一方が塗り終わるまでは明らかにどの頂点を選んでもうまく塗ることができない。そして一方が塗り終わったその時に限り、その辺が中心になるようなバランスを崩すことなく、その塗り終えた部分木の中からある頂点を選んでそこから上手く塗ることができる。よってこのようなことができるどうかの判定は、$\ u, v\ $を根とする部分木のうち深さが小さい方にお気に入りの点が含まれているかどうかを見るだけで良い。

実装は、まず1回目のdfs1では、$\ 0\ $を根とする根付き木について、各部分木の最大の深さfarとその部分木にお気に入りの点が含まれているかどうかのin_favを順に求めておく。

次に、2回目のdfs2をする。親の自分を除く部分木についての最大の深さと、親と自分以外の部分木の中にお気に入りの点が存在するかどうか、という情報を伝播させている。

そのような部分木をchildrenにいれ、その中でもお気に入りの点を含むような部分木を特にfav_childrenにいれる。そしてchildrenについてはその中で2番目に大きいものを記録し、fav_childrenについては最も浅いものを記録する。

辺に関する探索は各頂点から次の頂点に移動する、その直前にちょうど必要な情報が揃っているので、それでぱぱっと答えに足しておく。

実装自体は書いてみたあとでは意外と単純なことに気がつく。全方位木dpについて注意することは、とにかく「次に探索しようとする部分木の情報(自分自身の情報)」が入らないようにして伝播させること。そして根の次数が$\ 1\ $である場合などにうまく適切な値を伝播させることができるように、最初に適切な初期値をchildrenなどの中にいれておくこと。

実装
#include "bits/stdc++.h"
using namespace std;
 
#define all(x)  x.begin(), x.end()
#define mp      make_pair
#define pii     pair<int, int>
 
int n;
vector<int> g[202020];
bool fav[202020];
int far[202020];
int second_farthest[202020];
int closest[202020];
bool in_fav[202020];
long long ans = 0;
 
static const int INF = 0x3f3f3f3f;
 
void dfs1(int v, int prev) {
        if (fav[v]) in_fav[v] = true;
        for (auto u : g[v]) if (u != prev) {
                dfs1(u, v);
                in_fav[v] |= in_fav[u];
                far[v] = max(far[v], far[u] + 1);
        }
}
 
void dfs2(int v, int ma, bool par_fav, int prev) {
        vector<pair<pii, bool>> children;
        children.push_back(mp(mp(0, -1), false));
        for (auto u : g[v]) {
                if (u == prev) children.push_back(mp(mp(ma + 1, u), par_fav));
                else children.push_back(mp(mp(far[u] + 1, u), in_fav[u]));
        }
        vector<pii> fav_children;
        for (auto c : children) if (c.second) fav_children.push_back(mp(c.first.first, c.first.second));
        sort(all(fav_children));
        if (fav_children.empty()) closest[v] = INF;
        else closest[v] = fav_children[0].first;
        sort(all(children), greater<pair<pii, bool>>());
        second_farthest[v] = children[1].first.first;
        for (auto u : g[v]) if (u != prev) { 
                int maa = children[u == children[0].first.second ? 1 : 0].first.first;
                bool par_f = fav[v] || par_fav;
                if (!in_fav[u]) par_f |= !fav_children.empty();
                else par_f |= fav_children.size() > 1;
                ans += (maa < far[u] && par_f) || (maa > far[u] && in_fav[u]) || (maa == far[u] && (par_f || in_fav[u]));
                dfs2(u, maa, par_f, v);
        }
}
 
int main() {
        cin >> n;
        for (int i = 0; i < n - 1; i ++) {
                int a, b;
                cin >> a >> b;
                a --, b --;
                g[a].push_back(b);
                g[b].push_back(a);
        }
        string s;
        cin >> s;                                            
        for (int i = 0; i < n; i ++) fav[i] = s[i] == '1' ? true : false;
        dfs1(0, -1);
        dfs2(0, 0, fav[0], -1);
        for (int i = 0; i < n; i ++) {
                if (fav[i]) ans += 1 + second_farthest[i];
                else ans += max(0, 1 + (second_farthest[i] - closest[i]));
        }
        cout << ans << endl;
        return 0;
}