门控循环单元(GRU)

为了解决标准RNN的梯度消失问题,GRU提出使用两个向量,即更新门(update gate)和重置门(reset gate)的概念,用来决定什么样的信息应该被传递给输出。它可以保存很久之前的信息,也会去掉不相关的信息。

背景

门控循环单元(Gate Recurrent Unit,GRU)是由Cho等人于2014年提出的。与LSTM相比,GRU的结构单元相对更为简单。LSTM虽能缓解循环神经网络因长期依赖带来的梯度问题,但是LSTM有三个不同的门,参数较多,训练起来比较困难。与LSTM相比,GRU只含有两个门控结构,在大多数情况下,两者的表现效果相当。

原理

GRU的结构如图所示。

0-2022-08-20-22-54-26

1-2022-08-20-22-54-34

2-2022-08-20-22-54-43

3-2022-08-20-22-55-01

代码

以历史轨迹数据为例,使用GRU网络进行预测未来航迹走向。模型代码如下:

import torch
import torch.nn as nn
import torchsnooper
input_dim = 3

class GRUModel(nn.Module):
    """Encoder is part of both TrajectoryGenerator and
    TrajectoryDiscriminator"""
    def __init__(
        self, seq_len=8, embedding_dim=64, h_dim=64, mlp_dim=1024, num_layers=1,
        dropout=0.0, use_cuda=0
    ):
        super(GRUModel, self).__init__()

        self.mlp_dim = mlp_dim
        self.h_dim = h_dim
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.use_cuda = use_cuda
        self.seq_len = seq_len

        self.encoder = nn.GRU(
            embedding_dim, h_dim, num_layers, dropout=dropout
        )

        self.hidden2pos = nn.Linear(h_dim, input_dim)

        self.spatial_embedding = nn.Linear(input_dim, embedding_dim)

    def init_hidden(self, batch):
        state = torch.zeros(self.num_layers, batch, self.h_dim)
        if self.use_cuda == 1:
            state = state.cuda()
        return state

    # snoop tensor location
    @torchsnooper.snoop()
    def forward(self, obs_traj):
        """
        Inputs:
        - obs_traj: Tensor of shape (obs_len, batch, 3)
        Output:
        obs_traj shape is  torch.Size([8, 64, 3])
        obs_traj_embedding shape is  torch.Size([512, 64])
        # obs_traj_embedding shape is  torch.Size([512, 3])
        obs_traj_embedding shape is  torch.Size([8, 64, 64])
        ouput shape is  torch.Size([8, 64, 64])
        cur_pos shape is torch.Size([512, 3])
        cur_pos shape is torch.Size([8, 64, 3])
        - final_h: Tensor of shape (self.num_layers, batch, self.h_dim)
        """
        # Encode observed Trajectory
        result = []
        batch = obs_traj.size(1)
        # print("obs_traj shape is ",obs_traj.shape)
        # print("obs_traj shape is ",obs_traj.contiguous().view(-1, input_dim).shape)
        obs_traj_embedding = self.spatial_embedding(obs_traj.contiguous().view(-1, input_dim))
        # print("obs_traj_embedding shape is ",obs_traj_embedding.shape)
        obs_traj_embedding = obs_traj_embedding.view(
            -1, batch, self.embedding_dim
        )
        # print("obs_traj_embedding shape is ",obs_traj_embedding.shape)
        encoder_state_tuple = self.init_hidden(batch)
        output, state = self.encoder(obs_traj_embedding, encoder_state_tuple)

        cur_pos = self.hidden2pos(output.view(-1, self.h_dim))
        # print("cur_pos shape is {}".format(cur_pos.shape))
        cur_pos = cur_pos.view(-1, batch, input_dim)
        # print("cur_pos shape is {}".format(cur_pos.shape))

        return cur_pos

完整代码参见Github链接:https://github.com/ironartisan/trajectory-prediction