「算法笔记」多项式求逆

多项式求逆通过递归求得某个多项式在模 $x$ 的若干次幂下的逆元。


概述

所谓多项式求逆,就是给定一个 $n - 1$ 次多项式 $A(x)$,你需要求出一个多项式 $B(x)$ 满足 $A(x)B(x) \equiv 1\pmod{x ^ n}$。


思路

我们考虑一个子问题:假如已经求出了多项式 $A(x)$ 在模 $x ^ {\left\lceil\frac{n}{2}\right\rceil}$ 意义下的逆元 $B'(x)$,那么有:

$$ A(x)B'(x) \equiv 1\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$

又因为:

$$ A(x)B(x) \equiv 1\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$

将两式相减得到:

$$ A(x)[B(x) - B'(x)] \equiv 0\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$

由于 $A(x)\not\mid x ^ {\left\lceil\frac{n}{2}\right\rceil}$,那么我们可以把 $A(x)$ 除掉得到:

$$ B(x) - B'(x) \equiv 0\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$

接下来将等式两边平方得到:

$$ [B(x) - B'(x)] ^ 2 \equiv 0\pmod{x ^ \left\lceil\frac{n}{2}\right\rceil} $$

分析一下平方后的多项式有什么特点。设 $P(x) = B(x) - B'(x)$,那么对于 $P(x)$ 任意的 $i \in \left[0, \left\lceil\frac{n}{2}\right\rceil \right)$,第 $i$ 项的系数均为 $0$。考虑将其平方后,得到系数 $a'_i = \sum_{j = 0} ^ i a_j\times a_{i - j}$,对于任意的 $i \in \left[0, 2\times\left\lceil\frac{n}{2}\right\rceil\right)$,$i$ 或 $i - j$ 中必有一个值小于 $\left\lceil\frac{n}{2}\right\rceil$,那么 $a_i$ 和 $a_{i - j}$ 中必有一项值为 $0$。换言之,我们可以得到结论:$P(x) ^ 2$ 在模 $x ^ n$ 的意义下与 $0$ 同余。

$$ B ^ 2(x) +B' ^ 2(x) - 2 B(x) B'(x)\equiv 0 \pmod{x ^ n} $$

我们在两边同时乘上 $A(x)$ 得到:

$$ A(x)B ^ 2(x) + A(x)B' ^ 2(x) - 2A(x)B(x)B'(x) \equiv 0\pmod{x ^ n} $$

通过逆元的定义 $A(x)B(x)\equiv 1\pmod{x ^ n}​$ 可以化简为:

$$ B(x) + A(x)B' ^ 2(x) - 2B'(x)\equiv 0\pmod{x ^ n} $$

移项得到:

$$ B(x) = 2B'(x) - A(x)B' ^ 2(x)\pmod{x ^ n} $$

时间复杂度:$\mathcal O(n \log n)$。


实现

实现主要由如下两种方式:

  1. 递归:通过实现部分的前几句话,我们就可以知道可以递归求解。递归边界为 $n = 1$,此时答案为常数项的逆元
  2. 迭代:枚举迭代长度也可以求得答案。迭代实现的常数较小,但是细节较多。

考虑证明复杂度,$T(n) = T(\frac{n}{2}) + \mathcal O(n\log n)$。通过主定理可以得到复杂度为 $\mathcal O(n\log n)$。


代码

此处只给出迭代实现的代码。

Vec operator ~ (Vec A) {
    int n = A.size(), N = extend(n);
    A.resize(N);
    Vec I(N, 0);
    I[0] = inv(A[0]);
    for (int l = 2; l <= N; l <<= 1) {
        Vec P(l), Q(l);
        std::copy(A.begin(), A.begin() + l, P.begin());
        std::copy(I.begin(), I.begin() + l, Q.begin());
        int L = l << 1;
        P.resize(L), DFT(P);
        Q.resize(L), DFT(Q);
        for (int i = 0; i < L; i++) {
            P[i] = 1LL * Q[i] * (2 - 1LL * P[i] * Q[i] % MOD + MOD) % MOD;
        }
        IDFT(P), P.resize(l);
        std::copy(P.begin(), P.begin() + l, I.begin());
    }
    I.resize(n);
    return I;
}

发表评论