Learning Algorithms

アルゴリズムの勉強メモ

Atcoder Petrozavodsk Contest 001 H. Generalized Insertion Sort

H - Generalized Insertion Sort

解法

まず制約に注目すると、$25000 \geq n \log n$ にしか見えないので、そのような計算回数を実現するアルゴリズムを考えたい気持ちになります。

木に関するアルゴリズムであって、計算量に $\log$ が現れるものは僕は現時点ではそんなに知らなくて、$LCA$ や $HL$ 分解などでした。すると、自然に $HL$ 分解のときに使うパスに近いものを考える発想に至ります。

木の頂点であって、その頂点の部分木がパスであるようなものをパス頂点と呼ぶことにします。パス頂点であって、同一のパスを共有する頂点同士はまとめてソートできないでしょうか。

もしこれができれば、それ以降その木のパス頂点はいじる必要がなくなるので、破壊してしまってよさそうです。パス頂点になる条件を考えるとこの破壊の回数は $\log n$ 回で終了します。

さて、まず明らかにどの操作も根が関わることになるので、根の値に注目します。根の値がいずれかのパス頂点の $index$ と一致するならば、それをそのパスに移動させることができます。ここで、パス全体をまとめてソートするために、最終的な値の列と相対的な順序関係が同じになるように丁寧にソートをしていきます。つまり、パスの下の方から見ていって、挿入しようとする値と、今見ている値を比較して適切な位置に挿入します。もしどの $index$ とも一致しないならば、適当に根にもってきたい値を探してその値の位置に適当に挿入します。

計算量自体は $n$ が小さいことから、すべて適当に書くだけでも通ります。

実装
#include <cstdio>
#include <vector>
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <string>
#include <iostream>
#include <cassert>
#include <cmath>
using namespace std;

int main() {
        int n;
        scanf("%d", &n);
        vector<vector<int>> g(n);
        vector<int> par(n);
        for (int i = 1; i < n; i ++) {
                int p;
                scanf("%d", &p);
                g[p].push_back(i);
                par[i] = p;
        }
        vector<int> a(n);
        for (int i = 0; i < n; i ++) {
                scanf("%d", &a[i]);
        }
        vector<int> color(n, 0); //2:need to complete, 1:determined, 0:no considering, -1:dead
        vector<int> ans;
        for (;;) {
                if (color[0] == -1) break;
                vector<bool> path(n);
                function<void (int)> check_path = [&](int u) {
                        path[u] = true;
                        int cnt = 0;
                        for (auto v : g[u]) if (color[v] != -1) {
                                cnt ++;
                                check_path(v);
                                path[u] = path[u] && path[v];
                        }
                        if (cnt > 1) path[u] = false;
                };
                check_path(0);
                set<int> paths;
                for (int i = 0; i < n; i ++) {
                        if (path[i]) {
                                paths.insert(i);
                        }
                }
                for (int i = 0; i < n; i ++) {
                        if (paths.count(a[i])) {
                                color[i] = 2;
                        }
                }
                function<int (int)> find_bottom = [&](int u) {
                        assert(path[u]);
                        for (auto v : g[u]) if (color[v] != -1) {
                                return find_bottom(v);
                        }
                        return u;
                };
                function<void (int, int, int)> insert = [&](int u, int val, int col) {
                        int tmpa = a[u];
                        int tmpc = color[u];
                        a[u] = val;
                        color[u] = col;
                        if (u != 0) insert(par[u], tmpa, tmpc);
                };
                function<bool (int, int, int)> compare = [&](int pos, int in_num, int cur_num) {
                        while (true) {
                                if (pos == in_num) {
                                        return false;
                                } else if (pos == cur_num) {
                                        return true;
                                }
                                if (pos == 0) {
                                        break;
                                }
                                pos = par[pos];
                        }
                        return false;
                };
                vector<int> depth(n);
                function<void (int, int)> get_depth = [&](int u, int d) {
                        depth[u] = d;
                        for (auto v : g[u]) if (color[v] != -1) {
                                get_depth(v, d + 1);
                        }
                };
                while (true) {
                        bool ok = true;
                        for (int i = 0; i < n; i ++) {
                                ok = ok && a[i] == i;
                        }
                        if (ok) { 
                                break;
                        }
                        if (color[0] == 2) {
                                int target = a[0];
                                int cur = find_bottom(target);
                                while (color[cur] == 1 && compare(cur, target, a[cur])) {
                                        cur = par[cur];
                                }
                                insert(cur, target, 1);
                                ans.push_back(cur);
                        } else {
                                get_depth(0, 0);
                                int idx = -1;
                                int ma = -1;
                                for (int i = 0; i < n; i ++) {
                                        if (ma < depth[i] && color[i] == 2) {
                                                ma = depth[i];
                                                idx = i;
                                        }
                                }
                                if (idx == -1) {
                                        break;
                                }
                                insert(idx, a[0], color[0]);
                                ans.push_back(idx);
                        }
                }
                for (int i = 0; i < n; i ++) {
                        if (path[i]) {
                                color[i] = -1;
                        }
                }
        }
        printf("%d\n", (int) ans.size());
        for (int i = 0; i < (int) ans.size(); i ++) {
                printf("%d\n", ans[i]);
        }
        return 0;
}