- Today
- Total
Running Deeper C/C++/Java
Today, I tried writing a Generative Adversarial Network.
Using the MNIST Data set, I will be writing a GAN that will create images of numbers as close to the dataset as possible.
Currently, I have been getting stuck, the generator doesn't generate proper images.
GAN works with two neural nets, one was a generator, the other being a discriminant. The generator creates images from random noise, the discriminant checks if the image is from the dataset or not. Both evolve at the same time, and the focus of the generator is to trick the Discriminant and make it think that the image is from the dataset. After a while, the generator starts creating images that basically look like the images from the Dataset completely.
For me, this attempt was a total failure.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
|
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
class GAN():
def __init__(self,lr,batch_size,n_noise,layers):
self.checkpoint_dir='gan_check'
self.lr=lr
self.batch_size=batch_size
self.n_noise=n_noise
self.layers=layers
#generator
self.noise=tf.placeholder(dtype=tf.float32,shape=[None,self.n_noise],name="gen")
with tf.variable_scope('gen'):
h=tf.layers.Dense(self.layers,activation=tf.nn.relu)(self.noise)
h=tf.layers.Dense(28*28,activation=tf.nn.sigmoid)(h)
self.genout=tf.reshape(h,[-1,28,28])
#discriminator
self.x=tf.placeholder(dtype=tf.float32,shape=[None,28,28],name="in")
self.y=tf.placeholder(dtype=tf.uint8,shape=[None],name="out")
with tf.variable_scope('disc'):
flx=tf.layers.Flatten()(self.x)
h1=tf.layers.Dense(self.layers,activation=tf.nn.relu)(flx)
h1=tf.layers.Dense(1,activation=tf.nn.sigmoid)(h)
self.dor=h1
with tf.variable_scope('disc',reuse=True):
flx=tf.layers.Flatten()(self.genout)
h1=tf.layers.Dense(self.layers,activation=tf.nn.relu)(flx)
h1=tf.layers.Dense(1,activation=tf.nn.sigmoid)(h)
self.dog=h1
#loss
self.dl=-(tf.reduce_mean(tf.log(self.dor)+tf.log(1-self.dog)))
self.gl=-(tf.reduce_mean(tf.log(self.dog)))
#optimizer
dtv=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='disc')
gtv=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='gen')
self.dtrain=tf.train.AdamOptimizer(self.lr).minimize(self.dl,var_list=dtv)
self.gtrain=tf.train.AdamOptimizer(self.lr).minimize(self.gl,var_list=gtv)
self.sess=tf.Session()
self.load()
def save(self):
self.saver.save(self.sess,self.checkpoint_dir+'/model.ckpt')
print('save success')
def load(self):
self.saver=tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
ckpt=tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
self.saver.restore(self.sess,ckpt.model_checkpoint_path)
print('load success')
else:
self.sess.run(tf.global_variables_initializer())
print('load fail')
def save_img(self,iter):
size=10
f,a=plt.subplots(1,size,figsize=(size,1))
noise=np.random.normal(size=(10,self.n_noise))
smp=self.sess.run(self.genout,feed_dict={self.noise:noise})
for i in range(10):
a[i].set_axis_off()
a[i].imshow(smp[i])
plt.savefig('sample/{}.png'.format(str(iter).zfill(3)),bbox_inches='tight')
plt.close(f)
def train(self,x_train,run=100):
total=len(x_train)//self.batch_size
for i in range(run):
for j in range(total):
start=j*self.batch_size
end=min((j+1)*self.batch_size,len(x_train))
#noise
noise=np.random.normal(size=(self.batch_size,self.n_noise))
self.sess.run(self.dtrain,feed_dict={self.noise:noise,self.x:x_train[start:end]})
self.sess.run(self.gtrain,feed_dict={self.noise:noise})
self.save()
self.save_img(i)
if __name__=="__main__":
#load dataset
mn=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mn.load_data()
x_train,x_test=x_train/255.,x_test/255.
#object
ob=GAN(0.001,len(x_train)//100,128,256)
ob.train(x_train,10000)
|
cs |
Basically, I created a python GAN class with a Save, Load, and Generator/Discriminant. However, the result was a failure, the images generated as below.
I will try again later. The gan isn't learning, the fifth image was repeated for over 10 hours.