Hello and welcome to part 3 of our series of experimenting with neural networks. In this tutorial, we're going to take the same generative model that we've been working with, but now play with the MNIST dataset... in a way you probably wont see anywhere else.
My first wonder is if we can make a generative model just simply generate numbers. Then, if so, can we get a generative model to generate the class after seeing the model (using the generative model to classify?) ... and then, finally, could we get the generative model to draw us a number we ask for? Let's see!
Step 1, we need to compile a training dataset. This is all experimental, and I really do not know the BEST way to do this.
To grab mnist, we can just do:
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) batch_xs, batch_ys = mnist.train.next_batch(100) data = batch_xs[0] label = batch_ys[0]
In the above case, data will be the pixel values, and the label is the one-hot array of the label, which tells us which number this is. The MNIST dataset here has mnist.train
, mnist.test
, and mnist.validate
. We can train the model with mnist.train, and then see how we did with the validate. We don't need to use the mnist.test
, since this is a generative model. If we wanted to, we could throw it in the training set.
Moving along, let's visualize this a bit. Our data:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.09019608 0.46274513 0.76470596 0.9960785 0.3019608 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.10588236 0.78823537 0.9921569 0.9921569 0.9921569 0.9607844 0.12156864 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.26666668 0.7803922 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.13333334 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.44705886 0.9686275 0.9921569 0.95294124 0.77647066 0.47058827 0.2392157 0.2392157 0.03137255 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.07058824 0.9725491 0.9921569 0.9568628 0.4039216 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.03529412 0.8000001 0.9921569 0.9921569 0.4039216 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.33333334 0.9921569 0.9921569 0.8117648 0.09019608 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.13333334 0.9843138 0.9921569 0.80392164 0.3647059 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.61960787 0.9921569 0.9921569 0.56078434 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.6745098 0.9921569 0.9921569 0.56078434 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.16078432 0.9058824 0.9921569 0.9921569 0.2509804 0.19215688 0.19215688 0.19215688 0.16862746 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.21960786 0.9921569 0.9921569 0.9921569 0.80392164 0.9921569 0.9921569 0.9921569 0.9568628 0.6313726 0.05882353 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.21960786 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.7843138 0.04313726 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.21960786 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.7843138 0.06666667 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.12941177 0.86274517 0.9921569 0.9921569 0.9921569 0.85098046 0.5882353 0.6627451 0.9921569 0.9921569 0.9921569 0.9921569 0.6431373 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.6745098 0.9921569 0.9921569 0.9921569 0.9450981 0.86274517 0.8862746 0.9921569 0.9921569 0.9921569 0.9921569 0.6666667 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.21960786 0.8941177 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.5372549 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.21568629 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.9921569 0.61960787 0.03921569 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.02352941 0.37647063 0.9333334 0.9921569 0.9921569 0.9921569 0.9921569 0.9450981 0.909804 0.04705883 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.20000002 0.454902 0.454902 0.5568628 0.6392157 0.19607845 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. ]
And the label:
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
Each index value is mapped to the number. In this case, that 1 is at index 6, so this number is a six. We can bring in matplotlib to graph this:
from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt import numpy as np mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # mnist.train, mnist.test, mnist.validation batch_xs, batch_ys = mnist.train.next_batch(100) data = batch_xs[0] label = batch_ys[0] pixels = data.reshape((28,28)) print(data) print(label) plt.imshow(pixels) plt.show()
The first thing we need to do is drastically simplify this. We can also change the cmap to gray:
plt.imshow(pixels, cmap='gray')
Giving us more of what we'd expect for single-valued pixel data:
Great! Now, let's try to threshold this. Since we're using a character-level generative network, every character is costing us processing. We'd like to avoid long floats. Ideally, we'd be int only. Unfortunately, if we just do
pixels.astype(int)
Things aren't going to go well for us. We need a better rounding method, since just converting to an int is taking anything under 1 and making it 0. Luckily, there's numpy's rint
, so we can instead try:
plt.imshow(np.rint(pixels).astype(int), cmap='gray')
Giving us:
Alright, not too hard. Let's use that code and see how our array is looking with the following updated code:
from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt import numpy as np mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) # mnist.train, mnist.test, mnist.validation batch_xs, batch_ys = mnist.train.next_batch(100) data = np.rint(batch_xs[0]).astype(int) label = batch_ys[0] pixels = data.reshape((28,28)) print(data) print(label) plt.imshow(pixels, cmap='gray') plt.show()
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
What if we print pixels out? We should actually be able to tell in the values alone what the number is. After printing data and label, print out pixels print(pixels)
The pixels print is:
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
Can you tell what that is? I can definitely see it's a 5, which it is: [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
Let's convert the label array to be what we want to:
label = np.rint(batch_ys[0]).astype(int)
Now an example of the labels is:
[0 0 0 0 0 0 0 0 0 1]
Which means we're looking at a 9.
At this point, we've got a pretty simple, scaled down, version of mnist that we could in theory begin to feed through a generative network, but we need to probably consider other things that we could do to make this easier. Also, in order to get it to draw us the number we want, how might we go about doing this with our data?
Does our training data need to even be an array? Do we really have any constraints at all? Well, our training data *actually* is eventually going to be a string, it wont be an array, so we don't actually require it to be an array at all.
These are all considerations that we're going to be handling in the next part of this tutorial!