tf.train.Saver를 통해 모델을 저장하거나 불러 올 수 있습니다. Saver가 생성될 때 현재 그래프의 Variable에 save/restore operation 을 추가합니다. save 메소드와 restore 메소드를 통해 해당 operation을 실행할 수 있습니다. save는 Variable 값을 파일로 저장해 줍니다. 추가로, 모델의 구조를 분리된 파일로 저장합니다. 아래는 예시입니다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def model(x):
	with tf.name_scope('model'):
		w = tf.Variable(tf.random_normal((5, 5), 0.0, 1.0), name='w')
		b = tf.Variable(tf.random_normal((1, 5), 0.0, 1.0), name='b')
		return x @ w + b, w, b

tf.reset_default_graph()
x = tf.placeholder(tf.float32, (1, 5,)) 
h, w, b = model(x)

saver = tf.train.Saver()
dirname = 'mysave'
import os
if not os.path.exists(dirname):
	os.mkdir(dirname)

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	path = os.path.join(dirname, 'save')
	saver.save(sess, path)
	saver.save(sess, path, global_step=10)
	w_val, b_val = sess.run((w, b,))
	print(w_val)
	print(b_val)
1
2
3
4
5
6
[[ 0.8011036   1.0809381   2.6776495  -0.04687155 -0.93143463]
 [-1.7361611  -0.11210998  1.8819457  -0.11741646  1.1123091 ]
 [-0.04859339 -0.07581387 -1.3568013   1.3217454  -2.1227653 ]
 [-0.5395399  -0.7515742  -0.7829265   0.5170578  -0.3033594 ]
 [ 0.4305463  -0.18414573 -0.932086    1.529165    0.40401545]]
[[ 0.08475562 -1.3056839   1.6758546   0.6217058   0.48102817]]

restore은 저장된 파일로 부터 값을 읽어와 해당하는 Variable에 대입시켜줍니다. 그러나 그래프의 variable 값을 불러올 때는 주로 그래프의 구조도 모르는 경우가 많습니다. 만약 이전에 저장한 훈련을 이어서하고 싶다면 이전과 똑같은 구조의 그래프를 만들고 Saver 객체를 생성해야 합니다. 다행히 tf.train.import_meta_graph가 이러한 기능을 지원합니다. save 메소드를 실행하면 variable 값과 함께 그래프의 구조(Saver 객체가 포함된)도 별도의 파일로 저장됩니다(name-ckpt.meta 파일). 이 파일을 읽고 똑같은 구조의 그래프를 만든후 해당 그래프의 Saver 객체를 반환해줍니다. 반환된 Saver 객체를 이용해 variable 값을 복원하면 됩니다. 아래는 위에서 저장한값을 복원하는 코드입니다. restore이 variable 을 초기화해주기 때문에 initializer를 실행하지 않아도 됩니다.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
tf.reset_default_graph()
path = os.path.join(dirname, 'save-10.meta')
saver = tf.train.import_meta_graph(path)

g = tf.get_default_graph()
w = g.get_tensor_by_name('model/w:0')
b = g.get_tensor_by_name('model/b:0')

with tf.Session() as sess:
	fw = tf.summary.FileWriter(logdir='logdir', graph=sess.graph)
	path = os.path.join(dirname, 'save-10')
	saver.restore(sess, path)
	w_val, b_val = sess.run((w, b,))
	print(w_val)
	print(b_val)
1
2
3
4
5
6
7
INFO:tensorflow:Restoring parameters from mysave/save-10
[[ 0.8011036   1.0809381   2.6776495  -0.04687155 -0.93143463]
 [-1.7361611  -0.11210998  1.8819457  -0.11741646  1.1123091 ]
 [-0.04859339 -0.07581387 -1.3568013   1.3217454  -2.1227653 ]
 [-0.5395399  -0.7515742  -0.7829265   0.5170578  -0.3033594 ]
 [ 0.4305463  -0.18414573 -0.932086    1.529165    0.40401545]]
[[ 0.08475562 -1.3056839   1.6758546   0.6217058   0.48102817]]

다음과 같이 그래프의 특정 variable에 대해서만 save/restore operation 을 추가할 수 있습니다.

1
2
var_list = [var1, var2]
tf.train.Saver(var_list)