Implementing Differentiable IIR in PyTorch#

In this section, we will implement the methods we have discussed in the section Differentiable Implementation of IIR Filters. We encourage the readers to read that section first to get a better understanding of the methods we will implement here.

Below are some high-level functions that we will be using in this section.

Naive IIR Implementation#

Before we start, let us import the necessary libraries and define some helper functions for plotting.

Hide code cell source
import torch
import scipy.signal as signal
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple
from timeit import timeit
from torch.profiler import profile, ProfilerActivity
Hide code cell source
def plot_tf(
    title: str,
    fs: int,
    t: np.ndarray,
    ys: List[np.ndarray],
    labels: List[str] = None,
    xlim: Tuple[float, float] = None,
):
    fig = plt.figure(figsize=(12, 5), dpi=200)
    plt.subplot(1, 2, 1)
    for y, label in zip(ys, labels) if labels is not None else zip(ys, [None] * len(ys)):
        plt.plot(t, y, label=label)
    plt.xlabel("Time [s]")
    plt.ylabel("Amplitude")
    plt.title(title)
    if label is not None:
        plt.legend()
    if xlim is not None:
        plt.xlim(*xlim)
    plt.subplot(1, 2, 2)
    for y, label in zip(ys, labels) if labels is not None else zip(ys, [None] * len(ys)):
        plt.magnitude_spectrum(y, Fs=fs, scale="dB", window=signal.windows.boxcar(len(y)), label=label)
    plt.xlabel("Frequency [Hz]")
    plt.ylabel("Magnitude [dB]")
    plt.xscale("log")
    plt.ylim(-60, 10)
    plt.title("Magnitude spectrum")
    if labels is not None:
        plt.legend()
    plt.show()

Let us first implement an IIR filter in PyTorch using the difference equation (1) directly and see how it performs.

@torch.jit.script
def naive_iir(x: torch.Tensor, a: torch.Tensor):
    """
    Naive implementation of IIR filter.
    """
    assert x.ndim == 1
    assert a.ndim == 1
    T = x.numel()
    M = a.numel()

    # apply initial rest condition
    # here, we use the utility function `new_zeros` to create a tensor
    # with the same dtype and device as `x` filled with zeros, so we
    # don't have to worry and create variables for the dtype and device
    y = [x.new_zeros(1)] * M
    for i in range(T):
        past_outputs = torch.cat(y[-1:-M-1:-1])
        y.append(x[i:i+1] -torch.dot(a, past_outputs))
    return torch.cat(y[M:])

For simplicity, the above implementation assumes \(y[n] = 0\) for \(n < 0\). In addition, We use the torch.jit.script decorator to utilise just-in-time compilation for speed improvement. Let us check whether the implementation is correct by comparing it with the scipy.signal.lfilter function.

The test signal is a chirp signal with a frequency that increases linearly from 1 Hz to 1000 Hz in 2 seconds. The sampling rate is 16000 Hz.

def chirp(T: int, start_freq: float, end_freq: float, fs: float):
    """
    Generate chirp signal.
    """
    angular_freq = np.linspace(start_freq, end_freq, T) / fs
    phase = np.cumsum(angular_freq)
    return np.sin(2 * np.pi * phase)

fs = 16000
T = 16000 * 2
start_freq = 1
end_freq = 1000
test_signal = chirp(T, start_freq, end_freq, fs)

plot_tf("Chirp signal", fs, np.arange(T) / fs, [test_signal])
../_images/7e1fcd59c4f0cb096c1f40fbbb9d54941e7d266e1b3d7f9090d51563efabf441.png

The filter we use is a second-order low pass filter (\(M\) = 2) with coefficients \(a_1 = -1.8\) and \(a_2 = 0.81\). Let us see how the filter performs on the chirp signal.

b = [1, 0]
a = [1, -0.9 * 2, 0.9 ** 2]
filtered_signal = signal.lfilter(b, a, test_signal)

plot_tf("Filtered Chirp signal", fs, np.arange(T) / fs, [filtered_signal])
../_images/62b5f17b19c5674d1a9f52f71e8b11f980ab0f05c9c1651c824a3812d79938f0.png

