Atcoder Grand Contest 019 D. Shift and Flip
Atcoder Grand Contest 019 D. Shift and Flip
感想
結局やるだけだったけど、ひたすら補題の実装ができなかった。解説は自明なパートだけだらだらと書いていて最も知りたい部分が省略されていた。
解法
まず最終的に$\ a\ $のどの部分が$\ b\ $のどの部分に対応するのかを決定して$\ O(n)\ $。あとでreverse
で逆の場合を考えれば、このshiftは左にするものと考えて良い。flipの回数も一意に決定でき、あとはすべてのflipさせるべき位置がbが1であるような位置を通過するように移動の仕方を考えれば良い。
そのために最初に各位置について左と右でそれぞれbが1であるような位置までの距離を求めておく。そして、これを$\ (l, r)\ $とすれば、左に$\ x\ $回shiftしたとすると、bが1であるような位置までの距離は$\ (max(l - x, 0), r)\ $になると考えてよい。すべてのflipさせるべき位置についてこれをまとめたものを$\ V\ $とする。
すると、次の補題を解けばよい。
すべての$\ v \in V\ $に対して$\ v.first \leq a\ $または$\ v.second \leq b\ $が成り立つような$\ (a, b)\ $について$\ a + b\ $の最小値を求めよ。
この部分だけ強い人のコードを参考に書いた。まず$\ V\ $を降順にソートする。sum
が求めたい答えで、sum = INF
ならば当然条件を満たすのでそれで初期化する。ma
で今まで見た中でsecond
の最大値をいれる。first
の降順に見ているので、ある時点においてそれ以降に登場するものはすべてfirst
によってカバーできるはずである。逆に、それ以外はfirst
によってはカバーできていないので、それまでに出てきたsecond
の最大値以上のsecond
としての値が必要である。よって以下のコードでうまく動く。
sort(all(V)); reverse(all(V)); int sum = INF; int ma = 0; for (int i = 0; i < V.size(); i ++) { sum = min(sum, V[i].first + ma); ma = max(ma, V[i].second); } sum = min(sum, ma);
これを求めればあとは簡単で、このsum
の回数だけ寄り道して戻るので、結局答えは、shiftした回数$\ +\ $flipした回数$\ +\ sum * 2\ $である。
ソート部分が最も計算量が大きいので全体で$\ O(n^2\log n)\ $。
実装
#include "bits/stdc++.h" using namespace std; #define all(x) x.begin(), x.end() #define mp make_pair static const int INF = 0x3f3f3f3f; int solve(string a, string b) { int n = a.size(); int left = 0, right = 0; for (int i = 0; i < n; i ++) { if (b[i] == '0') left ++; else break; } for (int i = n - 1; i >= 0; i --) { if (b[i] == '0') right ++; else break; } vector<int> one; for (int i = 0; i < n; i ++) if (b[i] == '1') one.push_back(i); int op = 0; vector<pair<int, int>> dis(n); for (int i = 0; i < n; i ++) { if (one[op] == i) { dis[i] = mp(0, 0); op ++; } else if (op == 0) dis[i] = mp(i + 1 + right, one[op] - i); else if (op == one.size()) dis[i] = mp(i - one[op - 1], n - i + left); else dis[i] = mp(i - one[op - 1], one[op] - i); } int ans = INF; for (int shift = 0; shift <= n; shift ++) { int flip_cnt = 0; vector<bool> flip(n, false); for (int i = 0; i < n; i ++) { if (a[i] != b[(i - shift + n) % n]) { flip_cnt ++; flip[i] = true; } } vector<pair<int, int>> dis_modified; for (int i = 0; i < n; i ++) if (flip[i]) { dis_modified.emplace_back(max(dis[i].first - shift, 0), dis[i].second); } sort(all(dis_modified)); reverse(all(dis_modified)); int sum = INF; int ma = 0; for (int i = 0; i < dis_modified.size(); i ++) { sum = min(sum, dis_modified[i].first + ma); ma = max(ma, dis_modified[i].second); } sum = min(sum, ma); int res = shift + flip_cnt + 2 * sum; ans = min(ans, res); } return ans; } int main() { string a, b; cin >> a >> b; int n = a.size(); if (b == string(n, '0')) { cout << (a == string(n, '0') ? 0 : -1) << endl; return 0; } int ans = solve(a, b); reverse(all(a)); reverse(all(b)); ans = min(ans, solve(a, b)); cout << ans << endl; return 0; }