""""
File: cnn_mnist_class.py
Date: 2019-01-31 11:30
Author: Amy
Use tensorflow to implement cnn.
Dataset: minist (train/test)
"""
import os
import tensorflow as tf
import utils
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
class ConvNet(object):
def __init__(self):
self.batch_size = 4
self.learning_rate = 0.001
self.n_epochs = 5
self.n_train = 55000
self.n_valid = 5000
self.n_test = 10000
self.train_init, self.valid_init, self.test_init = None, None, None
self.x, self.y_true, self.y_pred = None, None, None
self.logits = None
self.n_correct = None
self.loss = None
self.opt = None
self.training = False
self.summary_loss, self.summary_acc = None, None
self.global_step = tf.get_variable("global_step", initializer=tf.constant(0), trainable=False)
""" Build the graph. """
@staticmethod
def conv_relu(inputs, n_filters, filter_size, stride, padding, scope_name):
""" Conv - Relu layer. """
with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
in_channels = inputs.shape[-1]
filter_wgts = tf.get_variable("filter_weight",
shape=[filter_size, filter_size, in_channels, n_filters],
initializer=tf.truncated_normal_initializer())
filter_biases = tf.get_variable("filter_biases",
shape=[n_filters],
initializer=tf.random_normal_initializer())
conv = tf.nn.conv2d(inputs,
filter_wgts,
strides=[1, stride, stride, 1],
padding=padding)
return tf.nn.relu(conv + filter_biases, name=scope.name)
@staticmethod
def max_pooling(inputs, pool_size, stride, padding="VALID", scope_name="pool"):
""" Max Pooling layer. """
with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
pool = tf.nn.max_pool(inputs,
ksize=[1, pool_size, pool_size, 1],
strides=[1, stride, stride, 1],
padding=padding)
return pool
@staticmethod
def fully_connected(inputs, out_dim, scope_name="fc"):
""" Fully connected layer.
inputs is a streched out 1-d Tensor.
"""
in_dim = inputs.shape[-1]
with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
wgts = tf.get_variable("fc_weight",
shape=[in_dim, out_dim],
initializer=tf.random_normal_initializer())
biases = tf.get_variable("fc_bais",
shape=[out_dim],
initializer=tf.zeros_initializer())
out = tf.matmul(inputs, wgts) + biases
return out
def _import_data(self):
""" Prepare data. """
with tf.name_scope("data_scope"):
mnist_folder = "data/mnist"
utils.download_mnist(mnist_folder)
train, valid, test = utils.read_mnist(mnist_folder, flatten=True)
train_data = tf.data.Dataset.from_tensor_slices(train)
train_data = train_data.shuffle(buffer_size=10000)
train_data = train_data.batch(self.batch_size)
valid_data = tf.data.Dataset.from_tensor_slices(valid)
valid_data = valid_data.batch(self.batch_size)
test_data = tf.data.Dataset.from_tensor_slices(test)
test_data = test_data.batch(self.batch_size)
iterator = tf.data.Iterator.from_structure(train_data.output_types, \
train_data.output_shapes)
self.train_init = iterator.make_initializer(train_data)
self.valid_init = iterator.make_initializer(valid_data)
self.test_init = iterator.make_initializer(test_data)
self.x, self.y_true = iterator.get_next()
self.x = tf.reshape(self.x, [-1, 28, 28, 1])
print("x shape: {}, y shape: {}".format(self.x.shape, self.y_true.shape))
def inference(self):
""" Model structure. """
conv1 = self.conv_relu(inputs=self.x,
n_filters=32,
filter_size=5,
stride=1,
padding="SAME",
scope_name="conv1")
pool1 = self.max_pooling(inputs=conv1,
pool_size=2,
stride=2,
padding="VALID",
scope_name="pool1")
conv2 = self.conv_relu(inputs=pool1,
n_filters=64,
filter_size=5,
stride=1,
padding="SAME",
scope_name="conv2")
pool2 = self.max_pooling(inputs=conv2,
pool_size=2,
stride=2,
padding="VALID",
scope_name="pool2")
in_dim = pool2.shape[1] * pool2.shape[2] * pool2.shape[3]
pool2 = tf.reshape(pool2, [-1, in_dim])
fc1 = self.fully_connected(inputs=pool2,
out_dim=1024,
scope_name="fc1")
if self.training:
dropout = tf.nn.dropout(tf.nn.relu(fc1),
keep_prob=0.75,
name="relu_dropout")
else:
dropout = tf.nn.dropout(tf.nn.relu(fc1),
keep_prob=1.0,
name="relu_dropout")
self.logits = self.fully_connected(inputs=dropout,
out_dim=10,
scope_name="logits")
def _loss(self):
""" Define loss function. """
with tf.name_scope("loss_scope"):
entropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.logits,
labels=self.y_true,
name="entropy")
self.loss = tf.reduce_mean(entropy, name="loss")
def _optimize(self):
""" Optimization. """
with tf.name_scope("optimizer_scope"):
self.opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate) \
.minimize(self.loss, global_step=self.global_step)
def _eval(self):
""" Evaluation. """
with tf.name_scope("eval_scope"):
self.y_pred = tf.nn.softmax(self.logits)
correct_preds = tf.equal(tf.argmax(self.y_pred, axis=1), tf.argmax(self.y_true, axis=1))
self.n_correct = tf.reduce_sum(tf.cast(correct_preds, tf.float32))
def _summarize(self):
""" Add summary. """
with tf.name_scope("summary_scope"):
self.summary_loss = tf.summary.scalar("loss", self.loss)
self.summary_acc = tf.summary.scalar("acc", self.n_correct / self.batch_size)
def build_graph(self):
""" Build computational graph. """
self._import_data()
self.inference()
self._loss()
self._optimize()
self._eval()
self._summarize()
def eval_once(self, sess, init, n_eval):
""" Eval once. """
self.training = False
sess.run(init)
total_correct = 0
try:
while True:
n_c = sess.run(self.n_correct)
total_correct += n_c
except tf.errors.OutOfRangeError:
pass
return total_correct / (n_eval * 1.0)
def train(self, sess):
""" Training. """
saver = tf.train.Saver()
initial_step = 0
utils.safe_mkdir("checkpointrs")
ckpt = tf.train.get_checkpoint_state(os.path.dirname("checkpoints/checkpoint"))
if ckpt and ckpt.model_checkpoint_path:
print("Restore from checkpoints!")
saver.restore(sess, ckpt.model_checkpoint_path)
writer = tf.summary.FileWriter("./graphs/my_convnet/lr_{}batch_{}".\
format(self.learning_rate, self.batch_size), sess.graph)
initial_step = self.global_step.eval()
step = initial_step
for epo in range(1, self.n_epochs + 1):
self.training = True
sess.run(self.train_init)
total_loss = 0.
total_batch = 0
try:
while True:
_, loss_val, summary_loss = sess.run([self.opt, self.loss, self.summary_loss])
writer.add_summary(summary_loss, global_step=step)
total_loss += loss_val
total_batch += 1
if (step + 1) % 1000 == 0:
print("step [{}] loss: {:.4f}".format(step + 1, total_loss / (total_batch * 1.0)))
saver.save(sess, "checkpoints/my_convnet", step)
step += 1
except tf.errors.OutOfRangeError:
print("Data out of range!")
sess.run(self.train_init)
pass
summary_acc = sess.run(self.summary_acc)
writer.add_summary(summary_acc, global_step=step)
if epo % 1 == 0:
acc = self.eval_once(sess, self.valid_init, self.n_valid)
print("epoch: {:2d} valid acc: {:.4f}".format(epo, acc))
writer.close()
def test(self, sess):
""" Testing. """
acc = self.eval_once(sess, self.test_init, self.n_test)
print("Test acc: {:.4f}".format(acc))
def main():
model = ConvNet()
model.build_graph()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
model.train(sess)
model.test(sess)
if __name__ == "__main__":
main()