Is our PyTorch implementation correct? Let us compare it with the scipy.signal.lfilter function.

a_torch = torch.tensor(a[1:], dtype=torch.float64)
test_signal_torch = torch.tensor(test_signal, dtype=torch.float64)
filtered_signal_torch = naive_iir(test_signal_torch, a_torch)

torch.allclose(filtered_signal_torch, torch.from_numpy(filtered_signal))
True

Horay! It pass the test. Let us now see how fast it is.

Hide code cell source
scipy_time = timeit(
    stmt="signal.lfilter(b, a, test_signal)", setup="from scipy import signal", 
    globals={
        "b": b,
        "a": a,
        "test_signal": test_signal
    },
    number=1000
) / 1000

torch_time = timeit(
    stmt="naive_iir(test_signal_torch, a_torch)",
    globals={
        "a_torch": a_torch,
        "test_signal_torch": test_signal_torch,
        "naive_iir": naive_iir
    },
    number=10
) / 10

print(f"scipy.signal.lfilter: {scipy_time:.5f} [s]")
print(f"naive_iir: {torch_time:.4f} [s]")
scipy.signal.lfilter: 0.00022 [s]
naive_iir: 0.1630 [s]

The direct implementation is roughly 1000 times slower than the scipy.signal.lfilter function! What is the reason? One reason is that the core part of scipy.signal.lfilter is implemented in C and is highly optimised. Let us use torch.profiler to profile the code and see if we can make our implementation faster.

Hide code cell source
with profile(activities=[ProfilerActivity.CPU]) as prof:
    naive_iir(test_signal_torch[:], a_torch)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
STAGE:2024-03-10 19:30:24 3792:3792 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-03-10 19:30:25 3792:3792 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-03-10 19:30:25 3792:3792 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
--------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
--------------------  ------------  ------------  ------------  ------------  ------------  ------------  
           naive_iir        42.20%     116.566ms        99.99%     276.183ms     276.183ms             1  
           aten::cat        18.16%      50.171ms        18.16%      50.171ms       1.568us         32001  
           aten::dot        13.07%      36.113ms        14.96%      41.319ms       1.291us         32000  
         aten::slice        12.29%      33.953ms        12.42%      34.292ms       1.072us         32001  
           aten::sub        12.25%      33.833ms        12.25%      33.833ms       1.057us         32000  
         aten::empty         1.89%       5.212ms         1.89%       5.212ms       0.163us         32001  
    aten::as_strided         0.12%     340.000us         0.12%     340.000us       0.011us         32001  
     aten::new_zeros         0.00%      12.000us         0.01%      26.000us      26.000us             1  
     aten::new_empty         0.00%       6.000us         0.00%      13.000us      13.000us             1  
         aten::zero_         0.00%       1.000us         0.00%       1.000us       1.000us             1  
         aten::fill_         0.00%       1.000us         0.00%       1.000us       0.000us         32000  
--------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 276.208ms

The above table shows that a lot of CPU time were spent on aten:sub, aten::cat, and aten::dot. These operations are not optimised for small tensors, but we call them in each loop. In addition, each operation creates a new tensor, and with more than ten thousands of loops, the memory allocations becomes a bottleneck. We recommend readers to check issue 1238 of TorchAudio for more details on this.

One thing we can do is to first allocate a tensor for all the past output samples and reuse it in each loop. This will reduce the number of memory allocations. Let us see how much faster it is.

@torch.jit.script
def inplace_iir(x: torch.Tensor, a: torch.Tensor):
    """
    Inplace implementation of IIR filter.
    """
    assert x.ndim == 1
    assert a.ndim == 1
    T = x.numel()
    M = a.numel()

    # allocate memory for output
    y = x.new_zeros(M + T)
    y[M:] = x
    a_flip = a.flip(0)
    for i in range(T):
        past_outputs = y[i:i+M]
        y[i+M] -= torch.dot(a_flip, past_outputs)
    return y[M:]
Hide code cell source
assert torch.allclose(inplace_iir(test_signal_torch, a_torch), torch.from_numpy(filtered_signal))

torch_time = timeit(
    stmt="iir(test_signal_torch, a_torch)",
    globals={
        "a_torch": a_torch,
        "test_signal_torch": test_signal_torch,
        "iir": inplace_iir
    },
    number=10
) / 10

