从SSM到S4
- 开源代码
- 2025-09-05 08:51:02

参考视频:Mamba 超超超详细解说 | 1、对 SSM 的透彻理解_哔哩哔哩_bilibili
以下内容为上述视频的学习记录,详情可看视频。
一、SSM:State Space Model 状态空间模型(State Space Model,SSM)是一种数学框架,用于描述动态系统的行为。
通过描述系统的状态变量及其动态变化规律,建模系统的状态随时间的演变以及观测数据的变化。
1.1 组成部分 状态方程(State Equation):描述系统的内部状态如何随时间演变。输出方程(Output Equation):描述外部可观测的输出如何与系统状态相关联。
1.2 什么是状态State
状态是系统当前所处的条件或信息,用以确定其未来行为。
1.3 连续SSM--Continuous State Space Model
时间是连续的,状态随时间的变化用微分方程表示。
状态方程:
:状态变量的时间导数,表示系统随时间的变化率。:状态变量,表示系统在时刻 t 的状态。:输入变量,表示系统的外部控制输入。:状态转移矩阵,描述状态随时间变化的动态关系。:控制矩阵,描述输入变量对状态变化的影响。输出方程:
:输出变量,表示系统的观测值。:输出矩阵,描述状态对系统输出的影响。:直连矩阵,描述输入变量直接对输出的影响。1.4 离散SSM--Discrete State Space Model
时间是离散的,状态随时间的变化用差分方程表示。
状态方程:
:系统在离散时刻 k的状态。:系统在下一时刻 k+1的状态。:在时刻 k 的输入变量。:状态转移矩阵,描述状态从 k到 k+1 的转移关系。:控制矩阵。
输出方程:
:系统在时刻 k 的输出。:输出矩阵,描述状态对系统输出的影响。:直连矩阵,描述输入对输出的直接影响。1.5 连续SSM和离散SSM的区别
1.6 SSM举例(理解系统与状态) 1.6.1 离散SSM--养鱼场例子 1.6.2 连续SSM--弹簧系统例子
1.7 SSM的离散化 1.7.1 为什么要离散化 对于连续的系统,计算机没办法处理,计算机处理数据和执行算法时通常使用离散的时间步长。离散化将连续时间系统转换为离散时间系统,使其能够在计算机上进行数值模拟和计算。通过 连续SSM 的状态方程,我们可以知道任意时刻:输入、状态、状态微分(状态对时间的导数) 的关系,但没办法根据上一时刻的状态推测下一时刻的状态。离散化后,可以将系统表示为 状态的递推公式,可逐步递推系统状态。 1.7.2 离散化方法
离散化即使用方法近似公式(3)中的
1.7.2.1 前向欧拉法 1.7.2.2 后向欧拉法 1.7.2.3 梯形法 1.7.2.4 零阶保持 Zero-Order Hold零阶保持(Zero-Order Hold,ZOH)用于将离散时间信号转换为连续时间信号。这种方法假设输入信号在每个采样周期内保持恒定不变,即在每个采样点之后,直到下一个采样点到来之前,信号的值不再变化。
1.7.3 弹簧系统离散化代码示例 1.7.3.1 定义状态空间模型SSM函数 example_mass(k, b, m)
def example_mass(k, b, m): A = np.array([[0, 1], [-k / m, -b / m]]) B = np.array([[0], [1.0 / m]]) C = np.array([[1.0, 0]]) return A, B, C 输入参数:弹簧常数 k,阻尼系数 b,质量 m。矩阵定义: A:状态转移矩阵,描述系统的动态性质。B:控制矩阵,描述输入对状态的影响。C:输出矩阵,描述状态如何影响输出函数 example_force(t)
@partial(np.vectorize, signature='()->()') def example_force(t): x = np.sin(10 * t) return x * (x > 0.5) 描述:定义一个时间函数 u(t),该函数为正弦函数的变体,仅在 sin(10 * t) 大于 0.5 时有效。作用:为系统提供外部输入 u(t)。1.7.3.2 离散化
函数 discretize(A, B, C, step)
def discretize(A, B, C, step): I = np.eye(A.shape[0]) BL = inv(I - (step / 2.0) * A) Ab = (I + (step / 2.0) * A) @ BL Bb = (BL @ step) @ B return Ab, Bb, C 输入参数:连续时间矩阵 A、B、C 和离散时间步长 step。输出:离散时间的矩阵 Ab、Bb 和 C。方法:使用双线性变换(梯形积分法)进行离散化,这种方法可以提高离散化的准确性。 Ab:离散化后的状态转移矩阵。Bb:离散化后的控制矩阵。1.7.3.3 运行状态空间模型
函数 run_SSM(A, B, C, u)
def run_SSM(A, B, C, u): L = u.shape[0] N = A.shape[0] Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L) return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros((N,)))[1] 输入参数:状态空间矩阵 A、B、C 和输入 u。主要步骤: 计算离散化的矩阵 Ab、Bb 和 Cb。调用函数 scan_SSM 进行递归计算,返回系统的输出。辅助函数 scan_SSM(Ab, Bb, Cb, u, x0)
def scan_SSM(Ab, Bb, Cb, u, x0): def step(x_k, u_k): x_k = Ab @ x_k + Bb @ u_k y_k = Cb @ x_k return x_k, y_k return jax.lax.scan(step, x0, u) 描述:使用 jax.lax.scan 实现递归计算,模拟离散时间状态空间系统的动态行为。step 函数:在每个时间步长上更新状态 x_k 和计算输出 y_k。1.7.3.4 运行示例 def example_ssm(): ssm = example_mass(k=40, b=5, m=1) # Samples of u(t) L = 100 step = 1.0 / L ks = np.arange(L) u = example_force(ks * step) # Approximation of y(t) y = run_SSM(*ssm, u) # Plotting import matplotlib.pyplot as plt import seaborn from celluloid import Camera seaborn.set_context("paper") fig, (ax1, ax2, ax3) = plt.subplots(3) camera = Camera(fig) ax1.set_title("Force $u_k$") ax2.set_title("Position $y_k$") ax3.set_title("Object") ax1.set_xticks([], []) ax2.set_xticks([], []) ax3.set_xticks([], []) # Animate plot over time for k in range(0, L, 2): ax1.plot(ks[:k], u[:k], color="red") ax2.plot(ks[:k], y[:k], color="blue") ax3.boxplot( [[y[k], -0.04, y[k], 0], [y[k], 0, y[k], 0.04]], showcaps=False, whis=False, vert=False, widths=0.1, ) camera.snap() anim = camera.animate() anim.save("images/line.gif", dpi=150, writer="imagemagick") if __name__ == "__main__": example_ssm() 二、S4--Structured State Space for Sequences
Structured State Space for Sequences (S4)模型在训练和推理使用了不同形式,并且设计了Hippo矩阵作为SSM方程中的矩阵A。
由上述内容可知,离散SSM的公式如下:
2.1 SSM的RNN表示按照timesteps展开表示:
如下图,可发现SSM和RNN的表达形式基本一致:
2.2 SSM的Convolution表示因为文字序列是一维的,它的一维卷积表示如下:
卷积形式只有一个表达式:
2.3 RNN表示和Convolution表示的使用 2.4 Hippo矩阵 2.5 S4的参数化NPLR(Normal Plus LowRank):正规矩阵(Normal Matrix) + 低秩矩阵(LowRank Matrix)