Friday, February 27, 2015

py29. Discrete Cosine Transform in Python

The Discrete Cosine Transform (DCT) leads to a real transform, whereas the fft leads to a complex transform. The relevant equations are given, for example, at Wikipedia Type 2 is the dct and type 3 is idct. We have to convert these equations to a form for the fft and ifft.


The program below finds the DCT and inverse DCT using fft and ifft. For both fft and ifft, we have to solve a 2N problem. For example, for the DCT, x is joined with its reverse. Besides the fft and ifft, we have to add phase terms and proper scaling constants.


The example x is of length 16. We plot x and its DCT. The main code is run only if the file is ran directly, and not imported by other python programs.


For comparison, we use the dct and idct functions from scipy.fftpack module with aliases of sdct, and sidct.

# ex29.py

from __future__ import print_function, division
import numpy as np
from scipy.fftpack import dct as sdct, idct as sidct
from numpy.fft import fft, ifft
import matplotlib.pyplot as plt

def dct(X):
    x = np.concatenate([X,np.flipud(X)])
    N = 2*len(X)
    f0 = 1/np.sqrt(2*N)
    f1 = 1/np.sqrt(N)
    Y = fft(x)
    kin = np.arange(N)
    Mul = np.cos(np.pi/N*kin)-1j*np.sin(np.pi/N*kin)
    Y *= Mul
    Y *= f1
    Y[0] *= f0/f1
    return Y[:N/2].real

def idct(X):
    N = len(X)
    f0 = 1/np.sqrt(4*N)
    f1 = 1/np.sqrt(2*N)
    kin = np.arange(N)
    FTR = X*np.cos(.5*np.pi/N*kin)
    FTI = X*np.sin(.5*np.pi/N*kin)
    FTR /= f1
    FTI /= f1
    FTR[0] *= f1/f0
    FTI[0] *= f1/f0
    FTR1 = np.concatenate((FTR,np.zeros(1),np.flipud(FTR[1:])))
    FTI1 = np.concatenate((FTI,np.zeros(1),-1*np.flipud(FTI[1:])))
    y = ifft(FTR1+1j*FTI1)
    return y[:N].real

if __name__ == '__main__':
    x = np.array(range(4)+range(2,5)+range(10,1,-1))
    y = dct(x)
    y1 = sdct(x.astype('float'), norm = 'ortho')
    z = idct(y)
    z1 = sidct(y1, norm = 'ortho')
    err1 = (z-x).std()
    err2 = (z1-x).std()
    print ('err1 =',err1)
    print ('err2 =',err2)
    plt.plot(x,'ro--', markerfacecolor='w')
    plt.plot(y,'bo--', markerfacecolor='w')
    plt.xlabel('n')
    plt.ylabel('x,y')
    plt.title('x (red), y = dct(x) (blue)')
    plt.xlim(-.5,15.5)
    plt.show()

# err1 = 1.43874580946e-15
# err2 = 4.52626554495e-16

Output:

No comments:

Post a Comment