Learning Algorithms

アルゴリズムの勉強メモ

Atcoder AGC 003 D. Anticube

Atcoder AGC 003 D. Anticube

D: Anticube - AtCoder Grand Contest 003 | AtCoder

感想

素因数分解したときに出てくる3乗は消しても良いことには割とすぐに気づき、共存できない数同士のペアが一意に定まるところまでは落とし込めた。が、実装方法と計算量の落とし方はなんだかさっぱりだったので、結局nuipさんのコードをジロジロ見ながら実装した。恥ずかしながら、unordered_mapを初めて使った。

解法

まずいかなる3乗も1と等価と見なしてよい。すると、どの数も素因数をそれぞれ高々2個までしかもたないような数に書き換えることができる(調べて見るとこれを正規化というらしい?)。

次に、この操作によって定まる数には、(1を除いて)それと共存できないそれとは異なる数が一意に定まる。これは割と自明で、例えば上の操作によって、$\ x = a^1 * b^2 * c ^ 1\ $のような感じになったとすると、これと共存できない数は、(上の操作を行なった結果)$\ y = a ^ 2 * b ^ 1 * c ^ 2\ $となるような数しかない。実装も基本的にはこのように片方の指数$\ x\ $に対してもう一方の指数は$\ 3 - x\ $とする。

これらの2数のうち出現した回数が多い方を採用していけば良いことになる。ただし、上の操作によって1になるような数が1つでも存在すれば、そのような数のうち1個だけは追加できるので、1を足す。

素数の列挙は、$\ \sqrt[3]{10^{10}}\ $以下の数までで抑えないと$\ TLE\ $する。ペアの数がどのような数であるかを理解していればそんなに難しくなかった。

ちなみに下から5行目の不等式の評価は、そのままパクったのだが、個数が同じときにでもどちらかの個数を1回だけ足さなければいけないため、ペアの値が常に相異なることを利用して、pairで比較して大きい方のものを取る、という風にうまく実装されている。

実装
#include "bits/stdc++.h"
using namespace std;

#define pll     pair<long long, long long>
#define ll      long long

//N以下の素数列挙。O(N log log N)
int N = 2180; //N * N * N >= 10^10
vector<int> prime;
void init() {
        vector<bool> is_prime(N + 1, true);
        is_prime[0] = is_prime[1] = false;
        for (int i = 2; i <= N; i ++) {
                if (is_prime[i]) {
                        prime.push_back(i);
                        for (int j = i + i; j <= N; j += i) is_prime[j] = false;
                }
        }
}

int main() {
        init();
        int n;
        cin >> n;
        vector<ll> a(n);
        for (int i = 0; i < n; i ++) cin >> a[i];
        unordered_map<ll, pll> cnts; //その値, 個数, その値のペアの値
        for (auto v : a) {
                ll val = 1, another = 1;
                for (auto p : prime) {
                        if (v % p != 0) continue;
                        int cnt = 0;
                        while (v % p == 0) {
                                cnt ++;
                                v /= p;
                        }
                        if (cnt % 3 == 1) {
                                val *= p;
                                another *= p * p;
                        } else if (cnt % 3 == 2) {
                                val *= p * p;
                                another *= p;
                        }
                }
                if (v > 1) {
                        ll tmp = sqrt(v);
                        ll root = 0;
                        for (int i = -1; i < 2; i ++) if ((tmp + i) * (tmp + i) == v) root = tmp + i;
                        val *= v;
                        if (root) another *= root;
                        else another *= v * v;
                }
                cnts[val].second = another;
                cnts[val].first ++;
        }
        auto tmp = cnts;
        ll ans = 0;
        for (auto u : tmp) if (u.first != 1 && cnts[u.second.second] < u.second) ans += u.second.first;
        if (cnts[1].first != 0) ans ++;
        cout << ans << endl;
        return 0;
}