Learning Algorithms

アルゴリズムの勉強メモ

Atcoder Petrozavodsk Contest 001 I. Simple APSP Problem

I - Simple APSP Problem

$2000$ 点の問題ですが、以下のアルゴリズムで使う発想ができれば、解くことができます。

Hirschberg's Algorithmについて - Learning Algorithms

English version is available here.

解法

グリッドを適当に二つに分割することを考えてみます。すると、その $2$ つのグリッドそれぞれに端点を持つ最短経路というのは、$Hirschberg's\ Algorithm$ でも見たように、そのグリッド間の境界を必ず一回だけ通ります。この分割を、白しか存在しない $2$ 行(列)において行うと、その境界を通るものだけまとめて計算することができます。その値は、それぞれのグリッドに含まれる白のマスの個数の積に等しくなります。

以降、この行(列)は考慮しなくて良いので、縮約してしまってよさそうです。実はグリッドのほとんどのマスは白なので、これを繰り返すことで、グリッドのサイズを $O(n) * O(n)$ にまで落とすことができます。

したがって縮約したグリッド上で全点対最短距離を $O(n^4)$ で求めればよさそうです。さて、これでこの問題が解けました。と、言いたいところですが、一つ注意があって、例えば以下のグリッドを縮約するとします。

f:id:KokiYamaguchi:20180205160002j:plain

この場合、画像左に示した部分を縮約することになります。したがって画像右のように、マスに重み(縮約されているマスの個数)がついていなければいけないはずです。

f:id:KokiYamaguchi:20180205160017j:plain

これは、行と列それぞれ独立に、白マスだけが続く長さと黒マスの出現を保存していき、新しくマッピングし直した座標で、それらの長さの積をとったものを重みとすることで実装できます。

縮約したグリッド上でのマス $A$ (重み $W_A$)とマス $B$ (重み $W_B$)の最短経路 $d$ は、縮約前の $A$ に含まれるすべてのマスから、縮約前の $B$ に含まれるすべてのマスへの最短距離になっているので、$d * W_A * W_B$ を全点対について計算して総和をとります。最後に重複を除くために $2$ の逆元をかけて、上で求めていた結果と足し合わせるとそれが答えです。

実装
#include <cstdio>
#include <vector>
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <string>
#include <iostream>
#include <cassert>
#include <cmath>
#include <queue>
using namespace std;

const int MOD = 1e9 + 7;

struct state { int y, x, step; };
static const int dx[] = {1, 0, -1, 0}, dy[] = {0, 1, 0, -1};

int main() {
        long long h, w;
        scanf("%lld %lld", &h, &w);
        int n;
        scanf("%d", &n);
        vector<int> w_cnt(w, 0), h_cnt(h, 0);
        vector<bool> w_exist(w, false), h_exist(h, false);
        vector<pair<int, int>> black;
        for (int i = 0; i < n; i ++) {
                int y, x;
                scanf("%d %d", &y, &x);
                w_cnt[x] ++;
                h_cnt[y] ++;
                w_exist[x] = true;
                h_exist[y] = true;
                black.emplace_back(x, y);
        }
        for (int i = 1; i < w; i ++) w_cnt[i] += w_cnt[i - 1];
        for (int i = 1; i < h; i ++) h_cnt[i] += h_cnt[i - 1];
        long long ans = 0;
        //precalc
        for (int i = 0; i < w - 1; i ++) {
                if (!w_exist[i] && !w_exist[i + 1]) {
                        long long left = (long long) (i + 1) * h % MOD - w_cnt[i];
                        long long right = (long long) (w - (i + 1)) * h % MOD - (n - w_cnt[i]);
                        ans += left * right;
                        ans %= MOD;
                }
        }
        for (int i = 0; i < h - 1; i ++) {
                if (!h_exist[i] && !h_exist[i + 1]) {
                        long long left = (long long) (i + 1) * w % MOD - h_cnt[i];
                        long long right = (long long) (h - (i + 1)) * w % MOD - (n - h_cnt[i]);
                        ans += left * right;
                        ans %= MOD;
                }
        }
        //compress
        map<int, int> newx, newy;
        vector<pair<long long, bool>> widths, heights; //(length, is_white)
        {
                int cnt = 0;
                for (int i = 0; i < w; i ++) {
                        if (!w_exist[i]) {
                                cnt ++;
                        } else {
                                if (cnt) {
                                        widths.emplace_back(cnt, true);
                                        cnt = 0;
                                }
                                newx[i] = (int) widths.size();
                                widths.emplace_back(1, false);
                        }
                }
                if (cnt) widths.emplace_back(cnt, true);
        }
        {
                int cnt = 0;
                for (int i = 0; i < h; i ++) {
                        if (!h_exist[i]) {
                                cnt ++;
                        } else {
                                if (cnt) {
                                        heights.emplace_back(cnt, true);
                                        cnt = 0;
                                }
                                newy[i] = (int) heights.size();
                                heights.emplace_back(1, false);
                        }
                }
                if (cnt) heights.emplace_back(cnt, true);
        }
        //re-write the grid
        int neww = (int) widths.size();
        int newh = (int) heights.size();
        vector<vector<long long>> s(newh, vector<long long> (neww, 1)); //-1 when it's black, weight when it's white
        for (auto b : black) {
                s[newy[b.second]][newx[b.first]] = -1;
        }
        for (int i = 0; i < newh; i ++) {
                for (int j = 0; j < neww; j ++) {
                        if (heights[i].second || widths[j].second) {
                                s[i][j] = heights[i].first * widths[j].first % MOD;
                        }
                }
        }
        //BFS
        long long sum = 0;
        for (int sy = 0; sy < newh; sy ++) {
                for (int sx = 0; sx < neww; sx ++) {
                        if (s[sy][sx] == -1) continue;
                        long long res = s[sy][sx];
                        vector<vector<bool>> used(newh, vector<bool>(neww, false));
                        queue<state> q;
                        q.push({sy, sx, 0});
                        used[sy][sx] = true;
                        while (!q.empty()) {
                                state p = q.front(); q.pop();
                                if (p.y != sy || p.x != sx) {
                                        assert(s[p.y][p.x] != -1);
                                        sum += (long long) p.step * res % MOD * s[p.y][p.x] % MOD;
                                        sum %= MOD;
                                }
                                for (int d = 0; d < 4; d ++) {
                                        int xx = p.x + dx[d], yy = p.y + dy[d];
                                        if (xx < 0 || xx >= neww || yy < 0 || yy >= newh) continue;
                                        if (used[yy][xx] || s[yy][xx] == -1) continue;
                                        used[yy][xx] = true;
                                        q.push({yy, xx, p.step + 1});
                                }
                        }
                }
        }
        sum *= (MOD + 1) / 2;
        sum %= MOD;
        ans += sum;
        ans %= MOD;
        printf("%lld\n", ans);
        return 0;
}