Learning Algorithms

アルゴリズムの勉強メモ

Atcoder ARC 030 D. グラフではない

Atcoder ARC 030 D. グラフではない

D: グラフではない - AtCoder Regular Contest 030 | AtCoder

解法

RBST(Randomized Binary Search Tree)として知られる永続データ構造を使う。ググって色々調べたり、他の人のコードから学んで実装した。

RBSTについて自分なりに少しまとめておく。

基本的にはTreapと同じで、配列を二分木によって表して、それぞれの根のポインタを上手く利用することで、区間クエリ(特にコピーや反転など)を高速に処理できるようにしたものである。Treapとの違いは、木のmergeの際に、Treapは優先度評価値によってどちらの根を新しい木の根にするかを決めるのに対し、RBSTはそれぞれの部分木のサイズの比率によって新しい根を決定する点である。

mergesplitは割と素直に(順番が崩れないように)実装する。また、部分木に関する情報は値の更新などがある度に更新する。

遅延評価をさせる場合は厄介だが、mergeの際に正しく子に伝播させるように実装すればよい。

コード中に整理のためにコメントで説明を書いておいた。

この実装さえできればあとはクエリ処理をてきとうに書くだけである。ちなみにshared_ptrを使わないと、$\ MLE\ $になった。

実装
#include "bits/stdc++.h"
using namespace std;

#include <random>
mt19937 mt(0); //Mersenne Twisterによる乱数                                                                                        

struct node_t {
        int sz;
        long long val, add, sum;
        shared_ptr<node_t> lchild, rchild;
        node_t(long long val, int sz, long long add, long long sum, shared_ptr<node_t> lchild, shared_ptr<node_t> rchild)
                     : val(val), sz(sz), add(add), sum(sum), lchild(lchild), rchild(rchild) {}
};

using node = shared_ptr<node_t>;

int size(node t) { return !t ? 0 : t->sz; }
long long sum(node t) { return !t ? 0 : t->sum; }

node add(node v, long long a) { //根の遅延評価値にaを足す。valはこの時点では変更していない。
        if (!v) return NULL;
        int sz = v->sz;
        return node(new node_t(v->val, sz, v->add + a, v->sum + a * sz, v->lchild, v->rchild));
}

node new_node(long long val) {
        return node(new node_t(val, 1, 0, val, NULL, NULL));
}

node make_node(long long val, node left, node right, long long a) {
        int sz = size(left) + size(right) + 1;
        return node(new node_t(val, sz, a, sum(left) + sum(right) + val + a * sz, left, right));
}

node merge(node left, node right) {
        if (!left || !right) return !left ? right : left;
        if (int(mt() % (size(left) + size(right)) < size(left))) { //ノード数に応じてどちらを根にするかを決める
                if (left->add == 0) { 
                        return make_node(left->val, left->lchild, merge(left->rchild, right), 0);
                }
                //遅延評価を子に伝播させ、自身のvalを更新して、add = 0としておく。
                node lv = add(left->lchild, left->add);
                node rv = merge(add(left->rchild, left->add), right);
                return make_node(left->val + left->add, lv, rv, 0); 
        } else {
                if (right->add == 0) {
                        return make_node(right->val, merge(left, right->lchild), right->rchild, 0);
                }
                node lv = merge(left, add(right->lchild, right->add));
                node rv = add(right->rchild, right->add);
                return make_node(right->val + right->add, lv, rv, 0);
        }
}

pair<node, node> split(node t, int k) { //[0, k), [k, n)
        if (k == 0) return pair<node, node>(NULL, t);
        if (k >= size(t)) return pair<node, node>(t, NULL);
        if (size(t->lchild) >= k) {
                auto left = split(t->lchild, k);
                auto rv = make_node(t->val, left.second, t->rchild, t->add);
                if (t->add == 0) return make_pair(left.first, rv);
                return make_pair(add(left.first, t->add), rv);
        } else {
                auto right = split(t->rchild, k - size(t->lchild) - 1);
                auto lv = make_node(t->val, t->lchild, right.first, t->add);
                if (t->add == 0) return make_pair(lv, right.second);
                return make_pair(lv, add(right.second, t->add));
        }
}

int main() {
        int n, q;
        cin >> n >> q;
        node tree = NULL;
        for (int i = 0; i < n; i ++) {
                int x;
                cin >> x;
                tree = merge(tree, new_node(x));
        }
        while (q --) {
                int type, a, b, c, d, v;
                cin >> type;
                if (type == 1) {
                        cin >> a >> b >> v;
                        auto y = split(tree, b);
                        auto x = split(y.first, a - 1);
                        tree = merge(x.first, merge(add(x.second, v), y.second));
                } else if (type == 2) {
                        cin >> a >> b >> c >> d;
                        auto mid = split(split(tree, d).first, c - 1).second;
                        auto x = split(tree, b);
                        auto left = split(x.first, a - 1).first;
                        auto right = x.second;
                        tree = merge(left, merge(mid, right));
                } else {
                        cin >> a >> b;
                        long long ans = sum(split(split(tree, b).first, a - 1).second);
                        cout << ans << endl;
                }
        }
        return 0;
}