PythonでRLSアルゴリズムを実装する

今回はPythonでRLSアルゴリズムを実装してみます。

以前の記事で取り扱ったLMSアルゴリズムと同じく適応フィルタの一種です。

LMSアルゴリズムより計算量は多く、また使える場面も限られますが、値が振動せず収束してくれるというメリットがあります。

今回は、システムのパラメータを逐次同定するという課題設定のもとで実装をしました。

コードは以下のようになります

[In]

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
np.random.seed(0)

[In]

class Plant(object):
    """ 同定対象のシステム(4点移動平均フィルタ)を記述しているクラス
    """
    def __init__(self):
        self.k0 = 0.0
        self.k1 = 0.0
        self.k2 = 0.0
        self.k3 = 0.0
        return
    def get(self, signal):
        self.k3 = self.k2
        self.k2 = self.k1
        self.k1 = self.k0
        self.k0 = signal
        return ((self.k0 + self.k1 + self.k2 + self.k3) * 0.25) + np.random.normal(loc=0,scale = 0.1)

[In]

class RecursiveLeastSquares(object):
    def __init__(self):
        self.len = 4
        self.theta = np.zeros((self.len,1))
        self.fai = np.zeros((self.len,1))
        self.cov = np.identity(self.len) * 100
        return

    def update(self, x, y):
        self.fai[1:] = self.fai[0:-1]
        self.fai[0] = x
        gain = (self.cov @ self.fai) / (self.fai.T @ self.cov @ self.fai + 1)
        error = y - (self.fai.T @ self.theta)
        self.theta = self.theta + (gain @ error)
        self.cov = self.cov - (gain @ self.fai.T @ self.cov)
        return

[In]

plant = Plant()
time = np.arange(300)
input_signal_array = []
output_signal_array = []
for i in time:
    input_signal = np.sin(2.0 * np.pi * i /20.0) + np.random.uniform(-1, 1)
    output_signal = plant.get(input_signal)
    input_signal_array.append(input_signal)
    output_signal_array.append(output_signal)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(time, input_signal_array)
plt.title("Input signal")
plt.subplot(1, 2, 2)
plt.plot(time, output_signal_array)
plt.title("Output signal")
plt.show()

[Out]

[In]

log = []
rls = RecursiveLeastSquares()
for t in time:
    rls.update(input_signal_array[t], output_signal_array[t])
    log.append(rls.theta)
        
for i in range(rls.theta.shape[0]):
    plt.plot(time, [log[j][i] for j in time], label="g_" + str(i))
plt.grid(True)
plt.legend()
plt.ylim(-0.5, 1.0)
plt.show()

[Out]

今回はここまで。

LMSアルゴリズムの時と同様アルゴリズムの説明を省略して恐縮ですが、他のサイトで非常に詳しく解説しているものもたくさんありますため、是非他のWebサイトや書籍を参考にし、本記事は実際のコード例、数値的な確認に使っていただければと思います。