Learning Algorithms

アルゴリズムの勉強メモ

Atcoder Grand Contest 019 D. Shift and Flip

Atcoder Grand Contest 019 D. Shift and Flip

感想

結局やるだけだったけど、ひたすら補題の実装ができなかった。解説は自明なパートだけだらだらと書いていて最も知りたい部分が省略されていた。

解法

まず最終的に$\ a\ $のどの部分が$\ b\ $のどの部分に対応するのかを決定して$\ O(n)\ $。あとでreverseで逆の場合を考えれば、このshiftは左にするものと考えて良い。flipの回数も一意に決定でき、あとはすべてのflipさせるべき位置がbが1であるような位置を通過するように移動の仕方を考えれば良い。

そのために最初に各位置について左と右でそれぞれbが1であるような位置までの距離を求めておく。そして、これを$\ (l, r)\ $とすれば、左に$\ x\ $回shiftしたとすると、bが1であるような位置までの距離は$\ (max(l - x, 0), r)\ $になると考えてよい。すべてのflipさせるべき位置についてこれをまとめたものを$\ V\ $とする。

すると、次の補題を解けばよい。

すべての$\ v \in V\ $に対して$\ v.first \leq a\ $または$\ v.second \leq b\ $が成り立つような$\ (a, b)\ $について$\ a + b\ $の最小値を求めよ。

この部分だけ強い人のコードを参考に書いた。まず$\ V\ $を降順にソートする。sumが求めたい答えで、sum = INFならば当然条件を満たすのでそれで初期化する。maで今まで見た中でsecondの最大値をいれる。firstの降順に見ているので、ある時点においてそれ以降に登場するものはすべてfirstによってカバーできるはずである。逆に、それ以外はfirstによってはカバーできていないので、それまでに出てきたsecondの最大値以上のsecondとしての値が必要である。よって以下のコードでうまく動く。

sort(all(V));
reverse(all(V));
int sum = INF;
int ma = 0;
for (int i = 0; i < V.size(); i ++) {
        sum = min(sum, V[i].first + ma);
        ma = max(ma, V[i].second);
}
sum = min(sum, ma);

これを求めればあとは簡単で、このsumの回数だけ寄り道して戻るので、結局答えは、shiftした回数$\ +\ $flipした回数$\ +\ sum * 2\ $である。

ソート部分が最も計算量が大きいので全体で$\ O(n^2\log n)\ $。

実装
#include "bits/stdc++.h"
using namespace std;

#define all(x)  x.begin(), x.end()
#define mp      make_pair

static const int INF = 0x3f3f3f3f;

int solve(string a, string b) {
        int n = a.size();
        int left = 0, right = 0;
        for (int i = 0; i < n; i ++) {
                if (b[i] == '0') left ++;
                else break;
        }
        for (int i = n - 1; i >= 0; i --) {
                if (b[i] == '0') right ++;
                else break;
        }
        vector<int> one;
        for (int i = 0; i < n; i ++) if (b[i] == '1') one.push_back(i);
        int op = 0;
        vector<pair<int, int>> dis(n);
        for (int i = 0; i < n; i ++) {
                if (one[op] == i) { 
                        dis[i] = mp(0, 0);
                        op ++;
                } else if (op == 0) dis[i] = mp(i + 1 + right, one[op] - i);
                else if (op == one.size()) dis[i] = mp(i - one[op - 1], n - i + left);
                else dis[i] = mp(i - one[op - 1], one[op] - i);
        }
        int ans = INF;
        for (int shift = 0; shift <= n; shift ++) {
                int flip_cnt = 0;
                vector<bool> flip(n, false);
                for (int i = 0; i < n; i ++) {
                        if (a[i] != b[(i - shift + n) % n]) {
                                flip_cnt ++;
                                flip[i] = true;
                        }
                }
                vector<pair<int, int>> dis_modified;
                for (int i = 0; i < n; i ++) if (flip[i]) { 
                        dis_modified.emplace_back(max(dis[i].first - shift, 0), dis[i].second);
                }
                sort(all(dis_modified));
                reverse(all(dis_modified));
                int sum = INF;
                int ma = 0;
                for (int i = 0; i < dis_modified.size(); i ++) {
                        sum = min(sum, dis_modified[i].first + ma);
                        ma = max(ma, dis_modified[i].second);
                }
                sum = min(sum, ma);
                int res = shift + flip_cnt + 2 * sum;
                ans = min(ans, res);
        }
        return ans;
}
int main() {
        string a, b;
        cin >> a >> b;
        int n = a.size();
        if (b == string(n, '0')) {
                cout << (a == string(n, '0') ? 0 : -1) << endl;
                return 0;
        }
        int ans = solve(a, b);
        reverse(all(a));
        reverse(all(b));
        ans = min(ans, solve(a, b));
        cout << ans << endl;
        return 0;
}