模板 - 可持久化线段树 (主席树)

基于 C++14 标准,实现了初始化,单点修改与查询

仅在 GCC 下测试过

https://cplib.tifa-233.com/src/code/ds/persistent_segtree.hpp 存放了笔者对该算法 / 数据结构的最新实现,建议前往此处查看相关代码

代码

Show code

Persistable_seg_tree.hppview raw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
namespace Persistable_seg_tree {
#define _TRAITS(expression, __...) \
std::enable_if_t<expression, ##__> * = nullptr
#define _CONVERTIBLE(Tp, Up) std::is_convertible<Tp, Up>::value

template <typename Tp, std::size_t Memory_rate = 24, Tp Zero = 0>
class persistable_seg_tree {
public:
using self = persistable_seg_tree<Tp, Memory_rate, Zero>;
using data_t = Tp;

protected:
struct node_t {
data_t data;
typename std::vector<node_t>::iterator l, r;

node_t(const data_t &_data = Zero): data(_data) {}
};
using pointer = typename std::vector<node_t>::iterator;
using index_t = std::size_t;

using nodes_t = std::vector<node_t>;
using roots_t = std::vector<pointer>;

constexpr void
_init(index_t &&l, index_t &&r, const data_t * const a, pointer &now) {
nodes.emplace_back();
now = nodes.end() - 1;
if (l == r) {
now->data = a[l];
return;
}
index_t mid = l + ((r - l) >> 1);
_init(std::move(l), std::move(mid), a, now->l);
_init(mid + 1, std::move(r), a, now->r);
}

constexpr void _modify(index_t &&l,
index_t &&r,
pointer &now,
const pointer &pre,
index_t &&pos,
const data_t &k) {
nodes.push_back(*pre);
now = nodes.end() - 1;
if (l == r) {
now->data = k;
return;
}
index_t mid = l + ((r - l) >> 1);
if (pos <= mid)
_modify(std::move(l), std::move(mid), now->l, pre->l, std::move(pos), k);
else _modify(mid + 1, std::move(r), now->r, pre->r, std::move(pos), k);
}

constexpr data_t &
_query(const pointer &now, index_t &&l, index_t &&r, index_t &&pos) const {
if (l == r) return now->data;
index_t mid = l + ((r - l) >> 1);
if (pos <= mid)
return _query(now->l, std::move(l), std::move(mid), std::move(pos));
else return _query(now->r, mid + 1, std::move(r), std::move(pos));
}

public:
constexpr persistable_seg_tree() = default;
constexpr explicit persistable_seg_tree(index_t &&_size): data_size(_size) {
this->nodes.reserve(this->data_size * Memory_rate);
this->nodes.emplace_back();
this->roots.push_back(nodes.begin());
}

constexpr persistable_seg_tree(const data_t * const data_array,
index_t &&_size)
: persistable_seg_tree(std::move(_size)) {
this->_init(1, std::move(this->data_size), data_array, this->roots.back());
}

template <class Up, _TRAITS(_CONVERTIBLE(Up, self &))>
constexpr persistable_seg_tree(Up &&rhs)
: data_size(std::forward(rhs).data_size), nodes(std::forward(rhs).nodes),
roots(std::forward(rhs).roots) {}

constexpr self &clear() {
this->nodes.clear();
this->roots.clear();
return *this;
}

constexpr index_t &&get_data_size() const {
return std::move(const_cast<self * const>(this)->data_size);
}
constexpr index_t &&get_node_size() const { return this->nodes.size(); }
constexpr index_t &&get_version_size() const { return this->roots.size(); }

constexpr nodes_t &data_nodes() const { return this->nodes; }
constexpr roots_t &data_roots() const { return this->roots; }

constexpr self &init(const data_t * const data_array, index_t &&_size) {
this->data_size = std::move(_size);
this->nodes.reserve(this->data_size * Memory_rate);
this->nodes.emplace_back();
this->roots.push_back(nodes.begin());
this->_init(1, std::move(this->data_size), data_array, this->roots.back());
return *this;
}

template <class Up, _TRAITS(_CONVERTIBLE(Up, data_t &))>
constexpr self &modify(index_t &&version, index_t &&pos, Up &&data) {
this->roots.push_back(this->nodes.begin());
this->_modify(1,
std::move(this->data_size),
this->roots.back(),
this->roots[version],
std::move(pos),
std::forward<data_t>(data));
return *this;
}

constexpr data_t &query(index_t &&version, index_t &&pos) {
this->roots.push_back(this->roots[version]);
return this->_query(
this->roots.back(), 1, std::move(this->data_size), std::move(pos));
}

// for debug
constexpr size_t memory_used() const {
return sizeof(node_t) * nodes.capacity() +
sizeof(pointer) * roots.capacity();
}

protected:
index_t data_size;
nodes_t nodes;
roots_t roots;
};

#undef _TRAITS
} // namespace Persistable_seg_tree
using Persistable_seg_tree::persistable_seg_tree;

示例

  • 洛谷 P3919 【模板】可持久化线段树 1(可持久化数组)

    由于 64 位系统下的迭代器大小为 8 字节,故该模板在部分测试点会 MLE, 解决方案就是把 index_t 换成 uint32_t 并将内部索引方式改成数组即可

    Show code

    Persistable_seg_tree_exp.cppview raw
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    #include <bits/stdc++.h>
    #include "Persistable_seg_tree.hpp"

    using namespace std;

    const size_t N = 1e6 + 5;
    int a[N];

    int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; ++i) cin >> a[i];
    persistable_seg_tree<int> tree(a, n);
    for (int i = 1, v, op, loc, value; i <= m; ++i) {
    cin >> v >> op >> loc;
    if (op & 1) {
    cin >> value;
    tree.modify(v, loc, value);
    } else cout << tree.query(v, loc) << '\n';
    }
    return 0;
    }

    附大数据的生成器 (from https://www.luogu.com.cn/discuss/354067)

    Show code

    data_gen.cppview raw
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    #include <bits/stdc++.h>
    using namespace std;
    mt19937 rd(1);
    const int inf = 1e8;
    int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n = 1e6, m = 1e6;
    cout << n << ' ' << m << '\n';
    for (int i = 1; i <= n; ++i) cout << (int)rd() % inf << ' ';
    cout << '\n';
    for (int i = 1; i <= 999990; ++i)
    cout << i - 1 << ' ' << 1 << ' ' << rd() % n + 1 << ' ' << (int)rd() % inf
    << '\n';
    for (int i = 1; i <= 10; ++i)
    cout << 999990 << ' ' << 2 << ' ' << rd() % n + 1 << '\n';
    return 0;
    }