TensorFlowで、Session()で動かすプログラム構造は、Graphである。Graphクラスについて、簡単に整理しておく。
デフォルトグラフの指定は、以下のようになる。
1 |
graph = tf.get_default_graph() |
新しいグラフを作成して、デフォルトに指定したければ、
1 2 3 |
newgraph = tf.Graph() with newgraph.as_default(): … |
グラフデータへアクセスする関数には以下のようなものがある。
1 2 3 4 5 6 7 |
コード:get_tensor_by_name(name) コード:get_operation_by_name(name) コード:get_operations() コード:get_all_collection_keys() コード:get_collection(name, scope=None) コード:add_to_collection(name, value) コード:add_to_collections(name, value) |
上記コードをいくつか以下のように組み込んでみる。
1 2 3 4 5 6 7 8 9 |
import tensorflow as tf a = tf.constant(1, name='first') b = tf.constant(2, name='second') sum = a + b; print(tf.get_default_graph().get_operations()) [<tf.Operation 'first' type=Const>, <tf.Operation 'second' type=Const>, <tf.Operation 'add' type=Add>] |
1 2 3 4 5 6 7 |
import tensorflow as tf a = tf.constant(1, name='first') b = tf.constant(2, name='second') sum = a + b; print(tf.get_default_graph().get_tensor_by_name('add:0')) Tensor("add:0", shape=(), dtype=int32) |
Graphをprotocol bufferにJSON様に出力するためには、
1 |
write_graph(graph/graph_def, logdir, name, as_text=True) |
デフォルトフォルダにgraph_output.datとして出力する場合は、以下のようになる。
1 |
tf.train.write_graph(tf.get_default_graph(), os.getcwd(), 'graph_output.dat') |
具体的に以下の例を動かしてみると、
1 2 3 4 5 6 7 8 9 10 |
import tensorflow as tf import os a = tf.constant(1, name='first') b = tf.constant(2, name='second') sum = a + b; with tf.Session() as sess: sess.run(sum) tf.train.write_graph(sess.graph, os.getcwd(), 'graph_output.dat') |
graph_output.datファイルにGraphが以下のように記述されて出力される。
記述には、3つのノードと、バージョンがJSON様に記載されていいる。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
node { name: "first" op: "Const" attr { key: "dtype" value { type: DT_INT32 } } attr { key: "value" value { tensor { dtype: DT_INT32 tensor_shape { } int_val: 1 } } } } node { name: "second" op: "Const" attr { key: "dtype" value { type: DT_INT32 } } attr { key: "value" value { tensor { dtype: DT_INT32 tensor_shape { } int_val: 2 } } } } node { name: "add" op: "Add" input: "first" input: "second" attr { key: "T" value { type: DT_INT32 } } } versions { producer: 27 } |
Graphには、各種変数等の情報を収めていいる以下のようなCollection Keysが設定されている。
1 2 3 4 5 6 7 8 |
GLOBAL_VARIABLES LOCAL_VARIABLES MODEL_VARIABLES TRAINABLE_VARIABLES MOVING_AVERAGE_VARIABLES SUMMARIES QUEUE_RUNNERS REGULARIZATION_LOSSES |