Learning Algorithms

アルゴリズムの勉強メモ

Atcoder Regular Contest 069 E. Frequency

E - Frequency

他の人の実装を見るとなんだかあっさり実装されているので、より簡単な解法がありそうです。

解法

数列に出現する数 $x$ について、数列全体に $x$ 以上の数が何個あって、その一番左の数の $index$ が何であるのかがわかれば答えがわかりそうです。

それを求めるのは難しいので、構成する数列は必ず広義単調減少になることを利用します。今、位置 $idx$ の数を構成に使っているとすると、次に構成に使う数というのは、区間 $[0, idx)$ の中で最大の値を持つものです。これはセグメント木で求めることができます。ただし、最大の値だけでなく、その $index$ (複数ある場合は一番左の)も返すようにします。

左側にある最大値よりも大きい数が右側にある場合は少し注意が必要で、構成に使う数を今の位置に保持しながら、右の数を処理してから左の数に移動するようにします。これで最適な数列が構成できました。

実装
#include <cstdio>
#include <vector>
#include <algorithm>
#include <functional>
#include <map>
#include <set>
#include <string>
#include <iostream>
#include <cassert>
#include <cmath>
using namespace std;

template<typename Type>
struct SegmentTree {
        vector<Type> data;
        int n;
        SegmentTree(int x) {
                n = pow(2, ceil(log2(x)));
                data.resize(2 * n - 1);
                for (int i = 0; i < 2 * n - 1; i ++) {
                        data[i] = make_pair(0, -1);
                }
        }
        Type merge(Type x, Type y) { //merge function (val, idx) maximum value and minimum idx
                if (x.first > y.first) {
                        return x;
                } else if (x.first < y.first) {
                        return y;
                } else {
                        return make_pair(x.first, min(x.second, y.second));
                }
        }
        void update(int k, Type p) { //update k-th value to p
                k += n - 1;
                data[k] = p;
                while (k > 0) {
                        k = (k - 1) / 2;
                        data[k] = merge(data[k * 2 + 1], data[k * 2 + 2]);
                }
        }
        // [l, r)
        Type query(int a, int b) { return query(a, b, 0, 0, n); }
        Type query(int a, int b, int k, int l, int r) {
                if (r <= a || b <= l) return make_pair(0, -1); //initial value
                if (a <= l && r <= b) return data[k];
                int m = (l + r) / 2;
                return merge(query(a, b, k * 2 + 1, l, m), query(a, b, k * 2 + 2, m, r));
        }
};

int main() {
        int n;
        scanf("%d", &n);
        vector<int> a(n);
        for (int i = 0; i < n; i ++) {
                scanf("%d", &a[i]);
        }
        SegmentTree<pair<int, int>> seg(n);
        for (int i = 0; i < n; i ++) {
                seg.update(i, make_pair(a[i], i));
        }
        map<int, int> cnt;
        for (int i = 0; i < n; i ++) {
                cnt[a[i]] ++;
        }
        sort(a.rbegin(), a.rend());
        a.erase(unique(a.begin(), a.end()), a.end());
        vector<long long> ans(n);
        int idx = n;
        int cur = 0;
        long long total = 0;
        for (;;) {
                pair<int, int> ma = seg.query(0, idx);
                total += cnt[a[cur]];
                if (cur + 1 < a.size()) ans[ma.second] += (long long) (a[cur] - a[cur + 1]) * total;
                else ans[ma.second] += (long long) a[cur] * total;
                cur ++;
                if (cur == a.size()) break;
                if (seg.query(0, ma.second).first == a[cur]) {
                        idx = ma.second;
                }
        }
        for (int i = 0; i < n; i ++) {
                printf("%lld\n", ans[i]);
        }
        return 0;
}