Learning Algorithms

アルゴリズムの勉強メモ

Atcoder Regular Contest 084 F. XorShift

Atcoder Regular Contest 084 F. XorShift

F: XorShift - AtCoder Regular Contest 084 | AtCoder

解法

基底とする数は与えられる$\ n\ $個の数だけとして考えて良い。これらを任意の個数使って、シフトさせたり$\ xor\ $をとったりして数を作る。

ここで$\ xor\ $をとるという操作は、$\ mod\ 2\ $で各桁の差(あるいは和)をとるという操作に読み替えることができる。さらにシフトの操作を加えると、除算の筆算のような操作を繰り返していけることがわかる。

結局これらの操作によって作ることができる数というのは、これらの$\ gcd\ $の倍数全体であり、逆にそれ以外の数は作れないことが言える。これは整数で考えるとわかりやすくて、例えば$\ \{18, 42, 66\}\ $という数から適当な加減を繰り返すことによって$\ gcd(18, 42, 66) = 6\ $の倍数はすべて作ることができるし、逆にそれ以外はどうやっても作れない。

よって$\ n\ $個の数の$\ gcd\ $を求め、$\ X\ $以下の$\ gcd\ $の倍数の個数を求めればそれが答えである。 これはなんか頑張るとできる。

bitsetvectorによる実装をどちらもした。bitsetの方が、$\ xor\ $の操作がわかりやすかったり、stringを経由しなくても数を扱えるという点で優れていると思うが、実質の長さを知りたいときや、大小比較をしたいときに少し不便だったり、慣れの問題だが先頭(左端)のインデックスが$\ size - 1\ $であるためバグる確率が高い感じがあった。

実装1(vectorによる)
#include "bits/stdc++.h"
using namespace std;

static const int MOD = 998244353;

vector<int> BitsetGcd(vector<int> s, vector<int> t) {
        int n = s.size(), m = t.size();
        if (n < m) return BitsetGcd(t, s);
        if (!m) return s;
        for (int i = 0; i < m; i ++) s[i] ^= t[i];
        int p = n;
        for (int i = 0; i < n; i ++) {
                if (s[i]) {
                        p = min(p, i);
                        break;
                }
        }
        vector<int> a(n - p);
        for (int i = 0; i < n - p; i ++) a[i] = s[p + i];
        return BitsetGcd(t, a);
}

long long ModPow(long long x, long long n, long long m) {
        long long res = 1;
        while (n > 0) {
                if (n & 1) res = res * x % m;
                x = x * x % m;
                n >>= 1;
        }
        return res;
}

int main() {
        int n;
        scanf("%d", &n);
        string x;
        cin >> x;
        vector<int> g;
        for (int i = 0; i < n; i ++) {
                string s;
                cin >> s;
                int k = s.size();
                vector<int> b(k);
                for (int i = 0; i < k; i ++) b[i] = s[i] - '0';
                g = BitsetGcd(g, b);
        }
        n = x.size();
        vector<int> c(n), d(n);
        for (int i = 0; i < n; i ++) c[i] = x[i] - '0';
        int m = g.size();
        int ans = 0;
        for (int i = 0; i + m <= n; i ++) {
                if (c[i]) ans = (ans + ModPow(2, n - m - i, MOD)) % MOD;
                if (c[i] == d[i]) continue;
                for (int j = 0; j < m; j ++) d[i + j] ^= g[j];
        }
        printf("%d\n", (ans + (d > c ? 0 : 1)) % MOD);
        return 0;
}
実装2(bitsetによる)
#include "bits/stdc++.h"
using namespace std;

static const int MOD = 998244353;

#define N 4040

int Len(const bitset<N> &a) {
        for (int res = N - 1; res >= 0; res --) {
                if (a[res]) return res + 1;
        }
        return 0;
}

bitset<N> gcd(bitset<N> s, bitset<N> t) {
        int n = Len(s), m = Len(t);
        if (n < m) return gcd(t, s);
        if (t.none()) return s;
        int d = n - m;
        bitset<N> u = t;
        u <<= d;
        s ^= u;
        return gcd(t, s);
}

long long ModPow(long long x, long long n, long long m) {
        long long res = 1;
        while (n > 0) {
                if (n & 1) res = res * x % m;
                x = x * x % m;
                n >>= 1;
        }
        return res;
}

int main() {
        int q;
        scanf("%d", &q);
        bitset<N> x;
        cin >> x;
        bitset<N> g;
        for (int i = 0; i < q; i ++) {
                bitset<N> s;
                cin >> s;
                g = gcd(g, s);
        }
        int n = Len(x);
        bitset<N> d(0);
        int m = Len(g);
        int ans = 0;
        for (int i = n - 1; i >= m - 1; i --) {
                if (x[i]) ans = (ans + ModPow(2, i - m + 1, MOD)) % MOD;
                if (x[i] == d[i]) continue;
                d ^= (g << i - m + 1);
        }
        bool ok = true;
        for (int i = n - 1; i >= 0; i --) {
                if (d[i] < x[i]) break;
                if (d[i] > x[i]) ok = false;
        }
        printf("%d\n", (ans + ok) % MOD);
        return 0;
}