print(f"inplace_iir: {torch_time:.4f} [s]")
inplace_iir: 0.1205 [s]

We reduce around 20% of the CPU time! However, it is still much slower than the scipy counterpart. The remaining time are mostly related to the framework overhead and cannot be further reduced using the Python frontend.

Moreover, if you backpropagate the gradient through inplace_iir, it will throw a runtime error. PyTorch has very limited support in autograd for inplace operations. In our case, this limits calcaulating the gradients with respect to the filter coefficients (if you set requires_grad=False in the first line of the following code block, it will work). This issue was raised in the early version of TorchAudio (<0.9). For interested readers, you can check issue 704 to know more history about this.

a_torch_params = a_torch.clone().requires_grad_(True) # try setting this line to false, and see what happens
test_signal_torch_params = test_signal_torch.clone().requires_grad_(True)

filtered_signal_torch_params = inplace_iir(test_signal_torch_params, a_torch_params)
loss = torch.sum(filtered_signal_torch_params.abs())

try:
    loss.backward()
except RuntimeError as e:
    print(e)
else:
    print("No error!")
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [2]], which is output 0 of AsStridedBackward0, is at version 32001; expected version 32000 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Differentiable IIR using Frequency Sampling#

Let us make another version of the IIR function that uses the frequency sampling method to compute the output samples.

@torch.jit.script
def freq_iir(x: torch.Tensor, a: torch.Tensor):
    """
    Frequency sampling of IIR filter.
    """
    assert x.ndim == 1
    assert a.ndim == 1
    T = x.numel()
    M = a.numel()

    X = torch.fft.rfft(x)
    A = torch.fft.rfft(torch.cat([a.new_ones(1), a]), n=T)
    mag = torch.abs(A)
    phase = torch.angle(A)
    H = torch.nan_to_num(mag.reciprocal()) * torch.exp(-1j * phase)
    Y = X * H
    y = torch.fft.irfft(Y)
    return y

def get_fir_from_iir(a: torch.Tensor, n: int):
    """
    Get FIR filter from IIR filter.
    """
    assert a.ndim == 1
    A = torch.fft.rfft(torch.cat([a.new_ones(1), a]), n=n)
    mag = torch.abs(A)
    phase = torch.angle(A)
    H = torch.nan_to_num(mag.reciprocal()) * torch.exp(-1j * phase)
    h = torch.fft.irfft(H)
    return h
Hide code cell source
filtered_signal_freq = freq_iir(test_signal_torch, a_torch)

print(
    f"Is the filtered signal the same? {torch.allclose(filtered_signal_freq, torch.from_numpy(filtered_signal))}"
)

plot_tf(
    "Filtered Chirp signal",
    fs,
    np.arange(T) / fs,
    [filtered_signal_freq.numpy(), filtered_signal],
    ["freq_iir", "scipy.signal.lfilter"],
    xlim=(0, 0.1),
)
Is the filtered signal the same? False
../_images/0a92bbc9aa04dd304e222edeec31ed3b960a24122849b0a8229174d17f943e3c.png

The above code shows that this approximation does not generate the same output. Although the frequency response looks very close to the original, it does not pass the test. This is because performing the discrete time Fourier transform (DFT) of IIR system equals sampling the system response on discrete frequencies, thus truncate the infinity impulse response to the same length FFT. If you need some refreshment on this topic, you can revisit the section Frequency Sampling Method. Below is the truncated \(h_k\) we used in the above example. The sampling process also makes \(h_k\) become periodic, so the outputs is actually a circular convolution of the input and \(h_k\). This difference is most obvious in our filtered signal at \(t \to 0\).

Hide code cell source
h = get_fir_from_iir(a_torch, T)

plot_tf("FIR filter", fs, np.arange(T) / fs, [h.numpy()], xlim=(0, 0.01))
../_images/e48c0b40542f05da6eaa093bdc540d8deb435af3a34c5d412ad73d1d4161e031.png

Let us benchmark the performance of this approximation.

Hide code cell source
torch_time = timeit(
    stmt="iir(test_signal_torch, a_torch)",
    globals={
        "a_torch": a_torch,
        "test_signal_torch": test_signal_torch,
        "iir": freq_iir
    },
    number=100
) / 100

