Pytorch RNN sequence tagging

Posted on 03/09/2017 in posts python machine learning deeplearning


In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import operator

import numpy as np

RNN intuition

Let us assume that we have an input $x = [x_1, x_2, ..., x_N]$ and we need to learn the mapping for some output $y = [y_1, y_2, ..., y_N]$, where $N$ is variable for each instance. In this case we can't just use a simple feed forward neural network which maps $x \rightarrow y$, as this will not work with variable length sequences. Furthermore, the number or parameters required for training such a network would be proportional to $size(x_i)*N$. This is a major memory cost. Additionally, if the sequence has some common mapping between $x_i$ and $y_i$, then we would be learning redundant weights for each pair in the sequence. This is where an RNN network is more useful. The basic idea is that each input $x_i$ is processed in a similar fashion using the same processing module and some additional context variable (which we will henseforth refer to as the hidden state). This hidden state should capture some information about the part of the sequence which has already been processed. Now at each step of the sequence we need to do the following:

  • Generate the output based on the previous hidden state and current input
  • Update the hidden state based on the previous hidden state and current input.

The order of the above steps is not fixed and forms the basis of many RNN spin-offs. What is important, at each step, is to have a new output and a new hidden state. Sometimes, the hidden state and the outputs are the same, to make the network smaller. But the core idea remains same. Below we would like to formalize the general intuition of an RNN module.

Initialize an initial hidden state $h_{0}$ with some initial value.

