There have been several different blog posts which talk about Encoder Decoder mechanisms and annotate their implementation. In this blog post we will do like-wise and annotate the “Plan, Attend, Generate: Planning for Sequence-to-Sequence Models” which was part of NIPS 2017.

Although this isn’t a particular famous or popular paper (with 5 citations); it is a generalisation of a RL approach called STRAW from “Strategic Attentive Writer for Learning Macro-Actions” which was part of NIPS 2016.

This post is written and based heavily on the implementation and post “The Annotated Encoder-Decoder with Attention” which in turn was the basis for “Plan, Attend, Generate” (PAG).

Although this isn’t the official implementation - we will try our best to adhere to the implementation and notes within the original paper and talk through the assumptions which I make when implementing the model. There is an official implementation which I will reference - however I have found it difficult to align all parts of the code together

Note: we won’t have a useful implementation here; it might be an item that I consider in the future

Model Architecture

The model architecture is shown below; as presented in the paper. The overarching goal is that the model learns to plan and execute alignments. It differs from standard sequence-to-sequence model with attention in that it makes a plan for future alignment and whether to follow the plan using a separate commitment vector.

pag image

Encoder

The encoder used is the same encoder as per the attention-based neural machine translation paper (Attention-NMT). It is a bidirection RNN, which we will use a Bi-GRU in this scenario.

For each input position $i$, an annotation vector $h_i$ is created by generating both the forward and backward encoder states; this will contain the full context for position $i$ as it will have both information on preceding and proceeding token information. Typically, this information would have an embedding vector formed as well so that the model can exploit tokens which are similar.

Decoder

On first glance, the decoder is very much the same as Attention-NMT approach; whereby the decoder is formed by taking in some additional information called context

$$s_t = f_{\text{decode}}(s_{t-1}, y_t, \psi_t)$$

Where $y_t$ is the previously generated token, and $\psi_t$ is the context obtained by weighted sum of encoder annotations $\psi_t = \sum_{i} \alpha_{ti} h_i$

Where does the difference lie?

In the attention mechanism! Rather than simply creating a mechanism with attention, we also add a planning mechanism, which is generated at each time step of the decoder - leading to the name “Planning, Attend, Generate”.

Attend

The “attend” (or alignment) mechanism is formed through generating a candidate alignment plan $A_t$; based on the number of steps we “look ahead”. For the $i$th look ahead step, it is generated from

$$A_t[i] = f_{\text{align}}(\textbf{s}_{t-1}, \textbf{h}_j, \beta_t^i, \textbf{y}_t)$$

where the alignment vector for the first time step (i.e. $A_t[0]$) is the attention mechanism of interest. In this scenario $f_{\text{align}}$ is presented as a MLP, with $\beta_t^i$ to be the summary of the alignment matrix at $i$th planning step for time $t-1$.

Planning

If we simply generate a brand new alignment vector without any context with what was considered before - then we’ve done no better than simply using attention. So we must have a way to plan. To plan, it use MLP with gumbel-softmax trick. Note that this is recomputed at every time-step to redetermine whether or not we need to re-plan or not!

$$c_t = f_{\text{plan}}(s_{t-1})$$

Whereby the output is a one-hot encoded vector which indicates when the next “planning” step should occur.

Generate

Finally, how is all of this generated? As per the planning stage - if the next step, i.e. $g_t = c_t[0]$, is the commitment switch. When $g_t = 0$ this indicates that we proceed as planned, and everything (decoder and plan) is moved forward using the shift operator. When $g_t = 1$, we update our information and interpolate the previous alignment plan to create a new alignment plan. These are combined in an additive manner (basically a weighted sum, so we add sigmoid activation) $\mathbf{u}_{ti}$ which is also learned.

$$\mathbf{u}{ti} = f{\text{update}}(\mathbf{h}i, \mathbf{s}{t-1})$$

$$\bar{A}t[:, i] = (1-\mathbf{u}{ti}) \cdot A_{t-1}[:, i] + \mathbf{u}_{ti} \cdot A_{t}[:, i]$$

Pseudo-code

for j in input_token:
  for t in output_token:
    if g = 1:
      compute commitment plan: c
      update alignment plan through interpolation: A = (1-u) A[old] + u A[new]
    if g = 0:
      shift commitment plan: c
      shift alignment plan: A
    compute alignment alpha = A[0]

Coding!

