题解 - [Luogu P6613] 一阶微分方程

题目链接

原始题面

题目背景

题目中 \(F'(x)\) 右侧的式子可以换成其它的,这里为了方便测试,是固定的

题目描述

已知多项式 \(F(x),A(x),B(x)\), 满足:

\[ \frac{\text dF(x)}{\text dx} \equiv A(x)\text e^{F(x)-1}+B(x) \pmod{x^n} \]

\(F(0)=1\)

给定 \(A(x),B(x)\), 请求出 \(F(x)\) 的前 \(n\) 次项系数

答案对 \(998244353\) 取模

输入格式

第一行一个正整数 \(n\), 表示 \(A(x),B(x)\) 的次数.
第二行 \(n+1\) 个整数,由低到高表示 \(A(x)\) 的系数.
第三行 \(n+1\) 个整数,由低到高表示 \(B(x)\) 的系数

输出格式

输出一行 \(n+1\) 个整数,由低到高表示 \(F(x)\) 的系数

样例 #1

样例输入 #1

1
2
3
9
2 9 8 7 3 6 5 4 1 12
23 9 8 7 4 6 1 3 2 5

样例输出 #1

1
1 25 34 332748429 124783260 22560 624092696 904826719 284383572 50973515

提示

数据规模与约定

对于 \(30\%\) 的数据,\(1\le n \le 5000\);
对于 \(100\%\) 的数据,\(1\le n \le 10^5\)

保证所有输入都在 \([0,998244353)\) 范围内

解题思路

简单的一阶非线性 ODE, 稍微仔细推一下

为简化公式,在不引起歧义的情况下省略自变量,所解方程为

\[ F'=A\exp(F-1)+B\tag{1} \]

一个容易想到的尝试是令 \(u=\exp(F-1)\), 则 \(u'=\exp(F-1)F'\), 进而方程 \((1)\) 变为

\[ u'=(Au+B)u\tag{2} \]

整理一下,有

\[ u'-Bu=Au^2\tag{2'} \]

此为 Bernoulli 微分方程 (\(n=2\)), 解法如下:

首先两边同除 \(u^n=u^2\), 即

\[ u^{-2}u'-Bu^{-1}=A \]

之后令 \(v=u^{1-n}=u^{-1}\), 则 \(v'=-u^nu'\), 即

\[ v'+Bv=-A\tag{3} \]

这样只需解一个简单的一阶线性 ODE 即可

取函数 \(\mu(x)\) 满足 \(\mu B=\mu'\), 显然 \(\mu\) 是存在的,稍后给出具体形式

\((3)\) 式两边同乘 \(\mu\) 后代入,得

\[ \mu v'+\mu' v=-\mu A\tag{4} \]

注意到 \((4)\) 式左边为 \((\mu v)'\), 进而

\[ v=-\mu^{-1}\left(\int\mu A\mathrm{d}x+C\right)\tag{5} \]

其中 \(C\) 为常数,由 \(F\) 初值确定

接下来我们考虑 \(\mu\) 的形式,显然 \((\ln\mu)'=\mu'/\mu=B\), 即 \(\mu=\exp\int B\mathrm{d}x +C'\), 其中 \(C'\) 为常数,不妨取为 \(0\)

最后我们将 \(v=\exp^{-1}(F-1)\)\(\mu\) 代入 \((5)\) 式,最终结果即为

\[ F=1+B-\ln\left(C-\int \left(\exp\int B(s)\mathrm{d}s\right)A\mathrm{d}x\right)\tag{6} \]

\(F(0)=1\) 代入,有 \(C=1\)

时间复杂度

\(O(n\log n)\)

代码参考

Show code

Luogu_P6613view raw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
/*
* @Author: Tifa
* @Description: From <https://github.com/Tiphereth-A/CP-archives>
* !!! ATTENEION: All the context below is licensed under a
* GNU Affero General Public License, Version 3.
* See <https://www.gnu.org/licenses/agpl-3.0.txt>.
*/
#include <bits/stdc++.h>
namespace Polynomial {
using data_t = int32_t;
using ldata_t = int64_t;
const size_t N = 1 << 17 | 500;
const data_t MOD = 998244353;
using udata_t = std::make_unsigned<data_t>::type;
using ludata_t = std::make_unsigned<ldata_t>::type;
const size_t DEG_LIMIT = N << 1;
namespace Helper {
constexpr ldata_t qpow(ldata_t a, ldata_t b, const ldata_t &mod) {
ldata_t res(1);
for (; b; b >>= 1, (a *= a) %= mod)
if (b & 1) (res *= a) %= mod;
return res;
}
constexpr ldata_t inverse(ldata_t n, const ldata_t &mod) {
ldata_t b = mod, m0 = 0;
for (ldata_t q = 0, _ = 0, m1 = 1; n;) {
_ = b - n * (q = b / n);
b = n;
n = _;
_ = m0 - m1 * q;
m0 = m1;
m1 = _;
}
return (m0 + (m0 < 0 ? mod / b : 0)) % mod;
}
constexpr data_t proot_impl_(data_t m) {
if (m == 2) return 1;
if (m == 3 || m == 5) return 2;
if (m == 104857601 || m == 167772161 || m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353 || m == 1004535809) return 3;
data_t divs[20] = {2};
data_t cnt = 1, x = (m - 1) / 2;
while (!(x & 1)) x >>= 1;
for (data_t i = 3; (ldata_t)i * i <= x; i += 2)
if (x % i == 0) {
divs[++cnt] = i;
while (x % i == 0) x /= i;
}
if (x > 1) divs[++cnt] = x;
for (data_t g = 2;; ++g) {
bool ok = true;
for (data_t i = 0; i < cnt; ++i)
if (qpow(g, (m - 1) / divs[i], m) == 1) {
ok = false;
break;
}
if (ok) return g;
}
}
template <data_t M>
constexpr data_t proot = proot_impl_(M);
constexpr int legendre_symbol(uint64_t a, uint64_t p) noexcept {
if (a == 0) return 0;
int s = 1, _ctz = 0;
while (a > 1) {
if (a == p || a == 0 || p < 2) return 0;
_ctz = __builtin_ctzll(a);
if (((p - 1) & 7) && ((p + 1) & 7) && (_ctz & 1)) s = -s;
if ((a >>= _ctz) == 1) break;
if ((((p - 1) & 7) * (a - 1)) & 7) s = -s;
std::swap(p %= a, a);
}
return s;
}
struct GaussInt {
data_t real, imag;
const data_t i_sqr, mod;
constexpr GaussInt &operator*=(GaussInt rhs) {
const ldata_t _r = real, _i = imag;
real =
(data_t)((_r * rhs.real % mod + i_sqr * _i % mod * rhs.imag % mod) % mod);
imag = (data_t)((_i * rhs.real % mod + _r * rhs.imag % mod) % mod);
return *this;
}
};
std::mt19937 eng__(time(nullptr));
data_t quad_residue(data_t n, data_t p) {
if (n == 0 || n == 1 || n == p - 1) return n;
if (legendre_symbol(n, p) != 1) return -1;
std::uniform_int_distribution<ldata_t> u(2, p - 1);
ldata_t a = u(eng__);
while (legendre_symbol((a * a % p + p - n) % p, p) == 1) a = u(eng__);
data_t ret = [](GaussInt a, udata_t b) {
GaussInt res{1, 0, a.i_sqr, a.mod};
for (; b; b >>= 1, a *= a)
if (b & 1) res *= a;
return res.real;
}(GaussInt{(data_t)a, 1, (data_t)(a * a % p + p - n) % p, p}, (p + 1) / 2);
return std::min(ret, p - ret);
}
template <size_t DEG_LIMIT, data_t MOD>
class INV_ {
protected:
data_t data[DEG_LIMIT];

public:
constexpr INV_() {
data[0] = 0;
data[1] = 1;
for (size_t i = 2; i < DEG_LIMIT; ++i)
data[i] = (data_t)((ldata_t)data[MOD % i] * (MOD - MOD / i) % MOD);
}
constexpr const data_t &operator[](size_t idx) const { return data[idx]; }
};
template <size_t DEG_LIMIT, data_t MOD>
class NTT_ {
static constexpr data_t G = proot<MOD>, IG = inverse(G, MOD);

protected:
data_t root[DEG_LIMIT];
size_t rsz_;
ludata_t f[DEG_LIMIT], w[DEG_LIMIT];
constexpr void root_init(size_t n) {
if (rsz_ == n) return;
rsz_ = n;
for (size_t i = 0; i < n; ++i)
root[i] = (root[i >> 1] >> 1) | (data_t)((i & 1) * (n >> 1));
}

public:
constexpr NTT_() = default;
constexpr void operator()(data_t *g, size_t n, bool inv = false) {
root_init(n);
w[0] = 1;
for (size_t i = 0; i < n; ++i)
f[i] = (((ldata_t)MOD << 5) + g[root[i]]) % MOD;
for (size_t l = 1; l < n; l <<= 1) {
ludata_t tG = qpow(inv ? IG : G, (MOD - 1) / (l + l), MOD);
for (size_t i = 1; i < l; ++i) w[i] = w[i - 1] * tG % MOD;
for (size_t k = 0; k < n; k += l + l)
for (size_t p = 0; p < l; ++p) {
ldata_t _ = w[p] * f[k | l | p] % MOD;
f[k | l | p] = f[k | p] + (MOD - _);
f[k | p] += _;
}
if (l == (1 << 10))
for (size_t i = 0; i < n; ++i) f[i] %= MOD;
}
if (inv) {
ludata_t in = inverse(n, MOD);
for (size_t i = 0; i < n; ++i) g[i] = (data_t)(f[i] % MOD * in % MOD);
} else
for (size_t i = 0; i < n; ++i) g[i] = (data_t)(f[i] % MOD);
}
};
const INV_<DEG_LIMIT, MOD> inv;
NTT_<DEG_LIMIT, MOD> NTT;
} // namespace Helper
using Helper::inverse;
using Helper::NTT;
using Helper::qpow;
class Poly {
protected:
std::vector<data_t> data;
template <class Fodd, class Feven>
void expand_base__(
Poly &ans, size_t n, data_t val1, Fodd &&fodd, Feven &&feven) const {
if (n == 1) {
ans.data.push_back(val1);
return;
}
if (n & 1) {
expand_base__(ans, n - 1, val1, fodd, feven);
fodd(ans, n);
return;
}
expand_base__(ans, n / 2, val1, fodd, feven);
feven(ans, n);
}
void inv_(Poly &ans, size_t n) const {
expand_base__(
ans,
n,
(data_t)inverse(data[0], MOD),
[this](Poly &ans, size_t n) -> void {
--n;
ldata_t _ = 0;
for (size_t i = 0; i < n; ++i)
_ = (_ + (ldata_t)ans[i] * data[n - i]) % MOD;
ans.data.push_back((data_t)(_ * inverse(MOD - data[0], MOD) % MOD));
},
[this](Poly &ans, size_t n) -> void {
Poly sA = *this;
sA.resize(n);
ans = ans * 2 - (ans * ans * sA).resize(n);
});
}
void exp_(Poly &ans, size_t n) const {
expand_base__(
ans,
n,
1,
[this](Poly &ans, size_t n) -> void {
n -= 2;
ldata_t _ = 0;
for (size_t i = 0; i <= n; ++i)
_ = (_ + (i + 1) * data[i + 1] % MOD * ans[n - i] % MOD) % MOD;
ans.data.push_back((data_t)(_ * Helper::inv[n + 1] % MOD));
},
[this](Poly &ans, size_t n) -> void {
Poly ans_log = ans;
ans_log.resize(n);
ans_log.do_log();
for (size_t i = 0; i < ans_log.size(); ++i)
ans_log[i] = (MOD + data[i] - ans_log[i]) % MOD;
++ans_log[0];
(ans *= ans_log).resize(n);
});
}
void sqrt_(Poly &ans, size_t n) const {
if (n == 1) {
auto &&qres = Helper::quad_residue(data[0], MOD);
assert(qres != -1);
ans.data.push_back(qres);
return;
}
sqrt_(ans, (n + 1) / 2);
Poly sA = *this;
sA.resize(n);
ans.resize(ans.size() * 2);
ans = (sA + (ans * ans).resize(n)) * inverse(ans * 2);
ans.resize(n);
}

public:
explicit Poly(decltype(DEG_LIMIT) sz = 0): data(std::min(DEG_LIMIT, sz)) {}
explicit Poly(const std::initializer_list<data_t> &v): data(v) {}
explicit Poly(const std::vector<data_t> &v): data(v) {}
friend std::istream &operator>>(std::istream &is, Poly &poly) {
for (auto &val : poly.data) is >> val;
return is;
}
friend std::ostream &operator<<(std::ostream &os, const Poly &poly) {
for (size_t i = 1; i < poly.size(); ++i) os << poly[i - 1] << ' ';
return os << poly.data.back();
}
data_t &operator[](size_t x) { return data[x]; }
const data_t &operator[](size_t x) const { return data[x]; }
size_t size() const { return data.size(); }
Poly &resize(size_t size) {
data.resize(size);
return *this;
}
Poly &strip() {
if (size() > DEG_LIMIT) resize(DEG_LIMIT);
while (!data.back()) data.pop_back();
if (data.empty()) data.push_back(0);
return *this;
}
Poly &operator*=(const data_t &c) {
for (data_t &val : data) val = (data_t)((ldata_t)val * c % MOD);
return *this;
}
friend Poly operator*(Poly poly, const data_t &c) { return poly *= c; }
friend Poly operator*(const data_t &c, Poly poly) { return poly *= c; }
#define OOCR_(op, ...) \
Poly &operator op##=(const Poly &rhs) __VA_ARGS__ friend Poly operator op( \
Poly lhs, const Poly &rhs) { \
return lhs op## = rhs; \
}
#define OO_(op, ...) \
Poly &operator op##=(Poly rhs) __VA_ARGS__ friend Poly operator op( \
Poly lhs, const Poly &rhs) { \
return lhs op## = rhs; \
}
OOCR_(+, {
resize(std::max(size(), rhs.size()));
for (size_t i = 0; i < rhs.size(); ++i) {
data[i] += rhs[i];
data[i] -= data[i] >= MOD ? MOD : 0;
}
return *this;
})
OOCR_(-, {
resize(std::max(size(), rhs.size()));
for (size_t i = 0; i < rhs.size(); ++i) {
data[i] += MOD - rhs[i];
data[i] -= data[i] >= MOD ? MOD : 0;
}
return *this;
})
OOCR_(*, {
static data_t a__[N << 1], b__[N << 1];
std::copy(data.begin(), data.end(), a__);
std::copy(rhs.data.begin(), rhs.data.end(), b__);
size_t _sz = size();
data.clear();
resize(_sz + rhs.size() - 1);
size_t n =
(size_t)(1) << (size_t)std::max(1., std::ceil(std::log2(size())));
NTT(a__, n);
NTT(b__, n);
for (size_t i = 0; i < n; ++i)
a__[i] = (data_t)((ldata_t)a__[i] * b__[i] % MOD);
NTT(a__, n, true);
std::copy(a__, a__ + size(), data.begin());
memset(a__, 0, sizeof(a__[0]) * (n));
memset(b__, 0, sizeof(b__[0]) * (n));
return *this;
})
OO_(/, {
size_t n_ = size(), m_ = rhs.size();
std::reverse(data.begin(), data.end());
std::reverse(rhs.data.begin(), rhs.data.end());
rhs.resize(n_ - m_ + 1);
*this *= rhs.do_inverse();
resize(n_ - m_ + 1);
std::reverse(data.begin(), data.end());
return *this;
})
OOCR_(%, {
auto &&__ = rhs * (*this / rhs);
return (*this -= __).resize(rhs.size() - 1);
})
#undef OO_
#undef OOCR_
friend std::pair<Poly, Poly> divmod(const Poly &lhs, const Poly &rhs) {
auto &&div_ = lhs / rhs;
return {div_, (lhs - rhs * div_).resize(rhs.size() - 1)};
}
Poly &shift_left(size_t offset) {
if (offset == 0) return *this;
if (offset >= size()) {
data.clear();
return *this;
}
data.erase(std::move(data.begin() + offset, data.end(), data.begin()),
data.end());
return *this;
}
Poly &shift_right(size_t offset) {
if (offset == 0) return *this;
resize(size() + offset);
std::fill(data.begin(),
std::move_backward(data.begin(), data.end() - offset, data.end()),
0);
return *this;
}
#define FUNC_(name, ...) \
Poly &do_##name() __VA_ARGS__ friend Poly name(Poly poly) { \
return poly.do_##name(); \
}
#define FUNCP2_(name, type1, var1, type2, var2, ...) \
Poly &do_##name(type1 var1, type2 var2) __VA_ARGS__ friend Poly name( \
Poly poly, type1 var1, type2 var2) { \
return poly.do_##name(var1, var2); \
}
FUNC_(inverse, {
Poly ret;
inv_(ret, size());
return *this = ret;
})
FUNC_(derivative, {
for (size_t i = 1; i < size(); ++i)
data[i - 1] = (data_t)((ldata_t)data[i] * i % MOD);
data.pop_back();
return *this;
})
FUNC_(integral, {
data.push_back(0);
for (size_t i = size() - 1; i; --i)
data[i] = (data_t)((ldata_t)data[i - 1] * Helper::inv[i] % MOD);
data.front() = 0;
return *this;
})
FUNC_(log, {
size_t sz_ = size();
*this = (derivative(*this) * inverse(*this)).do_integral();
resize(sz_);
return *this;
})
FUNC_(exp, {
Poly ret;
exp_(ret, size());
return *this = ret;
})
FUNC_(sqrt, {
Poly ret;
sqrt_(ret, size());
return *this = ret;
})
FUNC_(sin, {
size_t sz_ = size();
data_t i = qpow(Helper::proot<MOD>, (MOD - 1) / 4, MOD);
*this *= i;
*this = (exp(*this * (MOD - 1)) - exp(*this)) *
(data_t)(i * inverse(2, MOD) % MOD);
resize(sz_);
return *this;
})
FUNC_(cos, {
size_t sz_ = size();
data_t i = qpow(Helper::proot<MOD>, (MOD - 1) / 4, MOD);
*this *= i;
*this = (exp(*this) + exp(*this * (MOD - 1))) * (data_t)inverse(2, MOD);
resize(sz_);
return *this;
})
FUNC_(tan, {
size_t sz_ = size();
data_t i = 2 * qpow(Helper::proot<MOD>, (MOD - 1) / 4, MOD);
(*this *= i).do_exp();
Poly _1 = *this, _2 = *this;
--_1[0];
++_2[0];
*this = _1 * _2.do_inverse() * (MOD - i);
resize(sz_);
return *this;
})
FUNC_(asin, {
size_t sz_ = size();
Poly _1 = (*this * *this * (MOD - 1)).resize(sz_);
++_1[0];
*this =
(derivative(*this) * _1.do_sqrt().do_inverse()).resize(sz_).do_integral();
resize(sz_);
return *this;
})
FUNC_(acos, {
size_t sz_ = size();
Poly _1 = (*this * *this * (MOD - 1)).resize(sz_);
++_1[0];
*this = (derivative(*this) * _1.do_sqrt().do_inverse() * (MOD - 1))
.resize(sz_)
.do_integral();
resize(sz_);
return *this;
})
FUNC_(atan, {
size_t sz_ = size();
Poly _1 = (*this * *this).resize(sz_);
++_1[0];
*this = (derivative(*this) * _1.do_inverse()).resize(sz_).do_integral();
resize(sz_);
return *this;
})
FUNCP2_(pow, ludata_t, y, ludata_t, y_mod_phiMOD, {
size_t k_ = 0, sz_ = data.size();
for (; k_ < sz_; ++k_)
if (data[k_]) break;
if (k_ * y >= sz_) {
std::fill(data.begin(), data.end(), 0);
return *this;
}
shift_left(k_);
resize(sz_ - k_ * y);
data_t c_ = data[0], inv_c_ = (data_t)inverse(c_, MOD),
c_y_ = (data_t)qpow(c_, y_mod_phiMOD, MOD);
*this *= inv_c_;
*this = (log(*this) * (data_t)y).do_exp() * c_y_;
shift_right(k_ * y);
return *this;
})
#undef FUNC_
#undef FUNCP2_
};
} // namespace Polynomial
using Polynomial::Poly;
using namespace std;
using Polynomial::qpow, Polynomial::MOD, Polynomial::Helper::inv;
auto solve([[maybe_unused]] int t_ = 0) -> void {
int n;
cin >> n;
Poly a(n + 1), b(n + 1);
cin >> a >> b;
auto &&ib = integral(b);
auto &&mu = exp(ib);
auto &&g = integral(mu * a) * (Polynomial::MOD - 1);
g[0] = (g[0] + 1) % Polynomial::MOD;
g.resize(n + 1);
auto &&f = ib - log(g);
f[0] = (f[0] + 1) % Polynomial::MOD;
cout << f.resize(n + 1);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
solve();
return 0;
}