Generate batches for RNN

This week, I encountered implementations of a function generating batches for a recurrent neural network. The first one was:

def get_batches(arr, batch_size, n_steps):
    chars_per_batch = batch_size * n_steps
    n_batches = len(arr)//chars_per_batch

    arr = arr[:n_batches * chars_per_batch]
    arr = arr.reshape((batch_size, -1))

    for n in range(0, arr.shape[1], n_steps):
        x = arr[:, n:n+n_steps]
        y = np.zeros_like(x)
        y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]

        yield x, y

Let me use an example. Suppose that we make batches from the sentence The quick brown fox jumps over the lazy dog. Let’s say the variable n_steps is 5. Then, the above function generates batched as follows:
In the first batch, x and ‘y’ become ['T', 'h', 'e', ' ', 'q'] and ['h', 'e', ' ', 'q', 'T'], respectively. And the second batch will be x = ['u', 'i', 'c', 'k', ' '] and y = ['i', 'c', 'k', ' ', 'u'].

So, the above function shifts the elements of x by one position to fill y and puts the first element of x into the last element of y. However, one may want to feed the real next character to y at the end of a batch. For the first batch, u instead of T, and for the second batch, b instead of u. This seems to happen in the following implementation:

def get_batches(arr, batch_size, n_steps):
    chars_per_batch = batch_size * n_steps
    n_batches = len(arr)//chars_per_batch

    arr = arr[:n_batches * chars_per_batch]
    arr = arr.reshape((batch_size, -1))

    for n in range(0, arr.shape[1], n_steps):
        x = arr[:, n:n+n_steps]

        y_temp = arr[:, n+1:n+n_steps+1]

        y = np.zeros(x.shape, dtype=x.dtype)
        y[:,:y_temp.shape[1]] = y_temp

        yield x, y

This looks okay, except for the very last element of arr. After reshaped, the number of columns of arr or arr.shape[1] should be n_batches * n_steps. At the last iteration of the for loop, n is supposed to be arr.shape[1] - n_steps - 1. So, x can be filled with the last batch. Then, for y_temp, it tries to address the arr.shape[1]th element, which is not possible. Interestingly, I don’t get any errors.1

I don’t understand how python doesn’t raise any errors with the second implementation, considering that most of errors I get are related to the addressing and slicing of arrays. Finally, I have found a better implementation from someone’s github.

def get_batches(arr, batch_size, n_steps):
    chars_per_batch = batch_size * n_steps
    n_batches = len(arr)//chars_per_batch

    arr = arr[:n_batches * chars_per_batch]
    arr = arr.reshape((batch_size, -1))

    for n in range(0, arr.shape[1], n_steps):
        x = arr[:, n:n+n_steps]

        y = np.zeros_like(x)

        try:
            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+n_steps]
        except IndexError:
            y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]

        y[:,:y_temp.shape[1]] = y_temp

        yield x, y

This implementation takes care of a potential error at the last batch. And I noticed that it feeds the very first element of arr into the last element of the last batch. In the second implementation, the element is left to be zero.


  1. I dug a bit and learned that n+1:n+n_steps+1 is treated as a slice object. I guess that it works like a generator and that it is designed to return an empty array [] if there is no corresponding element. Or I would say that it implicitly takes care of exceptions. 
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s