Online FFT의 $\mathcal{O}(n \sqrt {n})$ 구현

수열 $\{a_n\}$, $\{b_n\}$의 컨볼루션 $c_n = \sum_{i=0}^{n} a_i b_{n-i}$를 구하는 것은 FFT를 이용하여 $\mathcal{O}(n \log {n})$에 계산할 수 있습니다.

그런데 $\{a_n\}$, $\{b_n\}$의 각 원소가 online으로 주어진다면, 즉 예를 들어 $c_{n-1}$의 값을 알아야만 $a_n$$b_n$을 알 수 있다면 FFT를 이용하는 방법을 그대로 사용할 수 없습니다. 이러한 경우 $\{c_n\}$을 계산하는 문제를 Online FFT라고 부릅니다. Online FFT를 사용할 수 있는 대표적인 사례로 어떤 함수 $f$가 존재해 $a_n = f \left (n, \sum_{i=0}^{n-1} a_i b_{n-1-i} \right )$ 꼴의 점화식을 가지는 수열을 계산할 수 있습니다.

EDIT (2022-10-15): Online FFT의 정식 명칭은 relaxed multiplication인 것 같습니다.

Online FFT의 구현 방법으로 다음과 같은 방법들이 알려져 있습니다.

이 글에서는 위 링크에서 소개된 $\mathcal{O}(n \sqrt{n \log n})$ 방법과 유사한(?) $\mathcal{O}(n \sqrt{n})$ 방법을 소개합니다.


Convolution Blackbox

컨볼루션을 FFT로 계산할 때 컨볼루션을 해주는 함수를 통째로 blackbox로 생각해 문제를 풀곤 합니다. 이 blackbox를 살짝만 뜯어봅시다.

우선 기호를 정의합니다. 수열 $a$, $b$에 대해 다음과 같은 연산을 정의합니다.

  • 원소 합(element-wise sum): $(a + b)_i = a_i + b_i$
  • 원소 곱(element-wise product): $(a \times b)_i = a_i b_i$
  • 컨볼루션(convolution): $(a * b)_i = \sum_{j=0}^{i} a_j b_{i-j}$

FFT는 유한 수열 $a$로부터 또 다른 어떤 유한 수열 $\mathcal{F}(a)$을 계산하는 정방향 FFT와, 반대로 $\mathcal{F}(a)$으로부터 $\mathcal{F}^{-1}(\mathcal{F}(a))=a$를 계산하는 역방향 FFT가 있습니다. 이 글에서는 FFT를 수행하는 것은 blackbox로 남겨둡니다. 정방향 FFT, 역방향 FFT는 모두 변환하는 수열의 길이를 $n$이라 할 때 $\mathcal{O}(n \log n)$에 계산할 수 있습니다.

유한 수열 $a$, $b$에 대해 다음과 같은 FFT와 관련된 성질이 성립합니다.

  • $\mathcal{F}(a + b) = \mathcal{F}(a) + \mathcal{F}(b)$
  • $\mathcal{F}(a * b) = \mathcal{F}(a) \times \mathcal{F}(b)$

유한수열 $a$$b$의 컨볼루션 $c = a * b$을 FFT를 통해서 계산하는 방법은 다음과 같습니다.

  • $a$, $b$ 각각에 대해 정방향 FFT를 수행한 결과인 $\mathcal{F}(a)$, $\mathcal{F}(b)$를 계산합니다.
  • $\mathcal{F}(a)$$\mathcal{F}(b)$에 대해 같은 위치의 원소를 곱해(element-wise multiplication) $\mathcal{F}(c) = \mathcal{F}(a * b) = \mathcal{F}(a) \times \mathcal{F}(b)$를 계산합니다.
  • $\mathcal{F}(c)$의 역방향 FFT를 계산해 $c$를 계산합니다.

이때 정방향 FFT를 수행한 결과는 재활용할 수 있다는 점을 기억해둡시다. 즉, 예를 들어, $a$를 두 수열 $b$, $b'$과 convolution 한 결과 $a * b$$a * b'$를 계산한다고 하면, $\mathcal{F}(a)$는 한 번만 계산하면 됩니다.


Online FFT in $\mathcal{O}(n \sqrt{n})$

$c = a * b$를 online FFT로 계산한다고 합시다. 컨볼루션될 수열 $a$, $b$를 블록 사이즈 $k$의 블록으로 나눕니다. 블록 번호를 $0$부터 시작하도록 하면 $t$번째 블록은 수열의 $t k$번째 항부터 $(t + 1) k - 1$번째 항까지를 포함합니다. $a$, $b$$t$번째 블록을 각각 $A_t$, $B_t$라고 부릅시다. 즉, $0 \leq i < k$에 대해 $(A_t)_i = a_{tk + i}$, $(B_t)_i = b_{tk + i}$입니다.

$c_n$으로부터 이미 계산된 블록들의 컨볼루션에서 찾을 수 있는 항들을 최대한 모아봅시다. $n = tk + r$ ($t \geq 1$, $0 \leq r < k$)이라 하면, 다음과 같습니다.

