CS Academy 058 E. Path Inversions
CS Academy 058 E. Path Inversions
解法
ある長さ $k$ のパスを固定して考えると、このパス上の転倒数の個数とこのパスを逆に進むようなパス上の転倒数の個数の和は、書かれている数に関わらず常に、${}_{k + 1} C _2$ となることがわかる。これは実験してみればわかるのだが、全ての数が相異なることから、ある任意の2数をとったときに、それらを端点とするパスの一方では必ず転倒しているものとして数えられ、もう一方では必ず転倒していないものとして数えられるためである。これを踏まえると、結局頂点数 $k + 1$ 個の中からペアを作るときの個数が転倒数の和になるはずである。
よって、向きを無視した長さ $k$ のパスが木の上で何個とれるかを数えれば、答えがわかる。
ここでは、デマテク(データ構造をマージする一般的なテク)によるすっきりとした解法を記す。
頂点 $0$ を根とする根付き木で考え、各頂点に対してmap
をもたせ、cnt[u][d]
をノード $u$ を親とする部分木の中で、深さが $d$ である頂点の個数と定義する。
この値を葉の方から決めていき、決まったらその値を親ノードにマージする。あるノード $u$ を親とする部分木 $S$ に部分木 $T$ をマージするときに $T$ のある頂点から $u$ を経由して作れる長さ $k$ のパスを計算し、答えに足していく。こうするとうまくすべてを数え上げることができる。
マージするときは当然サイズの小さい方を大きい方にマージする。 $O(n \log n)$ 。
ちなみにこの問題はアリ本に載っている類題とほとんど同じで、重心分解による分割統治法でやはり $O(n \log n)$ で解ける(ほとんど貼るだけだった)。
実装1
#include <cstdio> #include <algorithm> #include <vector> #include <map> using namespace std; const long long MOD = 1e9 + 7; vector<int> g[100000]; int n, k; map<int, int> cnt[100000]; long long ans; void merge(int u, int v, int d) { if (cnt[u].size() < cnt[v].size()) swap(cnt[u], cnt[v]); for (auto &sub : cnt[v]) { int dd = d + (k - (sub.first - d)); if (cnt[u].find(dd) != cnt[u].end()) { ans += (long long) cnt[u][dd] * sub.second; ans %= MOD; } } for (auto &sub : cnt[v]) cnt[u][sub.first] += sub.second; cnt[v].clear(); } void dfs(int u, int prev, int d) { cnt[u][d] ++; for (auto v : g[u]) if (v != prev) { dfs(v, u, d + 1); merge(u, v, d); } } int main() { scanf("%d%d", &n, &k); for (int i = 0; i < n - 1; i ++) { int a, b; scanf("%d%d", &a, &b); a --, b --; g[a].push_back(b); g[b].push_back(a); } ans = 0; dfs(0, -1, 0); printf("%lld\n", ((long long) k * (k + 1) / 2) % MOD * ans % MOD); return 0; }
実装2
#include <cstdio> #include <algorithm> #include <vector> #include <map> using namespace std; const int MOD = 1e9 + 7; static const int INF = 0x3f3f3f3f; struct edge { int to, cost; }; const int N = 100000; int k; vector<edge> g[N]; bool divided[N]; int subtree_size[N]; long long ans; int ComputeSubtreeSize(int u, int prev) { int cnt = 1; for (int i = 0; i < g[u].size(); i ++) { int v = g[u][i].to; if (v == prev || divided[v]) continue; cnt += ComputeSubtreeSize(v, u); } subtree_size[u] = cnt; return cnt; } pair<int, int> SearchCentroid(int u, int prev, int all_size) { pair<int, int> res = make_pair(INF, -1); int sum = 1, ma = 0; for (int i = 0; i < g[u].size(); i ++) { int v = g[u][i].to; if (v == prev || divided[v]) continue; res = min(res, SearchCentroid(v, u, all_size)); ma = max(ma, subtree_size[v]); sum += subtree_size[v]; } ma = max(ma, all_size - sum); res = min(res, make_pair(ma, u)); return res; } void EnumeratePaths(int u, int prev, int dist, vector<int> &ds) { ds.push_back(dist); for (int i = 0; i < g[u].size(); i ++) { int v = g[u][i].to; if (v == prev || divided[v]) continue; EnumeratePaths(v, u, dist + g[u][i].cost, ds); } } int CountPairs(vector<int> &ds) { int res1 = 0, res2 = 0; sort(ds.begin(), ds.end()); int j = ds.size(); for (int i = 0; i < ds.size(); i ++) { while (j > 0 && ds[i] + ds[j - 1] > k) j --; res1 += j - (j > i); } res1 /= 2; j = ds.size(); for (int i = 0; i < ds.size(); i ++) { while (j > 0 && ds[i] + ds[j - 1] >= k) j --; res2 += j - (j > i); } res2 /= 2; return res1 - res2; } void SolveSubproblem(int u) { ComputeSubtreeSize(u, -1); int centroid = SearchCentroid(u, -1, subtree_size[u]).second; divided[centroid] = true; for (int i = 0; i < g[centroid].size(); i ++) { int v = g[centroid][i].to; if (divided[v]) continue; SolveSubproblem(v); } vector<int> ds; ds.push_back(0); for (int i = 0; i < g[centroid].size(); i ++) { int v = g[centroid][i].to; if (divided[v]) continue; vector<int> ds_tmp; EnumeratePaths(v, centroid, g[centroid][i].cost, ds_tmp); ans -= CountPairs(ds_tmp); ds.insert(ds.end(), ds_tmp.begin(), ds_tmp.end()); } ans += CountPairs(ds); divided[centroid] = false; } void CentroidDecomposition() { ans = 0; SolveSubproblem(0); } int main() { int n; scanf("%d%d", &n, &k); for (int i = 0; i < n - 1; i ++) { int a, b; scanf("%d%d", &a, &b); a --, b --; g[a].push_back({b, 1}); g[b].push_back({a, 1}); } CentroidDecomposition(); printf("%lld\n", ((long long) k * (k + 1) / 2) % MOD * ans % MOD); return 0; }