长短期记忆网络LSTM

长短期记忆网络(LSTM,Long Short-Term Memory)是一种时间循环神经网络,是为了解决一般的RNN(循环神经网络)存在的长期依赖问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层。

背景

为了缓解循环神经网络引起的梯度问题,有学者提出一种控制信息积累的方式。本小节中,主要介绍一种循环神经网络的变种:长短期记忆网络LSTM。

原理

LSTM网络拓展了RNN记忆的长度,LSTM利用控制网络权重的方式,帮助网络输入新信息、忘记信息或赋予其足够的重要性以影响输出。LSTM使RNN能够长时间记住输入,因为LSTM将信息存储在记忆中,可以从中读取、添加和剔除信息。这个记忆可以看成是一个门控单元,门控单元基于赋予信息的重要程度来确定是否添加或剔除信息。重要程度的分配是通过权重决定的,同时也是模型需要学习的地方。随着模型的训练,权重系统不断变化,信息的重要程度也逐步确定。

在LSTM网络中,有三种特殊的门,分别为输入门、遗忘门和输出门。输入门负责决定是否允许新的输入,确定应该使用输入中的哪些值;遗忘门负责确定从网络块中要丢弃的值;输出门负责决定当前时刻的输出结果。其具体的循环单元结构如图所示。

0-2022-08-20-22-33-59

1-2022-08-20-22-34-09

2-2022-08-20-22-34-18

通过输出门对当前的状态信息 进行有选择的输出。LSTM通过上述的方式对信息进行选择和遗忘,经过重复类似的过程,模型能取得较好的效果。

代码

以历史轨迹数据为例,使用LSTM网络进行预测未来航迹走向。模型代码如下:

import torch
import torch.nn as nn

input_dim = 3

class My_Net(nn.Module):
    """Encoder is part of both TrajectoryGenerator and
    TrajectoryDiscriminator"""
    def __init__(
        self, seq_len=8, embedding_dim=64, h_dim=64, mlp_dim=1024, num_layers=1,
        dropout=0.0, use_cuda=0
    ):
        super(My_Net, self).__init__()

        self.mlp_dim = mlp_dim
        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.use_cuda = use_cuda
        self.seq_len = seq_len

        self.encoder = nn.LSTM(
            embedding_dim, h_dim, num_layers, dropout=dropout
        )

        self.decoder = nn.LSTM(
            embedding_dim, h_dim, num_layers, dropout=dropout
        )

        self.hidden2pos = nn.Linear(h_dim, input_dim)

        self.spatial_embedding = nn.Linear(input_dim, embedding_dim)

    def init_hidden(self, batch):
        state0 = torch.zeros(self.num_layers, batch, self.h_dim)
        state1 = torch.zeros(self.num_layers, batch, self.h_dim)

        if self.use_cuda == 1:
            state0 = state0.cuda()
            state1 = state1.cuda()

        return (state0, state1)

    def forward(self, obs_traj):
        """
        Inputs:
        - obs_traj: Tensor of shape (obs_len, batch, 3)
        Output:
        obs_traj shape is  torch.Size([8, 64, 3])
        obs_traj_embedding shape is  torch.Size([512, 64])
        # obs_traj_embedding shape is  torch.Size([512, 3])
        obs_traj_embedding shape is  torch.Size([8, 64, 64])
        ouput shape is  torch.Size([8, 64, 64])
        cur_pos shape is torch.Size([512, 3])
        cur_pos shape is torch.Size([8, 64, 3])
        - final_h: Tensor of shape (self.num_layers, batch, self.h_dim)
        """
        # Encode observed Trajectory
        result = []

        batch = obs_traj.size(1)
        # print("obs_traj shape is ",obs_traj.shape)
        # print("obs_traj shape is ",obs_traj.contiguous().view(-1, input_dim).shape)
        obs_traj_embedding = self.spatial_embedding(obs_traj.contiguous().view(-1, input_dim))
        # print("obs_traj_embedding shape is ",obs_traj_embedding.shape)
        obs_traj_embedding = obs_traj_embedding.view(
            -1, batch, self.embedding_dim
        )
        # print("obs_traj_embedding shape is ",obs_traj_embedding.shape)
        encoder_state_tuple = self.init_hidden(batch)
        output, state = self.encoder(obs_traj_embedding, encoder_state_tuple)

        cur_pos = self.hidden2pos(output.view(-1, self.h_dim))
        # print("cur_pos shape is {}".format(cur_pos.shape))
        cur_pos = cur_pos.view(-1, batch, input_dim)
        # print("cur_pos shape is {}".format(cur_pos.shape))

        return cur_pos

完整代码参见Github链接:https://github.com/ironartisan/trajectory-prediction