$$ \begin{align*} c_n &= (a * b)_n \\ &= \sum_{j=0}^{n} a_j b_{n-j} \\ % &= \sum_{i=0}^{t} \sum_{j=0}^{r} a_{ik+j} b_{(t-i)k+r-j} + \sum_{i=0}^{t-1} \sum_{j=r+1}^{k-1} a_{ik+j} b_{(t-i-1)k+r+k-j} &= \sum_{i=0}^{t-1} \sum_{j=0}^{k-1} a_{ik+j} b_{(t-i)k+r-j} + \sum_{j=0}^{r} a_{tk+j} b_{r-j} \\ &= \sum_{i=0}^{t-1} \sum_{j=0}^{r} a_{ik+j} b_{(t-i)k+r-j} + \sum_{i=0}^{t-1} \sum_{j=r+1}^{k-1} a_{ik+j} b_{(t-i)k+r-j} + \sum_{j=0}^{r} a_{tk+j} b_{r-j} \\ &= \sum_{j=0}^{r} a_{j} b_{tk+r-j} + \sum_{i=1}^{t-1} \sum_{j=0}^{r} a_{ik+j} b_{(t-i)k+r-j} + \sum_{i=0}^{t-1}\sum_{j=r+1}^{k-1} a_{ik+j} b_{(t-1-i)k+r+k-j} + \sum_{j=0}^{r} a_{tk+j} b_{r-j} \\ &= \sum_{j=0}^{r} a_{j} b_{tk+r-j} + \sum_{i=1}^{t-1} (A_{i} * B_{t-i})_{r} + \sum_{i=0}^{t-1}(A_{i} * B_{t-1-i})_{r+k} + \sum_{j=0}^{r} a_{tk+j} b_{r-j} \end{align*} $$

맨 앞과 맨 뒤 합은 $\mathcal{O}\left(\frac{n}{k} k^2\right) = \mathcal{O}(n k)$에 계산할 수 있습니다. 2번째 합은 다음과 같이 계산할 수 있습니다. 3번째 합도 비슷하게 계산합니다.

$$ \begin{align*} &\sum_{i=1}^{t-1} (A_{i} * B_{t-i})_{r} \\ =&\left(\sum_{i=1}^{t-1} (A_{i} * B_{t-i})\right)_{r} \\ =&\mathcal{F}^{-1} \left(\mathcal{F} \left(\sum_{i=1}^{t-1} (A_{i} * B_{t-i}) \right)\right)_{r} \\ =&\mathcal{F}^{-1} \left(\sum_{i=1}^{t-1} ( \mathcal{F} \left(A_{i}\right) \times \mathcal{F} \left(B_{t-i}\right) ) \right)_{r} \\ \end{align*} $$

$\mathcal{F} (A)$, $\mathcal{F} (B)$들을 각각 한번씩 계산해 다시 사용한다고 하면 전체 정변환, 역변환의 수는 $\mathcal{O}\left(\frac{n}{k}\right)$이므로 FFT에 소요되는 총 시간복잡도는 $\mathcal{O}(n \log k)$입니다. 한편 $\sum ( \mathcal{F} (A) \times \mathcal{F} (B) )$의 계산에 소요되는 총 시간복잡도는 $\mathcal{O}\left(\left(\frac{n}{k}\right)^2 k \right) = \mathcal{O}\left(\frac{n^2}{k}\right)$입니다.

전체 시간복잡도는 $\mathcal{O}\left(nk + n \log k + \frac{n^2}{k}\right) = \mathcal{O}\left(nk + \frac{n^2}{k}\right)$이고, $k = \sqrt{n}$으로 잡으면 $\mathcal{O}(n \sqrt{n})$가 됩니다.


구현 시 주의사항 및 코드

구현 시 주의사항을 생각나는 대로 적을 예정입니다

  • 놀랍게도 DFT를 $\mathcal{O}(k \log k)$가 아닌 $\mathcal{O}(k^2)$에 계산해도 전체 시간복잡도는 $\mathcal{O}(n \sqrt{n})$로 동일합니다! 다만 FFT를 쓸 수 있다면 써야 상수가 작아지므로 쓰는 게 좋겠습니다.
  • 2번째 합과 3번째 합은 거의 같은 수열의 앞 $k$개와 뒤 $k$개이므로 실제 구현에서는 대부분의 값이 중복으로 계산되어 함께 계산할 수 있습니다.
  • 문제에서 주어진 모듈로가 FFT를 직접 돌릴 수 없는 경우에도 계산 가능합니다. 다만 이 때는 FFT를 3개(또는 그 이상)의 서로 다른 소수로 계산해서 저장해 두어야 하기 때문의 $ \mathcal{O}\left(\frac{n^2}{k}\right)$ 부분의 상수가 커집니다. 따라서 블록 사이즈를 조금 더 크게 잡아야 더 빠르게 동작할 것입니다. 원래도 FFT를 수행하는 부분이 병목이 아니기 때문에 CRT가 조금 오래 걸려도 시간상 큰 문제는 없을 것 같습니다. garner CRT를 상수 모듈로 최적화를 받을 수 없는 형태로 짜면 꽤 오래 걸립니다. 적당히 빠를 정도로는 구현합시다...
  • a와 b가 서로 같을 경우(auto-convolution(?)) 잘 구현하면 상수가 $\frac{1}{2}$로 줄어듭니다. 또, 둘 중 하나(wlog $b$)가 고정된 수열인 경우는 맨 앞 합을 FFT로 계산하면 $\mathcal{O}(n k)$의 상수가 $\frac{1}{2}$로 줄어듭니다.

다음은 BOJ 1067번을 Online FFT로 푸는 코드입니다: http://boj.kr/2f5016b239d8486f907f3f0d1fa00352

다음은 같은 문제를 $10^9 + 7$을 모듈로로 하여 Online FFT로 푸는 코드입니다: http://boj.kr/216a39e758644c23883f4ab13ae18bcd


댓글 댓글 쓰기