Prelim

We shall use PyTorch (at the time of writing it is version 1.3.X)

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

Model Class

The EncoderDecoder class is almost identical to Attention-NMT. The differences are around how to retain semblance of state between subsequent decoder calls

class EncoderDecoder(nn.Module):
    """
    A (mostly) standard Encoder-Decoder architecture. Base for this and many 
    other models.
    """

    def __init__(self, encoder, decoder, src_embed, trg_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.generator = generator

    def forward(
        self,
        src,
        trg,
        src_mask,
        trg_mask,
        src_lengths,
        trg_lengths,
        previous_hidden=None,
        action_plan=None,
        previous_out=None,
        commit_plan=None,
    ):
        """Take in and process masked src and target sequences."""
        encoder_hidden, encoder_final = self.encode(src, src_mask, src_lengths)

        decoder_states, hidden, pre_output_vectors, action_plan, commit_plan = self.decode(
            encoder_hidden,
            encoder_final,
            src_mask,
            trg,
            trg_mask,
            hidden=previous_hidden,
            action_plan=action_plan,
            previous_out=previous_out,
            commit_plan=commit_plan,
        )

        # if hidden is None:
        #     hidden = self.init_hidden(encoder_final)

        # commitment_plan = self.planner(hidden)
        return decoder_states, hidden, pre_output_vectors, action_plan, commit_plan

    def encode(self, src, src_mask, src_lengths):
        src = self.src_embed(src)
        # print("embed src", src.shape)
        return self.encoder(src, src_mask, src_lengths)

    def decode(
        self,
        encoder_hidden,
        encoder_final,
        src_mask,
        trg,
        trg_mask,
        hidden=None,
        action_plan=None,
        previous_out=None,
        commit_plan=None,
    ):
        # print("embed trg pre", trg.shape)
        trg = self.trg_embed(trg)
        # print("embed trg", trg.shape)
        return self.decoder(
            trg,
            encoder_hidden,
            encoder_final,
            src_mask,
            trg_mask,
            hidden=hidden,
            action_plan=action_plan,
            previous_out=previous_out,
            commit_plan=commit_plan,
        )

We also make use of our embeddings and generators which are simply one layer items. Of course for embeddings you could use the inbuilt embedding as well if the sequences are word tokens.

class ContinuousEmbedding(nn.Module):
    def __init__(self, emb_size):
        super(ContinuousEmbedding, self).__init__()
        self.embedding = nn.Linear(1, emb_size, bias=False)

    def forward(self, x):
        x = x.unsqueeze(-1)
        x = self.embedding(x)
        return x


class Generator(nn.Module):
    """Define standard linear + softmax generation step."""

    def __init__(self, hidden_size, output_size):
        super(Generator, self).__init__()
        self.proj = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        x = self.proj(x)
        return x

Encoder

The encoder is exactly the same as Attention-NMT. As PyTorch enables much of this out of the box, there’s very little to worry about. We make use of the utility functions as part of pytorch to handle padding for varying sentence lengths as well.

class Encoder(nn.Module):
    """Encodes a sequence of word embeddings
    """

    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.rnn = nn.GRU(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout,
        )

    def forward(self, x, mask, lengths):
        """
        Applies a bidirectional GRU to sequence of embeddings x.
        The input mini-batch x needs to be sorted by length.
        x should have dimensions [batch, time, dim].

        what does the util functions do? Ensure that all inputs have
        the same length!
        """
        packed = pack_padded_sequence(x, lengths, batch_first=True)
        output, final = self.rnn(packed)
        output, _ = pad_packed_sequence(output, batch_first=True)

        # we need to manually concatenate the final states for both directions
        fwd_final = final[0 : final.size(0) : 2]
        bwd_final = final[1 : final.size(0) : 2]
        final = torch.cat([fwd_final, bwd_final], dim=2)  # [num_layers, batch, 2*dim]

        return output, final

Decoder

The decoder is a conditional GRU. Again we don’t change things too much from the Attention-NMT approach, as the difference lies within the PAG mechanism. The largest difference is the requirement to maintain some form of state.


class Decoder(nn.Module):
    """A conditional RNN decoder with attention."""

    def __init__(
        self, emb_size, hidden_size, attention, num_layers=1, dropout=0.5, bridge=True
    ):
        super(Decoder, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.attention = attention
        self.dropout = dropout

        self.rnn = nn.GRU(
            emb_size + 2 * hidden_size,
            hidden_size,
            num_layers,
            batch_first=True,
            dropout=dropout,
        )

        # to initialize from the final encoder state
        self.bridge = (
            nn.Linear(2 * hidden_size, hidden_size, bias=True) if bridge else None
        )

        self.dropout_layer = nn.Dropout(p=dropout)
        self.pre_output_layer = nn.Linear(
            hidden_size + 2 * hidden_size + emb_size, hidden_size, bias=False
        )

    def forward_step(
        self,
        prev_embed,
        encoder_hidden,
        src_mask,
        proj_key,
        hidden,
        beta=None,
        previous_out=None,
        prev_action_plan=None,
        prev_commit_plan=None,
    ):
        """Perform a single decoder step (1 word)
        prev_embed, encoder_hidden, src_mask, proj_key, hidden
        """

        # compute context vector using attention mechanism
        query = hidden[-1].unsqueeze(1)  # [#layers, B, D] -> [B, 1, D]

        context, attn_probs, action_plan, commit_plan = self.attention(
            query=query,
            proj_key=proj_key,
            value=encoder_hidden,
            mask=src_mask,
            beta=None,
            previous_out=previous_out,
            prev_action_plan=prev_action_plan,
            prev_commit_plan=prev_commit_plan,
        )

        # update rnn hidden state
        rnn_input = torch.cat([prev_embed, context], dim=2)
        output, hidden = self.rnn(rnn_input, hidden)

        pre_output = torch.cat([prev_embed, output, context], dim=2)
        pre_output = self.dropout_layer(pre_output)
        pre_output = self.pre_output_layer(pre_output)

        return output, hidden, pre_output, action_plan, commit_plan

    def forward(
        self,
        trg_embed,
        encoder_hidden,
        encoder_final,
        src_mask,
        trg_mask,
        hidden=None,
        max_len=None,
        action_plan=None,
        previous_out=None,
        commit_plan=None,
    ):
        """Unroll the decoder one step at a time.
        
        trg_embed       = encoder_hidden,
        encoder_hidden  = encoder_final,
        encoder_final   = src_mask[[0], :, :],
        src_mask        = prev_y,
        trg_mask        = prev_mask,
        hidden          = hidden
        """

        # the maximum number of steps to unroll the RNN
        if max_len is None:
            max_len = trg_mask.size(-1)

        # initialize decoder hidden state
        if hidden is None:
            hidden = self.init_hidden(encoder_final)

        # pre-compute projected encoder hidden states
        # (the "keys" for the attention mechanism)
        # this is only done for efficiency

        # TODO fix this hack...should be created under self.attention, as it is shared?
        # this is shared with the generate layer as
        proj_key = self.attention.attend.key_layer(encoder_hidden)
        if action_plan is not None:
            prev_action_plan = action_plan.clone().detach()
        else:
            prev_action_plan = None

        if commit_plan is not None:
            prev_commit_plan = commit_plan.clone().detach()
        else:
            prev_commit_plan = None

        if self.attention.attend.summary_layer is not None and action_plan is not None:
            beta_summary = self.attention.attend.summary_layer(action_plan)
        else:
            beta_summary = None

        if (
            self.attention.attend.prev_summary_layer is not None
            and previous_out is not None
        ):
            prev_summary = self.attention.attend.prev_summary_layer(previous_out)
        else:
            prev_summary = None

        # here we store all intermediate hidden states and pre-output vectors
        decoder_states = []
        pre_output_vectors = []

        # unroll the decoder RNN for max_len steps
        for i in range(max_len):
            prev_embed = trg_embed[:, i].unsqueeze(1)
            output, hidden, pre_output, action_plan, commit_plan = self.forward_step(
                prev_embed=prev_embed,
                encoder_hidden=encoder_hidden,
                src_mask=src_mask,
                proj_key=proj_key,
                hidden=hidden,
                beta=beta_summary,
                previous_out=prev_summary,
                prev_action_plan=prev_action_plan,
                prev_commit_plan=prev_commit_plan,
            )
            decoder_states.append(output)
            pre_output_vectors.append(pre_output)

        decoder_states = torch.cat(decoder_states, dim=1)
        pre_output_vectors = torch.cat(pre_output_vectors, dim=1)

        return (
            decoder_states,
            hidden,
            pre_output_vectors,
            action_plan,
            commit_plan,
        )  # [B, N, D]

    def init_hidden(self, encoder_final):
        """Returns the initial decoder state,
        conditioned on the final encoder state."""

        if encoder_final is None:
            return None  # start with zeros

        return torch.tanh(self.bridge(encoder_final))

Plan, Attend, Generate

To implement the PAG mechanism, we take some liberties in how it was done in the original paper. The original paper does not specify in great detail how each component was contructed; though we note the similarities between Attention-NMT and PAG approaches.

Starting with the update layer - this one is one of the easier ones as it needs only combine information on $\mathbf{h}$ and $\mathbf{s}$. In order to effectively combine them together, embedding both inputs is created so that they can be safely combined through addition. The embeddings both us tanh as to safely allow them to shift together than sigmoid is applied to generate the weighted sum.

class PAGGenerate(nn.Module):
    """
    Implements the update layer as part of PAG
    this is part of the "generate" step
    """

    def __init__(self, hidden_size, plan_length=1, query_size=None, score_size=None):
        super(PAGGenerate, self).__init__()
        query_size = hidden_size if query_size is None else query_size
        self.score_size = hidden_size if score_size is None else score_size
        self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
        self.energy_layer = nn.Linear(hidden_size, plan_length, bias=False)

        # build embedding of the value (hidden)
        self.value_embedding = nn.Linear(hidden_size * 2, plan_length, bias=False)

    def forward(
        self,
        query=None,
        proj_key=None,
        value=None,
        mask=None,
        beta=None,
        previous_out=None,
    ):
        # TODO remove beta and previous_out later.

        assert mask is not None, "mask is required"

        # We first project the query (the decoder state).
        # The projected keys (the encoder states) were already pre-computated.
        decoder_state = self.query_layer(query)

        # Calculate scores.
        scores = self.energy_layer(torch.tanh(decoder_state + proj_key))
        scores = scores.permute(0, 2, 1)
        value_embed = self.value_embedding(torch.tanh(value))
        value_embed = value_embed.permute(0, 2, 1)

        update = value_embed + scores
        return torch.sigmoid(update)

In a similar way the commitment plan can be generated in a rather straightforward fashion

class PAGPlan(nn.Module):
    """
    Implements the commitment plan layer as part of PAG
    This is part of the "plan" step
    """

    def __init__(self, hidden_size, plan_length=1, query_size=None, score_size=None):
        super(PAGPlan, self).__init__()
        query_size = hidden_size if query_size is None else query_size
        self.score_size = hidden_size if score_size is None else score_size
        self.commit_layer = nn.Linear(query_size, plan_length, bias=False)

    def forward(
        self,
        query=None,
        proj_key=None,
        value=None,
        mask=None,
        beta=None,
        previous_out=None,
    ):
        # TODO remove beta and previous_out later.

        # We first project the query (the decoder state).
        # The projected keys (the encoder states) were already pre-computated.
        decoder_state = self.commit_layer(query)

        log_proba = F.log_softmax(decoder_state, dim=-1)
        commit_vector = F.gumbel_softmax(log_proba, 0.01)
        return commit_vector

The attend mechanism borrows heavily from the attention mechanism in Attend-NMT; the only difference is that if $\beta$ and $y$ tokens are generated then the embedding is like-wise added before the final alignment plan is generated.

class PAGAttend(nn.Module):
    """Implements Plan Attend Generate (MLP) attention as per here:
    This performs "attend" step
    
    Francis Dutil, Caglar Gulcehre, Adam Trischler, Yoshua Bengio, Plan, Attend, Generate: 
    Planning for Sequence-to-Sequence Models (NIPS 2017)

    If no history (previous action plan) or previous tokens are generated, it will
    degerate to Bahdanau Attention: 
    "Neural Machine Translation by Jointly Learning to Align and Translate"
    """

    def __init__(
        self,
        hidden_size,
        plan_length=1,
        key_size=None,
        query_size=None,
        summary_layer=None,
        pred_embedding=None,
    ):
        super(PAGAttend, self).__init__()
        # We assume a bi-directional encoder so key_size is 2*hidden_size
        key_size = 2 * hidden_size if key_size is None else key_size
        query_size = hidden_size if query_size is None else query_size

        self.key_layer = nn.Linear(key_size, hidden_size, bias=False)

        self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
        # used for alignment plan.
        self.energy_layer = nn.Linear(hidden_size, plan_length, bias=False)

        # generates the betas
        self.summary_layer = summary_layer
        self.prev_summary_layer = pred_embedding
        # this is used instead of the energy layer?
        # self.alignment_plan = nn.Linear(hidden_size, plan_length, bias=False)

        # to store attention scores
        self.alphas = None

    def forward(
        self,
        query=None,
        proj_key=None,
        value=None,
        mask=None,
        beta=None,
        previous_out=None,
    ):
        # mask is prev_y
        assert mask is not None, "mask is required"
        # https://github.com/Dutil/PAG/blob/3e9f9beac6072cdc1aabaa5f015574402b12929c/planning.py#L685

        # We first project the query (the decoder state).
        # The projected keys (the encoder states) were already pre-computated.
        decoder_state = self.query_layer(query)

        # if self.action_plan is None:
        #     print("action_plan None")
        #     beta = 0
        # else:
        #     print("action_plan", self.action_plan.shape)
        #     beta = self.summary_layer(self.action_plan.permute(0, 2, 1))
        #     print("beta", beta.shape)
        if beta is None:
            beta = 0
        if previous_out is None:
            previous_out = 0

        # print("decoder", decoder_state.shape)
        # beta = torch.tanh(self.summary_layer(self.action_plan))
        # decoder_token_proj = torch.tanh(self.proj_pred(mask))
        # decoder_state + beta + proj_key underneath

        # Calculate scores.
        scores = self.energy_layer(
            torch.tanh(decoder_state + proj_key + beta + previous_out)
        )
        scores = scores.permute(0, 2, 1)

        # Mask out invalid positions.
        # The mask marks valid positions so we invert it using `mask & 0`.
        # not true here...
        # scores.data.masked_fill_(mask == 0, -float("inf"))

        # Turn scores to probabilities.
        # only take first column for calculating
        # alpha_t = softmax(A_t[0])
        alphas = F.softmax(scores[:, [0], :], dim=-1)
        self.action_plan = scores
        self.alphas = alphas

        # The context vector is the weighted sum of the values.

        # context shape: [B, 1, 2D], alphas shape: [B, 1, M]
        context = torch.bmm(alphas, value)
        return context, alphas, scores.data

Putting it all together

To put it all together, we have to make use of the shift operation

def shift(tensor, dim=1):
    """
    Shifts tensor along this dimension by one spot and fills in shifted
    with zeros; quickest way is to pad the target dim and then slice
    """
    pad = []
    for idx in range(len(tensor.shape)):
        if idx != dim:
            pad.extend([0, 0])
        else:
            pad.extend([1, 0])
    pad = pad[::-1]
    slice_ls = [slice(None) for _ in range(len(tensor.shape))]
    slice_ls[dim] = slice(1, None)
    tensor_pad = F.pad(tensor, pad=pad)
    return tensor_pad[slice_ls]

Rather than having some fancy for-loop, I implemented the switch through doing an additive sum across a binary representation of $g$ (that is I create a binary matrix which I then compute g*X[recompute] + (1-g)X[shift]). I find that this works “well enough” for my purposes.

class PAGAttention(nn.Module):
    """
    Performs full attention as documented (minus a couple of shortcuts for now...)
    """

    @staticmethod
    def shift(tensor, dim=1):
        """
        Shifts tensor along this dimension by one spot and fills in shifted
        with zeros; quickest way is to pad the target dim and then slice
        """
        pad = []
        for idx in range(len(tensor.shape)):
            if idx != dim:
                pad.extend([0, 0])
            else:
                pad.extend([1, 0])
        pad = pad[::-1]
        slice_ls = [slice(None) for _ in range(len(tensor.shape))]
        slice_ls[dim] = slice(1, None)
        tensor_pad = F.pad(tensor, pad=pad)
        return tensor_pad[slice_ls]

    @staticmethod
    def create_binary_repr(tensor, shape):
        for _ in range(len(shape) - 1):
            tensor = tensor.unsqueeze(1)
        shape = list(shape)
        shape[0] = -1
        return tensor.expand(*shape)

    def __init__(self, plan, attend, generate):
        super(PAGAttention, self).__init__()
        self.plan = plan
        self.attend = attend
        self.generate = generate

    def forward(
        self,
        query=None,
        proj_key=None,
        value=None,
        mask=None,
        beta=None,
        previous_out=None,
        prev_action_plan=None,
        prev_commit_plan=None,
    ):
        """
        The central tenent in this approach is to reduce the amount of data
        which flows into the downstream layers (assuming that the evaluation)
        is expensive.

        We will operate using the commitment plan here to skip - for implementation
        reasons we won't actually "skip" but instead zero out parts of the solution
        This should operate correctly as we don't add "biases" to our data.
        """

        if prev_commit_plan is not None:
            next_action_vector = prev_commit_plan[:, 0, 0].round()
        else:
            next_action_vector = None

        if prev_action_plan is not None:
            shift_action_plan = self.shift(prev_action_plan, 1)
            # print("-----------")
            # print(prev_action_plan.shape)
            # print(self.shift(prev_action_plan, 1).shape)
            # print(prev_action_plan[0, :, 0])
            # print(self.shift(prev_action_plan, 1)[0, :, 0])
            # print("-----------")

        if prev_commit_plan is not None:
            shift_commit_plan = self.shift(prev_commit_plan, 2)
            # print(prev_commit_plan.shape)
            # print(self.shift(prev_commit_plan, 2).shape)

        # TODO check the previous commit weights
        # and move everything along - no need to eval if not needed here...

        # this already gracefully handles scenario when beta is None or 0
        context, attn_probs, action_plan = self.attend(
            query=query,
            proj_key=proj_key,
            value=value,
            mask=mask,
            beta=beta,
            previous_out=previous_out,
        )

        # repeat alignment uses context - set it to be None for now.
        commit_weights = self.plan(
            query=query,
            proj_key=proj_key,
            value=value,
            mask=mask,
            beta=beta,
            previous_out=previous_out,
        )

        # if in repeat - then weights don't update! Simple!
        update_weights = self.generate(
            query=query,
            proj_key=proj_key,
            value=value,
            mask=mask,
            beta=beta,
            previous_out=previous_out,
        )
        if (
            update_weights is not None
            and action_plan is not None
            and prev_action_plan is not None
        ):
            # update_weights = update_weights.expand(-1, -1, action_plan.size(2))
            action_plan = (update_weights * action_plan) + (
                update_weights * prev_action_plan
            )

        # finally double check whether it really should be updated or not via
        if next_action_vector is not None:
            action_plan_binary = self.create_binary_repr(
                next_action_vector, prev_action_plan.shape
            )
            commit_binary = self.create_binary_repr(
                next_action_vector, commit_weights.shape
            )
            action_plan = (action_plan_binary * action_plan) + (
                (1 - action_plan_binary) * shift_action_plan
            )
            commit_weights = (commit_binary * commit_weights) + (
                (1 - commit_binary) * shift_commit_plan
            )

        return context, attn_probs, action_plan, commit_weights

Putting it all together

To put it altogether, it looks something like this:

# set up network...
attention = PAGAttention(
    plan=PAGPlan(hidden_size, seq),
    attend=PAGAttend(
        hidden_size,
        seq,
        summary_layer=lambda x: torch.tanh(
            nn.Linear(score_size[1], score_size[1], bias=False)(x)
        ),
        pred_embedding=lambda x: torch.tanh(
            nn.Linear(trg_y.size(1), hidden_size, bias=False)(x)
        ),
    ),
    generate=PAGGenerate(hidden_size, seq, score_size=score_size[1]),
)

net = EncoderDecoder(
    Encoder(emb_size, hidden_size, num_layers=num_layers, dropout=dropout),
    Decoder(emb_size, hidden_size, attention, num_layers=num_layers, dropout=dropout),
    ContinuousEmbedding(emb_size),
    ContinuousEmbedding(emb_size),
    Generator(hidden_size, pred_shape),
)

At a later stage I’ll try running this over a proper dataset to see the performance!

Update 21 Feb 2020

We can create a lighter weight version based on Tensorflow port and tutorial.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class Encoder(nn.Module):
    """Encodes a sequence of word embeddings
    """

    def __init__(self, input_size, hidden_size, embed_size=None, num_layers=1):
        super(Encoder, self).__init__()
        embed_size = hidden_size if embed_size is None else embed_size
        self.num_layers = num_layers
        self.embed = nn.Linear(input_size, embed_size)
        self.rnn = nn.GRU(
            embed_size, hidden_size, num_layers, batch_first=True, bidirectional=True
        )

    def forward(self, x, hidden=None):
        """
        Applies a bidirectional GRU to sequence of embeddings x.
        The input mini-batch x needs to be sorted by length.
        x should have dimensions [batch, time, dim].

        what does the util functions do? Ensure that all inputs have
        the same length!
        """
        x = self.embed(x)
        output, final = self.rnn(x, hidden)

        # final = final.permute(1, 0, 2)

        # we need to manually concatenate the final states for both directions
        fwd_final = final[0 : final.size(0) : 2]
        bwd_final = final[1 : final.size(0) : 2]
        final = torch.cat([fwd_final, bwd_final], dim=2)  # [num_layers, batch, 2*dim]
        final = final.permute(1, 0, 2)

        return output, final


class PlanAttention(nn.Module):
    """Implements "Plan Attend Generate" attention"""

    @staticmethod
    def shift(tensor, dim=1):
        """
        Shifts tensor along this dimension by one spot and fills in shifted
        with zeros; quickest way is to pad the target dim and then slice
        """
        pad = []
        for idx in range(len(tensor.shape)):
            if idx != dim:
                pad.extend([0, 0])
            else:
                pad.extend([1, 0])
        pad = pad[::-1]
        slice_ls = [slice(None) for _ in range(len(tensor.shape))]
        slice_ls[dim] = slice(1, None)
        tensor_pad = F.pad(tensor, pad=pad)
        return tensor_pad[slice_ls]

    @staticmethod
    def create_binary_repr(tensor, shape):
        for _ in range(len(shape) - 1):
            tensor = tensor.unsqueeze(1)
        shape = list(shape)
        shape[0] = -1
        return tensor.expand(*shape)

    def __init__(self, hidden_size, units, plan_length, output_size):
        super(PlanAttention, self).__init__()

        # We assume a bi-directional encoder so key_size is 2*hidden_size
        key_size = 2 * hidden_size
        self.plan_length = plan_length
        self.units = units

        self.w1 = nn.Linear(key_size, units * plan_length, bias=False)
        self.w2 = nn.Linear(key_size, units * plan_length, bias=False)
        self.w3 = nn.Linear(plan_length, 1)
        self.energy_layer = nn.Linear(units, 1, bias=False)
        self.action_plan_layer = nn.Linear(units, plan_length, bias=False)
        self.prev_decoder_layer = nn.Linear(output_size, units, bias=False)

        # commitment
        self.commit_layer = nn.Linear(key_size, plan_length, bias=False)

        # update weights
        self.update_w1 = nn.Linear(key_size, units * plan_length, bias=False)
        self.update_w2 = nn.Linear(key_size, units * plan_length, bias=False)
        self.update_context = nn.Linear(key_size, units * plan_length, bias=False)
        self.update_layer = nn.Linear(units, 1, bias=False)

    def forward(self, query=None, values=None, commit_plan=None, action_plan=None):
        # preliminary planning actions...
        if commit_plan is not None:
            next_action_vector = commit_plan[:, 0].round()
            shift_commit_plan = self.shift(commit_plan, 1)
        else:
            next_action_vector = None

        if action_plan is not None:
            shift_action_plan = self.shift(action_plan, 1)

        # hidden shape == (batch_size, hidden size)

        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        # the shape of the tensor before applying self.V is (batch_size, max_length, units)
        if action_plan is not None:
            beta = torch.tanh(self.w3(action_plan.permute(0, 2, 1)))
            beta = beta.unsqueeze(3)
        else:
            beta = 0

        # e_ij = a(s_i-1, h_j)
        w1 = torch.tanh(self.w1(query))
        w1 = w1.view(w1.shape[0], w1.shape[1], self.plan_length, self.units)
        w2 = torch.tanh(self.w2(values))
        w2 = w2.view(w2.shape[0], w2.shape[1], self.plan_length, self.units)

        scores = self.energy_layer(w1 + w2 + beta)
        scores = scores.squeeze(3).permute(0, 2, 1)

        # Turn scores to probabilities.
        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = F.softmax(scores[:, [0], :], dim=-1)
        self.alphas = attention_weights

        # The context vector is the weighted sum of the values.
        context = torch.bmm(attention_weights, values)

        # compute commitment
        commit_proba = F.log_softmax(self.commit_layer(query).squeeze(1), dim=-1)
        commit_plan = F.gumbel_softmax(commit_proba, 0.01)
        print("commit_plan", commit_plan.shape)

        # compute update gate
        update_w1 = torch.tanh(self.update_w1(query))
        update_w1 = update_w1.view(
            update_w1.shape[0], update_w1.shape[1], self.plan_length, self.units
        )
        update_w2 = torch.tanh(self.update_w2(values))
        update_w2 = update_w2.view(
            update_w2.shape[0], update_w2.shape[1], self.plan_length, self.units
        )
        update_context = torch.tanh(self.update_context(context))
        update_context = update_context.view(
            update_context.shape[0],
            update_context.shape[1],
            self.plan_length,
            self.units,
        )

        update_vector = self.update_layer(update_w1 + update_w2 + update_context)
        update_vector = update_vector.squeeze(3)
        update_vector = update_vector.permute(0, 2, 1)
        # context shape: [B, 1, 2 * hidden_size], alphas shape: [B, 1, max_length]

        # set action plan
        if action_plan is None:
            action_plan = scores
        else:
            action_plan = ((1 - update_vector) * shift_action_plan) + (
                update_vector * scores
            )

        # process all updates accordingly
        if next_action_vector is not None:
            # update all things selectively
            action_plan_binary = self.create_binary_repr(
                next_action_vector, shift_action_plan.shape
            )
            commit_binary = self.create_binary_repr(
                next_action_vector, shift_commit_plan.shape
            )

            action_plan = (action_plan_binary * action_plan) + (
                (1 - action_plan_binary) * shift_action_plan
            )
            commit_plan = (commit_binary * commit_plan) + (
                (1 - commit_binary) * shift_commit_plan
            )

        return context, attention_weights, action_plan, commit_plan


class Decoder(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_size,
        embed_size=None,
        num_layers=1,
        attention_config={},
    ):
        super(Decoder, self).__init__()
        embed_size = hidden_size if embed_size is None else embed_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embed = nn.Linear(output_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + hidden_size * 2,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=True,
        )
        self.attention = PlanAttention(**attention_config)
        self.projection = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x, hidden, enc_output, commit_plan=None, action_plan=None):
        x = self.embed(x)

        # enc_output shape == (batch_size, max_length, hidden_size)
        context_vector, attention_weights, action_plan, commit_plan = self.attention(
            hidden, enc_output, commit_plan, action_plan
        )

        x = torch.cat([x, context_vector], dim=-1)

        output, state = self.rnn(x)
        x = self.projection(output)
        return x, state, attention_weights, action_plan, commit_plan


enc = Encoder(10, 12, 13)
x = torch.from_numpy(np.random.normal(size=(6, 3, 10))).type(torch.FloatTensor)

# always the last decoder output (i.e. y_t)
d_x = torch.from_numpy(np.random.normal(size=(6, 1, 9))).type(torch.FloatTensor)

sample_output, sample_hidden = enc(x)
print("input", x.shape)
print("sample_output", sample_output.shape, "sample_hidden", sample_hidden.shape)

attention_layer = PlanAttention(12, 9, 5, 9)
attention_result, attention_weights, action_plan, commit_plan = attention_layer(
    sample_hidden, sample_output
)
print(
    "attention_result",
    attention_result.shape,
    "attention_weights",
    attention_weights.shape,
)

attention_result, attention_weights, action_plan, commit_plan = attention_layer(
    sample_hidden, sample_output, None, action_plan
)
print(
    "attention_result",
    attention_result.shape,
    "attention_weights",
    attention_weights.shape,
    "action_plan",
    action_plan.shape,
)


decoder = Decoder(
    10,
    9,
    12,
    14,
    attention_config=dict(hidden_size=12, units=9, plan_length=5, output_size=9),
)
(
    sample_decoder_output,
    decoder_state,
    attention_weights,
    action_plan,
    commit_plan,
) = decoder(d_x, sample_hidden, sample_output)
print("sample_decoder_output", sample_decoder_output.shape)
print(
    "decoder_state", decoder_state.shape, "attention_weights", attention_weights.shape
)

(
    sample_decoder_output,
    decoder_state,
    attention_weights,
    action_plan,
    commit_plan,
) = decoder(d_x, sample_hidden, sample_output, None, action_plan)
print("sample_decoder_output", sample_decoder_output.shape)
print(
    "decoder_state", decoder_state.shape, "attention_weights", attention_weights.shape
)