PythonでLMSアルゴリズムを実装する

今回はPythonでLMSアルゴリズムを実装してみます。

LMSアルゴリズムとは適応フィルタの一種で、ノイズキャンセリングなどに応用されているようなものです。

今回は、システムのパラメータを逐次同定するという課題設定のもとで実装をしました。

コードは以下のようになります

[In]

"""
4点移動平均フィルタのシステムについて、入力信号(白色雑音)と出力信号を用いてLMSアルゴリズムにより逐次同定を行う
"""
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
np.random.seed(0)

[In]

# 同定対象のシステム(4点移動平均フィルタ)を記述しているクラス
class Plant(object):
    def __init__(self):
        self.k0 = 0.0
        self.k1 = 0.0
        self.k2 = 0.0
        self.k3 = 0.0
        return
    def get(self, signal):
        self.k3 = self.k2
        self.k2 = self.k1
        self.k1 = self.k0
        self.k0 = signal
        return ((self.k0 + self.k1 + self.k2 + self.k3) * 0.25) + np.random.normal(loc=0,scale = 0.1)

[In]

# LMSアルゴリズムを実装しているクラス
class LeastMeanSquare(object):
    def __init__(self, alpha):
        self.len = 4
        self.regressor = np.zeros(self.len)
        self.parameter = np.zeros(self.len)
        self.alpha = alpha
        return
    
    def next(self, sig_in, sig_out):
        self.regressor[1:] = self.regressor[0:-1]
        self.regressor[0] = sig_in
        error = (self.regressor @ self.parameter) - sig_out
        self.parameter = self.parameter - (2 * self.alpha * self.regressor * error)
        return

[In]

plant = Plant()
time = np.arange(300)
input_signal_array = []
output_signal_array = []
for i in time:
    input_signal = np.random.normal() 
    output_signal = plant.get(input_signal)
    input_signal_array.append(input_signal)
    output_signal_array.append(output_signal)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(time, input_signal_array)
plt.subplot(1, 2, 2)
plt.plot(time, output_signal_array)

[Out]

※左が入力信号で右が出力信号

[In]

alpha_list = [0.005, 0.01, 0.025, 0.05]
plt.figure(figsize=(20, 4))
for j, a in enumerate(alpha_list):
    lms = LeastMeanSquare(alpha=a)
    log = []
    for t in time:
        lms.next(input_signal_array[t], output_signal_array[t])
        log.append(lms.parameter)
        
    plt.subplot(1, len(alpha_list), j+1)
    for i in range(lms.parameter.shape[0]):
        plt.plot(time, [log[j][i] for j in time])
    plt.grid(True)
    plt.ylim(0, 0.4)

[Out]

※左から、α=0.005, 0.01, 0.025, 0.05 の時のパラメータ推定の推移

今回はここまで。 コードを載せただけでロクLMSアルゴリズムの説明をしていなくて恐縮ですが、説明は他のWebサイトなり書籍を参考にしてください。