print(f"freq_iir: {torch_time:.4f} [s]")
freq_iir: 0.0036 [s]

Uohhhhhh! It is super duper fast! Although it is still slower than the scipy.signal.lfilter function, it is much faster than the direct implementation. In addition, it is differentiable and can be used in gradient-based methods. Moreover, in many cases, we only care about the frequency response of the filter and not its convolution result. You can compute the loss on the transfer function directly without converting it back to the time domain. Although FFT needs more memory than the direct implementation, the gain in speed outweighs the memory cost.

Custom Backward Method#

To define a custom operator in PyTorch, we have to use the torch.autograd.Function class. The forward method defines how to compute the output from the input, and the backward method defines how to compute the gradient of the loss with respect to the input. After defining the custom operator, we can use by calling its class method apply (not the forward method!).

class IIR(torch.autograd.Function):
    """
    IIR filter as a torch.autograd.Function.
    """
    @staticmethod
    def forward(ctx, x: torch.Tensor, a: torch.Tensor):
        assert x.ndim == 1
        assert a.ndim == 1
        M = a.numel()

        assert x.is_cpu & a.is_cpu, "Sorry, CPU only for now :("

        a_scipy = [1] + a.numpy().tolist()
        y_scipy = signal.lfilter([1] + [0] * M, a_scipy, x.numpy())
        y = torch.from_numpy(y_scipy).to(x.dtype)

        # remember to save necessary tensors for backward pass
        ctx.save_for_backward(a, y)
        return y

    @staticmethod
    def backward(ctx, y_grad: torch.Tensor):
        a, y = ctx.saved_tensors
        M = a.numel()
        T = y.numel()

        # compute gradient wrt a
        shift_y = torch.cat([y.new_zeros(M), y[:-1]])
        dyda = IIR.apply(-shift_y, a).unfold(0, T, 1).flip(0)
        a_grad = dyda @ y_grad

        # compute gradient wrt x
        x_grad = IIR.apply(y_grad.flip(0), a).flip(0)
        return x_grad, a_grad

In the above sample, we reuse scipy.signal.lfilter as our core runner as writing native C++ would be too much in this short notebook :) Another option is to use Numba to write the core function in Python and compile it to machine code. Numba also support CUDA and can be used to write GPU kernels. Noting that we call IIR.apply inside backward instead of scipy.signal.lfilter. Why? PyTorch by default does not disable gradient computation for the backward pass, which means the backward computational graph is differentiable if we use differentiable operators inside backward. The benefit of this is we can get the second order gradient by backpropagating the gradient through the backward pass, and get the third order gradient by backpropagating through the backward pass of the backward … you got the idea. Higher order gradient is useful in some applications, and as implementers, it is always good to provide the option for potential users.

Let us check if our custom backward method is correct using torch.autograd.gradcheck and torch.autograd.gradgradcheck. This is important as we want to make sure the gradient is correct before using it in our model.

assert torch.autograd.gradcheck(IIR.apply, (test_signal_torch_params[:1000], a_torch_params))
assert torch.autograd.gradgradcheck(IIR.apply, (test_signal_torch_params[:1000], a_torch_params), atol=1e-4)

Congratulations! We have successfully implemented a differentiable IIR filter that is at least as fast as the scipy.signal.lfilter function. You can use it in your model and backpropagate the gradient through it. In the next notebook, we are going to look at some examples on using differentiable IIR filters.

Additional Resources#

Although we have just walked step-by-step through the implementation of a differentiable IIR filter, there are already some implementations available online thanks to the open-source community. Here are some of them:

  • torchaudio.functional.lfilter: you can imagine this as a PyTorch version of scipy.signal.lfilter. It is implemented in C++ and CUDA and is highly optimised.

  • torchlpc: The implementations we have discussed so far only consider the case of time-invartiant IIR filters. This library implements the custom backward method and extends it to time-varying IIR filters.

  • diffsptk.AllPoleDigitalFilter: diffsptk is a differentiable version of the well-known speech processing toolkit, and it also provides a differentiable All-Pole filter, which actually depends on torchlpc.