At timestep n: $$ \begin{equation} h^{'}_{i} = f(x_{i},h_{i})\\ y_{i} = g(x_{i},h^{'}_{i})\\ h_{i+1} = h^{'}_{i}\\ \end{equation} $$

Here $y_{i}$ is the output and $h^{'}_{i}$ is the intermediate hidden state.

In [2]:
class Input2Hidden(nn.Module):
    def __init__(self, x_dim, concat_layers=False):
        """Input2Hidden module
        
        Args:
            x_dim: input vector dimension
            concat_layers: weather to concat input and hidden layers or sum them
        """
        super(Input2Hidden, self).__init__()
        self.concat_layers = concat_layers
        input_dim = x_dim
        if self.concat_layers:
            input_dim = 2*x_dim
        self.linear_layer = nn.Linear(input_dim, x_dim)
        
    def forward(self, x, h):
        if self.concat_layers:
            cell_input = torch.cat([x,h], dim=1)
        else:
            cell_input = x + h
        assert isinstance(cell_input, Variable)
        logit = F.tanh(self.linear_layer(cell_input))
        return logit
    
    
class Hidden2Output(nn.Module):
    def __init__(self, x_dim, out_dim, concat_layers=False):
        """Hidden2Output module
        
        Args:
            x_dim: input vector dimension
            out_dim: output vector dimension
            concat_layers: weather to concat input and hidden layers or sum them
        """
        super(Hidden2Output, self).__init__()
        input_dim = x_dim
        self.concat_layers = concat_layers
        if self.concat_layers:
            input_dim = 2*x_dim
        self.linear_layer = nn.Linear(input_dim, out_dim)
        
    def forward(self, x, h):
        if self.concat_layers:
            cell_input = torch.cat([x,h], dim=1)
        else:
            cell_input = x + h
        assert isinstance(cell_input, Variable)
        logit = F.tanh(self.linear_layer(cell_input))
        return logit
    
    
class CustomRNNCell(nn.Module):
    def __init__(self, i2h, h2o):
        super(CustomRNNCell, self).__init__()
        self.i2h = i2h
        self.h2o = h2o
        
    def forward(self, x, h):
        assert isinstance(x, Variable)
        assert isinstance(h, Variable)
        h_prime = self.i2h(x,h)
        assert isinstance(h_prime, Variable)
        output = self.h2o(x,h_prime)
        return output, h_prime
    
class Model(nn.Module):
    def __init__(self, embedding, rnn_cell):
        super(Model, self).__init__()
        self.embedding = embedding
        self.rnn_cell = rnn_cell
        self.loss_function = nn.CrossEntropyLoss()
    
    def forward(self, word_ids, hidden=None):
        if hidden is None:
            hidden = Variable(torch.zeros(
                word_ids.data.shape[0],self.embedding.embedding_dim))
        assert isinstance(hidden, Variable)
        embeddings = self.embedding(word_ids)
        max_seq_length = word_ids.data.shape[-1]
        outputs, hidden_states = [], []
        for i in range(max_seq_length):
            x = embeddings[:, i, :]
            assert isinstance(x, Variable)
            #print("x={}\nhidden={}".format(x,hidden))
            output, hidden = self.rnn_cell(x, hidden)
            assert isinstance(output, Variable)
            assert isinstance(hidden, Variable)
            #print("output: {}, hidden: {}".format(output.data.shape, hidden.data.shape))
            outputs.append(output.unsqueeze(1))
            hidden_states.append(hidden.unsqueeze(1))
        outputs = torch.cat(outputs, 1)
        hidden_states = torch.cat(hidden_states, 1)
        assert isinstance(outputs, Variable)
        assert isinstance(hidden_states, Variable)
        return outputs, hidden_states
    
    def loss(self, word_ids, target_ids, hidden=None):
        outputs, hidden_states = self.forward(word_ids, hidden=hidden)
        outputs = outputs.view(-1, outputs.data.shape[-1])
        target_ids = target_ids.view(-1)
        assert isinstance(outputs, Variable)
        assert isinstance(target_ids, Variable)
        #print("output={}\ttargets={}".format(outputs.data.shape,target_ids.data.shape))
        loss = self.loss_function(outputs, target_ids)
        return loss    
        
        
    def predict(self, word_ids, hidden=None):
        outputs, hidden_states = self.forward(word_ids, hidden=hidden)
        outputs = outputs.view(-1, outputs.data.shape[-1])
        max_scores, predictions = outputs.max(1)
        predictions = predictions.view(*word_ids.data.shape)
        #print(word_ids.data.shape, predictions.data.shape)
        assert word_ids.data.shape == predictions.data.shape, "word_ids: {}, predictions: {}".format(
            word_ids.data.shape, predictions.data.shape
        )
        return predictions
        
        
def tensors2variables(*args, requires_grad=False):
    return tuple(map(lambda x: Variable(x, requires_grad=requires_grad), args))

def get_batch(tensor_types, *args, requires_grad=False):
    return tuple(map(lambda t,arg: Variable(t(arg), requires_grad=requires_grad), tensor_types, args))

Learning to predict bit flip

Let us take a simple example of using an RNN to predict the flip in bits of an $N$ bit unsigned integer. In python for an integer n represented using $N$ bits, the unsigned bitflip can be written as (~n) & ((1 << N)-1). Four our RNN each bit from the left will be $x_i$ and each flipped bit will be $y_i$. This task doesn't require any temporal dependencies but will be a good exercise to test the accuracy of RNN implementation. Theoretically, the network should learn to do this job perfectly in a few iterations. Later we will move to an example which does require the network to learn some temporal dependencies between inputs.

For our network we define $f(x,h)$ as a simple affine layer with $tanh$ activation, which takes the concatanated input $[x_i, h_{i-1}]$ and returns a new hidden state $h^{'}_{i}$. Similarly, we have $g(x_i, h^{'}_{n})$ also represented as an affine layer with $tanh$ activation, of the concatanation of its inputs $[x_i, h^{'}_{n}]$, resulting in a new output $y_{i}$. More formally, we have

$$ \begin{equation} h^{'}_{i} = f(x_{i},h_{i}) = \sigma([x_i, h_{i-1}]W_{i2h})\\ y_{i} = g(x_{i},h^{'}_{i}) = \sigma([x_i, h^{'}_{n}]W_{h2o})\\ h_{i+1} = h^{'}_{i}\\ \end{equation} $$
In [3]:
input_size=2
embedding_size=3
output_size=2

embedding = nn.Embedding(input_size, embedding_size)
f = Input2Hidden(embedding_size, concat_layers=True)
g = Hidden2Output(embedding_size, 2, concat_layers=True)
rnn_cell = CustomRNNCell(f,g)

model = Model(embedding, rnn_cell)
In [4]:
word_ids = [[0, 1, 0, 1, 0, 1]]
target_ids = [[1, 0, 1, 0, 1, 0]]

tensor_types = (torch.LongTensor, torch.LongTensor)
word_ids_tensor, target_ids_tensor = get_batch(tensor_types, word_ids, target_ids)
print(word_ids_tensor, target_ids_tensor)
Variable containing:
 0  1  0  1  0  1
[torch.LongTensor of size 1x6]
 Variable containing:
 1  0  1  0  1  0
[torch.LongTensor of size 1x6]

In [5]:
model.forward(torch.cat([word_ids_tensor, word_ids_tensor, word_ids_tensor], 0))[0]
Out[5]:
Variable containing:
(0 ,.,.) = 
  0.1159 -0.0933
 -0.4920  0.6325
  0.1634 -0.0420
 -0.5012  0.6430
  0.1758 -0.0425
 -0.5058  0.6449

(1 ,.,.) = 
  0.1159 -0.0933
 -0.4920  0.6325
  0.1634 -0.0420
 -0.5012  0.6430
  0.1758 -0.0425
 -0.5058  0.6449

(2 ,.,.) = 
  0.1159 -0.0933
 -0.4920  0.6325
  0.1634 -0.0420
 -0.5012  0.6430
  0.1758 -0.0425
 -0.5058  0.6449
[torch.FloatTensor of size 3x6x2]
In [6]:
model.predict(torch.cat([word_ids_tensor, word_ids_tensor, word_ids_tensor], 0))
Out[6]:
Variable containing:
 0  1  0  1  0  1
 0  1  0  1  0  1
 0  1  0  1  0  1
[torch.LongTensor of size 3x6]
In [7]:
loss = model.loss(word_ids_tensor, target_ids_tensor)
loss
Out[7]:
Variable containing:
 1.1108
[torch.FloatTensor of size 1]
In [8]:
loss.backward()
In [9]:
model.predict(word_ids_tensor)
Out[9]:
Variable containing:
 0  1  0  1  0  1
[torch.LongTensor of size 1x6]
In [10]:
def create_dataset(max_len=5):
    """Create a dataset of max_len bits and their flipped values
    
    Args:
        max_len: Maximum number of bits in the number
    """
    max_val = (1<<max_len)
    X, Y = [], []
    for i in range(max_val):
        x = "{0:0{1}b}".format(i,max_len)
        y = "{0:0{1}b}".format((~i) & max_val-1,max_len)
        
        x = tuple(map(int, x))
        y = tuple(map(int, y))
        X.append(x)
        Y.append(y)
    return X, Y
        
In [11]:
X, Y = create_dataset(max_len=10)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"
torch.Size([1024, 10]) torch.Size([1024, 10])
In [12]:
train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=8, shuffle=True)
In [13]:
%%time
learning_rate = 1e-4
max_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % 2 != 0:
        continue
        
    loss = model.loss(*tensors2variables(X_tensors, Y_tensors))
    Y_predict = model.predict(Variable(X_tensors)).data
    accuracy = (Y_tensors == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[002]: loss=0.982; accuracy=45.000%
Epoch[004]: loss=0.902; accuracy=50.000%
Epoch[006]: loss=0.827; accuracy=50.000%
Epoch[008]: loss=0.754; accuracy=50.000%
Epoch[010]: loss=0.681; accuracy=50.000%
Epoch[012]: loss=0.609; accuracy=50.000%
Epoch[014]: loss=0.536; accuracy=77.500%
Epoch[016]: loss=0.459; accuracy=100.000%
Epoch[018]: loss=0.383; accuracy=100.000%
Epoch[020]: loss=0.320; accuracy=100.000%
CPU times: user 32.2 s, sys: 1min 42s, total: 2min 15s
Wall time: 22.6 s
In [14]:
Y_predict = model.predict(X_batch)
Y_predict[:10].data
Out[14]:
    1     1     1     1     0     0     0     0     0     0
    1     0     0     1     1     1     1     1     0     0
    0     0     0     0     0     0     0     1     1     1
    1     1     0     1     1     0     0     1     0     0
    0     1     0     0     1     1     1     1     1     1
    0     0     1     0     1     1     0     1     0     0
    0     0     0     0     1     0     1     1     1     1
    1     1     1     0     0     0     1     1     1     0
[torch.LongTensor of size 8x10]

Learning to predict bit shift

This example requires some learning of temporal dependencies. We want to learn our network the output when the input $x$'s bits are shifted right by $K$ positions. This can be done using x >> K. Similarly, a left shift can be done using (a << K) & (1 << N) -1), where $N$ is the max length of the bit sequence.

In [15]:
def create_dataset(max_len=5, K=4):
    """Create a dataset of max_len bits and their flipped values
    
    Args:
        max_len: Maximum number of bits in the number
    """
    X, Y = [], []
    max_val = 2**max_len
    for i in range(max_val):
        x = "{0:0{1}b}".format(i,max_len)
        y = "{0:0{1}b}".format(i>>K,max_len)
        
        x = tuple(map(int, x))
        y = tuple(map(int, y))
        X.append(x)
        Y.append(y)
    return X, Y
        
In [16]:
input_size=2
embedding_size=3
output_size=2

embedding = nn.Embedding(input_size, embedding_size)
f = Input2Hidden(embedding_size, concat_layers=True)
g = Hidden2Output(embedding_size, 2, concat_layers=True)
rnn_cell = CustomRNNCell(f,g)

model = Model(embedding, rnn_cell)
In [17]:
X, Y = create_dataset(max_len=10, K=4)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"

train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=128, shuffle=True)
torch.Size([1024, 10]) torch.Size([1024, 10])
In [18]:
%%time
learning_rate = 1e-4
max_epochs = 10000
check_every=500
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % check_every != 0:
        continue
        
    loss = model.loss(*tensors2variables(X_tensors, Y_tensors))
    Y_predict = model.predict(Variable(X_tensors)).data
    accuracy = (Y_tensors == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[500]: loss=0.552; accuracy=76.133%
Epoch[1000]: loss=0.523; accuracy=78.555%
Epoch[1500]: loss=0.512; accuracy=78.896%
Epoch[2000]: loss=0.494; accuracy=79.600%
Epoch[2500]: loss=0.464; accuracy=79.600%
Epoch[3000]: loss=0.409; accuracy=83.193%
Epoch[3500]: loss=0.368; accuracy=85.518%
Epoch[4000]: loss=0.348; accuracy=87.441%
Epoch[4500]: loss=0.341; accuracy=87.500%
Epoch[5000]: loss=0.337; accuracy=87.500%
Epoch[5500]: loss=0.334; accuracy=87.500%
Epoch[6000]: loss=0.332; accuracy=87.500%
Epoch[6500]: loss=0.331; accuracy=87.500%
Epoch[7000]: loss=0.330; accuracy=87.500%
Epoch[7500]: loss=0.329; accuracy=87.500%
Epoch[8000]: loss=0.329; accuracy=87.500%
Epoch[8500]: loss=0.329; accuracy=87.500%
Epoch[9000]: loss=0.328; accuracy=87.500%
Epoch[9500]: loss=0.328; accuracy=87.500%
Epoch[10000]: loss=0.327; accuracy=87.500%
CPU times: user 19min 9s, sys: 57min 56s, total: 1h 17min 6s
Wall time: 12min 52s
In [19]:
Y_predict = model.predict(X_batch)
Y_predict[:10].data
Out[19]:
    0     0     0     0     1     1     1     1     1     1
    0     0     0     0     0     1     0     1     1     1
    0     0     0     0     1     0     1     1     1     1
    0     0     0     0     0     0     0     0     0     1
    0     0     0     0     0     0     1     0     1     0
    0     0     0     0     1     1     1     1     1     1
    0     0     0     0     0     0     0     1     0     1
    0     0     0     0     0     1     0     1     0     1
    0     0     0     0     0     1     0     1     0     1
    0     0     0     0     1     0     1     0     1     0
[torch.LongTensor of size 10x10]

Too slow to learn

The network trained above takes more than 10,000 epochs to converge to only $90\%$ accuracy. This reflects a major shortcoming of general RNN's. The shortcoming comes from a problem known as vanishing gradients, where gradients based on more distant steps become numerically too small to update the current layer, leading to information loss and failure to learn long range dependencies. Researcher's have worked around this using what is known as gated or memory based RNN cell's, which allows the information to be stored for a longer duration in the network and the gradients from long range dependencies to flow more easily. Two of the most popular variants are Long Short Term Memory (LSTM) cells and Gated Recurrent Unit (GRU) cells. The core idea is to allow some memory of the current state to be stored for the long time either in a seperate memory cell or in the hidden state. This is usually done by selectively reading and editing from the memory based on the current step. In the following sections we will understand the GRU cells which are a very simple extension of RNN and solve the vanishing gradient problem. The LSTM cells are a bit more involved and will be discussed later.

Gated Recurrent Unit (GRU)

The idea behind GRU's is to update part of the hidden state and retain the rest. This is done using the following functions:

  • reset gate - Identifies what proportion of hidden state should be reset
  • update gate - Identifies what proportion of hidden state should be updated

The implementation is as follow:

$$ \begin{equation} reset = \sigma(W_{r}[x_i, h_{i-1}])\\ update = \sigma(W_{u}[x_i, h_{i-1}])\\ interim\_hidden = tanh(W_{i}[x_i, reset \circ h_{i-1}])\\ h^{'}_{i} = update \circ interim\_hidden + (1-update) \circ h_{i-1} \\ \end{equation} $$
In [20]:
class GRUCell(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GRUCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.reset_linear = nn.Linear(2*self.input_dim, self.input_dim)
        self.update_linear = nn.Linear(2*self.input_dim, self.input_dim)
        self.interim_linear = nn.Linear(2*self.input_dim, self.input_dim)
        
        self.output_linear = nn.Linear(2*self.input_dim, output_dim)
        
    def forward(self, x, h):
        concat_tensors = torch.cat([x,h], dim=1)
        reset = F.sigmoid(self.reset_linear(concat_tensors))
        update = F.sigmoid(self.update_linear(concat_tensors))
        reset_hidden = reset * h
        concat_reset_hidden = torch.cat([x, reset_hidden], dim=1)
        interim_hidden = F.tanh(self.interim_linear(concat_reset_hidden))
        h_prime = update * interim_hidden + (1-update) * h
        
        concat_out = torch.cat([x, h_prime], dim=1)
        output = F.tanh(self.output_linear(concat_out))
        return output, h_prime
        
In [21]:
input_size=2
embedding_size=3
output_size=2

embedding = nn.Embedding(input_size, embedding_size)
rnn_cell = GRUCell(embedding_size, output_size)

model = Model(embedding, rnn_cell)
In [22]:
X, Y = create_dataset(max_len=10, K=4)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"

train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=64, shuffle=True)
torch.Size([1024, 10]) torch.Size([1024, 10])
In [23]:
%%time
learning_rate = 1e-4
max_epochs = 3000
check_every=500
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % check_every != 0:
        continue
        
    loss = model.loss(*tensors2variables(X_tensors, Y_tensors))
    Y_predict = model.predict(Variable(X_tensors)).data
    accuracy = (Y_tensors == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[500]: loss=0.412; accuracy=83.740%
Epoch[1000]: loss=0.350; accuracy=86.836%
Epoch[1500]: loss=0.302; accuracy=91.025%
Epoch[2000]: loss=0.279; accuracy=92.422%
Epoch[2500]: loss=0.270; accuracy=93.242%
Epoch[3000]: loss=0.264; accuracy=93.506%
CPU times: user 21min 12s, sys: 1h 6min 27s, total: 1h 27min 39s
Wall time: 14min 38s
In [24]:
model
Out[24]:
Model (
  (embedding): Embedding(2, 3)
  (rnn_cell): GRUCell (
    (reset_linear): Linear (6 -> 3)
    (update_linear): Linear (6 -> 3)
    (interim_linear): Linear (6 -> 3)
    (output_linear): Linear (6 -> 2)
  )
  (loss_function): CrossEntropyLoss (
  )
)

Running using GPU

This makes things go faster.

In [25]:
input_size=2
embedding_size=3
output_size=2

embedding = nn.Embedding(input_size, embedding_size)
rnn_cell = GRUCell(embedding_size, output_size)

model = Model(embedding, rnn_cell).cuda()
In [26]:
X, Y = create_dataset(max_len=10, K=4)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"

batch_size=64
train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
torch.Size([1024, 10]) torch.Size([1024, 10])
In [27]:
overall_hidden = Variable(torch.zeros(1,model.embedding.embedding_dim))
overall_hidden = overall_hidden.cuda()
overall_hidden
Out[27]:
Variable containing:
 0  0  0
[torch.cuda.FloatTensor of size 1x3 (GPU 0)]
In [28]:
%%time
learning_rate = 1e-4
max_epochs = 3000
check_every=500
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch = X_batch.cuda(async=True)
        Y_batch = Y_batch.cuda(async=True)
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        hidden = overall_hidden.repeat(batch_size, 1)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch, hidden=hidden)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % check_every != 0:
        continue
    
    hidden = overall_hidden.repeat(X_tensors.shape[0], 1)
    loss = model.loss(*tensors2variables(X_tensors.cuda(), Y_tensors.cuda()),
                      hidden=hidden)
    Y_predict = model.predict(Variable(X_tensors.cuda()), hidden=hidden).data
    accuracy = (Y_tensors.cuda() == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[500]: loss=0.406; accuracy=84.111%
Epoch[1000]: loss=0.320; accuracy=88.711%
Epoch[1500]: loss=0.285; accuracy=91.377%
Epoch[2000]: loss=0.270; accuracy=92.842%
Epoch[2500]: loss=0.262; accuracy=93.301%
Epoch[3000]: loss=0.255; accuracy=93.975%
CPU times: user 18min 6s, sys: 3.28 s, total: 18min 9s
Wall time: 18min 11s

It looks like the GPU version is actually a bit slower in this case. This might be due to the small size of our dataset, for which the cost of moving tensors to GPU is greater than the gain by speeding up network computations.

Increasing the network capacity

This can be done by increasing the hidden units in the network. Or in our case by increasing the embedding size as that is used to derive the number of hidden units.

In [29]:
input_size=2
embedding_size=10
output_size=2

embedding = nn.Embedding(input_size, embedding_size)
rnn_cell = GRUCell(embedding_size, output_size)

model = Model(embedding, rnn_cell).cuda()
In [30]:
X, Y = create_dataset(max_len=10, K=4)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"

batch_size=64
train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
torch.Size([1024, 10]) torch.Size([1024, 10])
In [31]:
overall_hidden = Variable(torch.zeros(1,model.embedding.embedding_dim))
overall_hidden = overall_hidden.cuda()
overall_hidden
Out[31]:
Variable containing:
    0     0     0     0     0     0     0     0     0     0
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]
In [32]:
%%time
learning_rate = 1e-4
max_epochs = 1000
check_every=50
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch = X_batch.cuda(async=True)
        Y_batch = Y_batch.cuda(async=True)
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        hidden = overall_hidden.repeat(batch_size, 1)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch, hidden=hidden)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % check_every != 0:
        continue
    
    hidden = overall_hidden.repeat(X_tensors.shape[0], 1)
    loss = model.loss(*tensors2variables(X_tensors.cuda(), Y_tensors.cuda()),
                      hidden=hidden)
    Y_predict = model.predict(Variable(X_tensors.cuda()), hidden=hidden).data
    accuracy = (Y_tensors.cuda() == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[050]: loss=0.642; accuracy=65.000%
Epoch[100]: loss=0.605; accuracy=70.000%
Epoch[150]: loss=0.521; accuracy=70.010%
Epoch[200]: loss=0.426; accuracy=82.100%
Epoch[250]: loss=0.376; accuracy=86.182%
Epoch[300]: loss=0.344; accuracy=88.877%
Epoch[350]: loss=0.293; accuracy=93.311%
Epoch[400]: loss=0.229; accuracy=96.768%
Epoch[450]: loss=0.188; accuracy=98.701%
Epoch[500]: loss=0.157; accuracy=99.805%
Epoch[550]: loss=0.143; accuracy=99.971%
Epoch[600]: loss=0.136; accuracy=100.000%
Epoch[650]: loss=0.131; accuracy=100.000%
Epoch[700]: loss=0.130; accuracy=100.000%
Epoch[750]: loss=0.129; accuracy=100.000%
Epoch[800]: loss=0.128; accuracy=100.000%
Epoch[850]: loss=0.128; accuracy=100.000%
Epoch[900]: loss=0.127; accuracy=100.000%
Epoch[950]: loss=0.127; accuracy=100.000%
Epoch[1000]: loss=0.127; accuracy=100.000%
CPU times: user 6min 5s, sys: 1.22 s, total: 6min 6s
Wall time: 6min 7s

As we can see, the network converges 10x quicker than the one with lower capacity and also achieves $100\%$ accuracy in just 600 epochs. This is a very useful result, as it shows that in order to learn more complex functionalities we need networks with larger capacities as well as computationally efficient structures. Luckily for us many of the standard functionalities, are usually implemented efficiently in neural network libraries. Pytorch implements many of the standard neural network modules efficiently using it's C code, which can give us an order of magniture of improvement (especially for larger networks). These modules include GRU cells and a GRU module which can process the whole sequence. We will look at these in detail below.

Using Pytorch's GRUCell

Let us check our implementation using the Pytorch's inbuild GRU cell

In [33]:
class PytorchGRUCell(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PytorchGRUCell, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.gru = torch.nn.GRUCell(self.input_dim, self.input_dim)
        
        self.output_linear = nn.Linear(2*self.input_dim, self.output_dim)
        
    def forward(self, x, h):
        h_prime = self.gru(x,h)
        
        concat_out = torch.cat([x, h_prime], dim=1)
        output = F.tanh(self.output_linear(concat_out))
        return output, h_prime
In [34]:
input_size=2
embedding_size=10
output_size=2

embedding = nn.Embedding(input_size, embedding_size)
rnn_cell = PytorchGRUCell(embedding_size, output_size)

model = Model(embedding, rnn_cell).cuda()
In [35]:
X, Y = create_dataset(max_len=10, K=4)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"

batch_size=64
train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
torch.Size([1024, 10]) torch.Size([1024, 10])
In [36]:
overall_hidden = Variable(torch.zeros(1,model.embedding.embedding_dim))
overall_hidden = overall_hidden.cuda()
overall_hidden
Out[36]:
Variable containing:
    0     0     0     0     0     0     0     0     0     0
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]
In [37]:
%%time
learning_rate = 1e-4
max_epochs = 1000
check_every=50
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch = X_batch.cuda(async=True)
        Y_batch = Y_batch.cuda(async=True)
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        hidden = overall_hidden.repeat(batch_size, 1)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch, hidden=hidden)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % check_every != 0:
        continue
    
    hidden = overall_hidden.repeat(X_tensors.shape[0], 1)
    loss = model.loss(*tensors2variables(X_tensors.cuda(), Y_tensors.cuda()),
                      hidden=hidden)
    Y_predict = model.predict(Variable(X_tensors.cuda()), hidden=hidden).data
    accuracy = (Y_tensors.cuda() == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[050]: loss=0.564; accuracy=70.000%
Epoch[100]: loss=0.450; accuracy=80.488%
Epoch[150]: loss=0.380; accuracy=86.357%
Epoch[200]: loss=0.334; accuracy=89.805%
Epoch[250]: loss=0.292; accuracy=92.783%
Epoch[300]: loss=0.259; accuracy=94.795%
Epoch[350]: loss=0.237; accuracy=96.182%
Epoch[400]: loss=0.221; accuracy=96.670%
Epoch[450]: loss=0.190; accuracy=97.900%
Epoch[500]: loss=0.148; accuracy=99.990%
Epoch[550]: loss=0.135; accuracy=100.000%
Epoch[600]: loss=0.131; accuracy=100.000%
Epoch[650]: loss=0.130; accuracy=100.000%
Epoch[700]: loss=0.129; accuracy=100.000%
Epoch[750]: loss=0.128; accuracy=100.000%
Epoch[800]: loss=0.128; accuracy=100.000%
Epoch[850]: loss=0.127; accuracy=100.000%
Epoch[900]: loss=0.127; accuracy=100.000%
Epoch[950]: loss=0.127; accuracy=100.000%
Epoch[1000]: loss=0.127; accuracy=100.000%
CPU times: user 3min 10s, sys: 948 ms, total: 3min 11s
Wall time: 3min 11s

Great, this implementation is almost 2x times faster than our implementation, probably because it is written using the C backend.

Using Pytorch GRU module

In [38]:
torch.cat([torch.zeros(1,2,3), torch.ones(1,2,3)], 2)
Out[38]:
(0 ,.,.) = 
  0  0  0  1  1  1
  0  0  0  1  1  1
[torch.FloatTensor of size 1x2x6]
In [39]:
class PyTorchModel(nn.Module):
    def __init__(self, input_size, embedding_size, output_size):
        super(PyTorchModel, self).__init__()
        
        self.input_size = input_size
        self.embedding_size = embedding_size
        self.output_size = output_size
        
        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.GRU(embedding_size, embedding_size)
        self.output_linear = nn.Linear(2*self.embedding_size, self.output_size)
        
        self.loss_function = nn.CrossEntropyLoss()
    
    def forward(self, word_ids, hidden=None):
        if hidden is None:
            hidden = Variable(torch.zeros(
                word_ids.data.shape[0],self.embedding.embedding_dim))
        assert isinstance(hidden, Variable)
        embeddings = self.embedding(word_ids)
        max_seq_length = word_ids.data.shape[-1]
        ## RNN input and output shapes are (seq_len, batch_size, input_size)
        embeddings = embeddings.permute(1,0,2)
        hidden_states, hidden = self.rnn(embeddings)
        
        concat_tensors = torch.cat([embeddings, hidden_states], 2)
        concat_tensors = concat_tensors.permute(1,0,2).contiguous()
        concat_tensors = concat_tensors.view(-1, concat_tensors.data.shape[2])
        outputs = self.output_linear(concat_tensors)
        
        hidden_states = hidden_states.permute(1,0,2)
        outputs = outputs.view(self.input_size, -1, self.output_size)
        assert isinstance(outputs, Variable)
        assert isinstance(hidden_states, Variable)
        return outputs, hidden_states
    
    def loss(self, word_ids, target_ids, hidden=None):
        outputs, hidden_states = self.forward(word_ids, hidden=hidden)
        outputs = outputs.view(-1, outputs.data.shape[-1])
        target_ids = target_ids.view(-1)
        assert isinstance(outputs, Variable)
        assert isinstance(target_ids, Variable)
        #print("output={}\ttargets={}".format(outputs.data.shape,target_ids.data.shape))
        loss = self.loss_function(outputs, target_ids)
        return loss    
        
        
    def predict(self, word_ids, hidden=None):
        outputs, hidden_states = self.forward(word_ids, hidden=hidden)
        outputs = outputs.view(-1, outputs.data.shape[-1])
        max_scores, predictions = outputs.max(1)
        predictions = predictions.view(*word_ids.data.shape)
        #print(word_ids.data.shape, predictions.data.shape)
        assert word_ids.data.shape == predictions.data.shape, "word_ids: {}, predictions: {}".format(
            word_ids.data.shape, predictions.data.shape
        )
        return predictions
In [40]:
input_size=2
embedding_size=10
output_size=2

model = PyTorchModel(input_size, embedding_size, output_size).cuda()
In [41]:
X, Y = create_dataset(max_len=10, K=4)
X_tensors, Y_tensors = tuple(map(torch.LongTensor, [X, Y]))
print(X_tensors.shape, Y_tensors.shape)
assert X_tensors.shape == Y_tensors.shape, "X and Y should be of same shape"

batch_size=64
train = data_utils.TensorDataset(X_tensors, Y_tensors)
train_loader = data_utils.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
torch.Size([1024, 10]) torch.Size([1024, 10])
In [42]:
overall_hidden = Variable(torch.zeros(1,model.embedding.embedding_dim))
overall_hidden = overall_hidden.cuda()
overall_hidden
Out[42]:
Variable containing:
    0     0     0     0     0     0     0     0     0     0
[torch.cuda.FloatTensor of size 1x10 (GPU 0)]
In [43]:
%%time
learning_rate = 1e-4
max_epochs = 1000
check_every=50
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, max_epochs+1):
    for X_batch, Y_batch in train_loader:
        X_batch = X_batch.cuda(async=True)
        Y_batch = Y_batch.cuda(async=True)
        X_batch, Y_batch = tensors2variables(X_batch, Y_batch)
        hidden = overall_hidden.repeat(batch_size, 1)
        # Forward pass: compute predicted y by passing x to the model.
        loss = model.loss(X_batch, Y_batch, hidden=hidden)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights
        # of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model
        # parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its
        # parameters
        optimizer.step()
    if epoch % check_every != 0:
        continue
    
    hidden = overall_hidden.repeat(X_tensors.shape[0], 1)
    loss = model.loss(*tensors2variables(X_tensors.cuda(), Y_tensors.cuda()),
                      hidden=hidden)
    Y_predict = model.predict(Variable(X_tensors.cuda()), hidden=hidden).data
    accuracy = (Y_tensors.cuda() == Y_predict).sum() /operator.mul(*Y_tensors.shape) * 100.
    print("Epoch[{:03d}]: loss={:5.3f}; accuracy={:.3f}%".format(epoch, loss.data[0], accuracy))
Epoch[050]: loss=0.580; accuracy=70.000%
Epoch[100]: loss=0.453; accuracy=77.949%
Epoch[150]: loss=0.342; accuracy=84.902%
Epoch[200]: loss=0.255; accuracy=88.398%
Epoch[250]: loss=0.193; accuracy=94.902%
Epoch[300]: loss=0.136; accuracy=96.455%
Epoch[350]: loss=0.074; accuracy=99.297%
Epoch[400]: loss=0.032; accuracy=99.990%
Epoch[450]: loss=0.016; accuracy=100.000%
Epoch[500]: loss=0.009; accuracy=100.000%
Epoch[550]: loss=0.005; accuracy=100.000%
Epoch[600]: loss=0.003; accuracy=100.000%
Epoch[650]: loss=0.002; accuracy=100.000%
Epoch[700]: loss=0.001; accuracy=100.000%
Epoch[750]: loss=0.001; accuracy=100.000%
Epoch[800]: loss=0.001; accuracy=100.000%
Epoch[850]: loss=0.000; accuracy=100.000%
Epoch[900]: loss=0.000; accuracy=100.000%
Epoch[950]: loss=0.000; accuracy=100.000%
Epoch[1000]: loss=0.000; accuracy=100.000%
CPU times: user 57.6 s, sys: 952 ms, total: 58.5 s
Wall time: 58.7 s

This is super fast, 3x faster than using the GRUCell and 6x faster than our implementation.

This concludes our introduction to sequence tagging using Pytorch. The example covered here were very small so as to demonstrate the code required to implement a neural network as well as to give an intuition about the kind of tasks the networks can handle. More complex models can be built on top of this demo, which can handle variable length sequences, complex inference process (e.g. Linear Chain Conditional Random Fields for predicting the best sequence of outputs), and complex handling of input like words, phrases, etc.

In [44]:
print("Pytorch Version: {}".format(torch.__version__))
Pytorch Version: 0.2.0_4
In [ ]: