Treap 是 BST(二元搜尋樹)跟 Heap(堆積)的結合體
該資料結構同時具有以下性質:
- BST
- 節點的 value 必 > 左子樹所有點的 value
- 節點的 value 必 < 右子樹所有點的 value
相同的值可以合併至同一節點
- Heap
- 節點的 priority 必 <= 子樹所有點的 priority (這裡以 min_heap 為範例,當然 max_heap 也是可以的)
權重:priority(隨機值)
值:value
Treap 使用隨機值作為 priority 維護堆的性質,來維持整棵樹的平衡
維護堆的性質#
樹的性質可以在插入時進行維護
堆的性質有兩種方法可以進行維護,分別為
- 旋轉:旋轉
- 無旋:分裂、合併
旋轉 Treap#
正常情況下,旋轉 Treap 常數較小
以下為範例程式碼
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define INF LONG_LONG_MAX/1000
#define WA() cin.tie(0)->sync_with_stdio(0)
#define all(x) (x).begin(), (x).end()
#define int long long
#define PII pair<int, int>
random_device rd;
mt19937 gen(rd());
struct Node;
inline int _getSiz(Node *x);
struct Node {
Node* son[2]; // 0 為左子節點,1 為右節點
int val, pri, cnt, siz;
Node(int x) : val(x), cnt(1), siz(1), pri(gen()), son{nullptr, nullptr} {}
void pull() {
siz = cnt + _getSiz(son[0]) + _getSiz(son[1]);
}
};
inline int _getSiz(Node *x) {
return (x ? x->siz : 0);
}
enum rot {
L = 1, R = 0 // 左旋為將根節點的右子節點為新的根結點,反之亦然
};
void _rotate(Node *&x, rot dir) {
Node *y = x->son[dir];
x->son[dir] = y->son[!dir];
y->son[!dir] = x;
x->pull(), y->pull();
x = y;
}
void _insert(Node *&x, int val) {
if (!x) return x = new Node(val), void();
else if (val == x->val) x->cnt++, x->siz++;
else if (val < x->val) {
_insert(x->son[0], val);
if (x->son[0]->pri < x->pri) _rotate(x, R);
x->pull();
}
else {
_insert(x->son[1], val);
if (x->son[1]->pri < x->pri) _rotate(x, L);
x->pull();
}
}
void _delete(Node *&x, int val) {
if (!x) return;
if (val < x->val) _delete(x->son[0], val);
else if (val > x->val) _delete(x->son[1], val);
else {
if (x->cnt > 1) x->cnt--, x->siz--;
else {
Node *y = x;
if (x->son[0] && x->son[1]) {
rot dir = (x->son[0]->pri < x->son[1]->pri ? R : L);
_rotate(x, dir);
_delete(x->son[!dir], val);
}
else if (x->son[0]) {
x = y->son[0];
delete y;
}
else if (x->son[1]) {
x = y->son[1];
delete y;
}
else {
delete x;
x = nullptr;
}
}
if (x) x->pull();
}
}
int queryRank(Node *x, int key) {
if (!x) return -1;
if (key == x->val) return _getSiz(x->son[0]) + 1;
if (key < x->val) return queryRank(x->son[0], key);
int qr = queryRank(x->son[1], key);
if (~qr) return _getSiz(x->son[0]) + x->cnt + qr; // if (key > x->val)
return -1;
}
int querykth(Node *x, int rank) {
if (!x) return -1;
if (x->son[0]) {
if (x->son[0]->siz >= rank) return querykth(x->son[0], rank);
if (x->son[0]->siz + x->cnt >= rank) return x->val;
}
else if (x->cnt >= rank) return x->val;
return querykth(x->son[1], rank - _getSiz(x->son[0]) - x->cnt);
}
signed main() { WA();
}
無旋 Heap#
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define INF LONG_LONG_MAX/1000
#define WA() cin.tie(0)->sync_with_stdio(0)
#define all(x) (x).begin(), (x).end()
#define int long long
#define PII pair<int, int>
/*
val 值
pri 維護 Heap 的權重
siz 以該節點的樹的總結點數,也就是 l->siz + cnt + r->siz
cnt 紀錄 val 值有幾個
l、r 為左右子樹
此範例為 min_heap
*/
mt19937 rng(time(0));
struct Node;
inline int _getSiz(Node *x);
struct Node {
int val, pri, siz, cnt;
Node *l, *r;
Node(int x) : val(x), pri(rng()), siz(1), cnt(1), l(nullptr), r(nullptr) {}
void pull() {
siz = cnt + _getSiz(l) + _getSiz(r);
}
};
inline int _getSiz(Node *x) {
return (x ? x->siz : 0);
}
/*
_split 將原 Treap x 分割成兩個 Treap,為 a 跟 b
其中所有的值符合
a:< key
b:>= key
*/
pair<Node *, Node *> _split(Node *x, int key) { // 使用方法:auto [a, b] = split(root, key);
// base case: 當前節點為空節點
if (!x) return {nullptr, nullptr};
/*
如果當前節點為 x
x->val < key,同時 x->l 所有的值 < 當前節點的值
因此當前節點及左子樹都會被分到 a
並且無法確定所有大於 x 的值也就是 x->r 也全部 >= key,所以要向下遞迴從 x->r 中找出 < key 的部份分到 a,成為 x 的右子樹
並且將真正 >= key 的值放到 b
*/
if (x->val < key) {
auto [a, b] = _split(x->r, key); // 往 x->r 遞迴
x->r = a; // 遞迴切出來的 a 為大於 x 但是 <= key 的值
x->pull();
return {x, b};
}
// 相反同理
auto [a, b] = _split(x->l, key);
x->l = b;
x->pull();
return {a, x};
}
/*
遞迴分割小於 rk、等於 rk、大於 rk
*/
tuple<Node *, Node *, Node *> splitByRank(Node *x, int rk) {
if (!x) return {nullptr, nullptr, nullptr};
int ls = (x->l ? x->l->siz : 0);
if (rk <= ls) {
auto [ll, mid, rr] = splitByRank(x->l, rk);
x->l = rr; x->pull();
return {ll, mid, x};
}
else if (rk <= ls + x->cnt) {
Node *ll = x->l, *rr = x->r;
x->l = x->r = nullptr;
x->pull();
return {ll, x, rr};
}
else {
auto [ll, mid, rr] = splitByRank(x->r, rk);
x->r = ll; x->pull();
return {x, mid, rr};
}
}
/*
前提:Treap a 的值 < Treap b 的值
所以在合併的時候只需要維護堆性質
*/
Node *_merge(Node *a, Node *b) {
// 直到該左右子樹為空節點,則直接連上另一個 Treap
if (!a) return b;
if (!b) return a;
/*
如果 a 的權重 < b 的權重
當前為 min_heap,所以將權重較小的 a 設為父節點,放在上方
並且因為 b 所有的值比 a 所有的值大,所以按照樹的性質,要將 b 放在 a 的右子樹
*/
if (a->pri < b->pri) {
a->r = _merge(a->r, b); // 記得遵守前提
a->pull();
return a; // a 為父節點
}
// 相反同理
b->l = _merge(a, b->l); // 記得遵守前提
b->pull();
return b;
}
/*
先從 x 中切出 a: < val, b: >= val
再從 b 中切出 bl: < val+1, br: >= val+1
其中 bl 同時滿足 bl >= val 跟 bl < val+1,也就是說 bl 就是 val 本人
如果該節點存在,也就是說已經有了儲存 val 的節點,就將 cnt + 1 即可
否則需要創建一個新的節點
最後再將分割出來的三個 Treap 合併回去
*/
void _insert(Node *&x, int val) {
auto [a, b] = _split(x, val);
auto [bl, br] = _split(b, val+1);
Node *k;
if (bl) {
bl->cnt++;
bl->pull();
}
else bl = new Node(val);
x = _merge(a, _merge(bl, br));
}
/*
跟 insert 同個概念
基本上想要對特定節點進行操作,流程大概就是這樣
*/
void _delete(Node *&x, int val) {
auto [a, b] = _split(x, val);
auto [bl, br] = _split(b, val+1);
if (bl) { // 找得到該節點再進行 cnt 的判斷
if (bl->cnt > 1) {
bl->cnt--;
bl->pull();
}
else {
delete bl;
bl = nullptr;
}
}
x = _merge(a, _merge(bl, br));
}
int queryRankByVal(Node *&x, int val) {
auto [a, b] = _split(x, val);
int ret = (a ? a->siz : -2) + 1;
x = _merge(a, b);
return ret;
}
int queryValByRank(Node *&x, int rk) {
auto [ll, mid, rr] = splitByRank(x, rk);
int ret = (mid ? mid->val : 0);
x = _merge(_merge(ll, mid), rr);
return ret;
}
signed main() { WA();
}
無旋 Heap 進行區間操作#
捨棄了使用 val 維護二元樹的性質,改成使用 index 來維護
使得當前節點代表的區間為:左子樹區間 + 當前節點 + 右子樹區間
跟線段樹有些相似
可以將其想像成能夠解決符合線段數性質但帶有序列操作的題目的資料結構
以下為範例程式碼
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define INF LONG_LONG_MAX/1000
#define WA() cin.tie(0)->sync_with_stdio(0)
#define all(x) (x).begin(), (x).end()
#define int long long
#define PII pair<int, int>
mt19937 gen(time(0));
struct Node;
inline int _getSiz(Node *x);
inline int _getSum(Node *x);
struct Node {
int val, sum, pri, siz, _add, _set_tag, _set_val, rev;
Node *l, *r;
Node(int x): val(x), sum(x), pri(gen()), siz(1), _add(0), _set_tag(0), _set_val(0), rev(0), l(nullptr), r(nullptr) {}
void pull() {
siz = _getSiz(l) + _getSiz(r) + 1;
sum = _getSum(l) + _getSum(r) + val;
}
void set(int k) {
_set_tag = 1;
_set_val = k;
_add = 0;
val = k;
sum = siz * k;
}
void add(int k) {
if (_set_tag) _set_val += k;
else _add += k;
val += k;
sum += siz * k;
}
void push() {
if (_set_tag) {
if (l) l->set(_set_val);
if (r) r->set(_set_val);
_set_tag = 0;
}
if (rev) {
swap(l, r);
if (l) l->rev ^= 1;
if (r) r->rev ^= 1;
rev = 0;
}
if (_add) {
if (l) l->add(_add);
if (r) r->add(_add);
_add = 0;
}
}
};
inline int _getSiz(Node *x) {
return (x ? x->siz : 0);
}
inline int _getSum(Node *x) {
return (x ? x->sum : 0);
}
/*
直接使用 split by size
因為 size 就是 index
*/
pair<Node *, Node *> _split(Node *x, int sz) {
if (!x) return {nullptr, nullptr};
x->push();
if (sz <= _getSiz(x->l)) {
auto [a, b] = _split(x->l, sz);
x->l = b;
x->pull();
return {a, x};
}
auto [a, b] = _split(x->r, sz - _getSiz(x->l) - 1);
x->r = a;
x->pull();
return {x, b};
}
Node *_merge(Node *a, Node *b) {
if (!a) return b;
if (!b) return a;
if (a->pri < b->pri) {
a->push();
a->r = _merge(a->r, b);
a->pull();
return a;
}
b->push();
b->l = _merge(a, b->l);
b->pull();
return b;
}
Node *_build_rec(vector<int> &v, int l, int r) {
if (l > r) return nullptr;
int mid = (l+r) / 2;
Node *x = new Node(v[mid]);
x->l = _build_rec(v, l, mid-1);
x->r = _build_rec(v, mid+1, r);
x->pull();
return x;
}
void _traversal(Node *x) {
if(!x) return;
x->push();
if (x->l) _traversal(x->l);
cout << x->val << ' ';
if (x->r) _traversal(x->r);
}
void _seg_rev(Node *&x, int l, int r) {
auto [a, b] = _split(x, l);
auto [ba, bb] = _split(b, r-l+1);
if (ba) ba->rev ^= 1;
x = _merge(a, _merge(ba, bb));
}
void _seg_add(Node *&x, int l, int r, int val) {
auto [a, b] = _split(x, l);
auto [ba, bb] = _split(b, r-l+1);
if (ba) ba->add(val);
x = _merge(a, _merge(ba, bb));
}
void _seg_set(Node *&x, int l, int r, int val) {
auto [a, b] = _split(x, l);
auto [ba, bb] = _split(b, r-l+1);
if (ba) ba->set(val);
x = _merge(a, _merge(ba, bb));
}
signed main() { WA();
}
