国庆假期结束了。
上一次,写了一个DFT函数,用它计算一个10万大小的输入信号,结果等了很久然后显示“内存溢出”。看来DFT算法确实不行。上次的DFT代码如下:
def myDFT(inputSig,isPrintW=False):
'''
@author:zengwei
输入:
inputSig:指输入信号,我希望它是array格式。
isPrintW:可以指定是否输出那个W矩阵
输出:
DFT计算出的结果
'''
sigLen = len(inputSig)
n = np.arange(0,sigLen).reshape(1,sigLen)
k = n.reshape(sigLen,1)
base = np.exp(-1j*2*np.pi/sigLen) # 基低
w = np.dot(k,n) # 指数矩阵
W = base**w # W矩阵
outputSig = np.dot(W,inputSig)
if isPrintW:
return outputSig,W
else:
return outputSig
这一次,准备尝试一下快速傅里叶算法,参考的是《数字信号处理》(胡广书)的4.1和4.2节——时间抽取基2FFT算法。我的代码如下:
def MyDIT_FFT(inputA):
'''
@author:zengwei
思路:
码位倒置+分组+分级+蝶形单位+旋转因子
符号说明:
getSub:获得码位倒置序列的函数;
inputA:输入序列,希望是2^n长度;
N:输入序列的长度,N=2^n;
m:分组计算的组数;M:级数
X1:FFT变换后的结果
'''
def getSub(intNum,M):
return int( bin(intNum)[2:].zfill(M)[::-1],2 )
N = len(inputA)
n = np.log2(N)
sub = [getSub(i,int(n)) for i in np.arange(N)]
X0 = np.array([inputA[k] for k in sub],dtype = complex) # 初始序列
for M in np.arange(int(n)): # 遍历每一级
m = N//(2**(M+1)) # 分组数
groupM = np.arange(N).reshape(m,N//m) # 进行分组
r = np.arange(2**M)*(2**(n-1-M)) # 旋转因子的指数
W = np.exp(-1j*2*np.pi/N) # 旋转因子底数
Wr = (W**r).tolist() # 旋转因子
if len(Wr) < (N//2): # 旋转因子长度补长
Wr = np.array(Wr * (N//2//len(Wr)))
X1 = np.zeros(len(inputA),dtype = complex) # 存放输出序列
for i,p in enumerate(groupM[:,0:N//m//2].reshape(N//2)): # 遍历每一组
q = p + 2**M
X1[p] = X0[p] + Wr[i]*X0[q]
X1[q] = X0[p] - Wr[i]*X0[q]
X0 = X1
return X1
经过测试,其结果与numpy内置的fft函数结果一致,这里就不放测试部分代码了。这里对比一下我的FFT和内置的FFT的运行速度。测试信号的长度从2的3次方到2的20次方(大约100万)。
from timeit import default_timer as timer
timeMyFFt = []
timeFFt = []
for i in np.arange(3,21):
testdata = np.random.randint(0,10**4,2**i)
startmyfft = timer()
MyDIT_FFT(testdata)
endmyfft = timer()
timeMyFFt.append(endmyfft-startmyfft)
startfft = timer()
np.fft.fft(testdata)
endfft = timer()
timeFFt.append(endfft-startfft)
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure()
plt.scatter(np.arange(3,21),timeMyFFt)
plt.plot(np.arange(3,21),timeMyFFt)
plt.scatter(np.arange(3,21),timeFFt)
plt.plot(np.arange(3,21),timeFFt)
plt.legend(["MyDIT_FFT","FFT"])
plt.xlim(3,20)
plt.xlabel('2^Index')
plt.ylabel('time')
plt.show()
测试结果如下:
运行时间对比.png
可以看到,我的FFT虽然比DFT算法好很多了,但还有不小的优化空间。
寒假还会远吗。
网友评论