Tensorflow 使用 pb 文件保存(恢复)模型计算图和参数

本文详细介绍了如何使用TensorFlow的graph_util模块将模型保存为pb文件,并提供了代码示例。同时,也阐述了如何从pb文件中恢复模型,以便于在不同环境中部署和使用。

一 、 保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量
       参数2:GraphDef 对象,它描述了计算网络
       参数3:Graph图中需要输出的节点的名称的列表
返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:
        constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
        with open( pbName, mode='wb') as f:
            f.write(constant_graph.SerializeToString())
需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:
        graph0 = tf.GraphDef()
        with open( pbName, mode='rb') as f:
            graph0.ParseFromString( f.read() )
            tf.import_graph_def( graph0 , name = ''  )

三、代码:


import tensorflow as tf 
from tensorflow.python.framework import graph_util

pbName = 'graphA.pb'
def graphCreate() :
    with tf.Session() as sess :
        var1 = tf.placeholder ( tf.int32 , name='var1' ) 
        var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                                              #’var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                                              # var2的名称是’var2:0'
        var3 = tf.Variable( 30 , name='var3' )
        var4 = tf.Variable( 40 , name='var4' )
        var4op = tf.assign( var4 , 1000 , name = 'var4op1'  )
        sum = tf.Variable( 4, name='sum' )
        sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
        sum = tf.add( sum , var3 , name='sum_var3' )
        sumOps = tf.add( sum , var4 , name='sum_operation'  )
        oper = tf.get_default_graph().get_operations()
        with open( 'operation.csv','wt' ) as f:
            s = 'name,type,output\n'
            f.write( s ) 
            for o in oper:
                s = o.name
                s += ','+ o.type 
                inp = o.inputs
                oup = o.outputs
                for iip in inp :
                    s #s += ','+ str(iip)
                for iop in oup :
                    s += ',' + str(iop)
                s += '\n'
                f.write( s ) 
                  
            for var in tf.global_variables():
                print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                                                #张量的名字使用操作名加:0来表示
        init = tf.global_variables_initializer()
        sess.run( init )
        sess.run( var4op )
        print('sum_operation result is  Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )

        constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
        with open( pbName, mode='wb') as f:
            f.write(constant_graph.SerializeToString())

def graphGet() :
    print("start get:" )
    with tf.Graph().as_default():
        graph0 = tf.GraphDef()
        with open( pbName, mode='rb') as f:
            graph0.ParseFromString( f.read() )
            tf.import_graph_def( graph0 , name = ''  )
        with tf.Session() as sess :
            init = tf.global_variables_initializer()
            sess.run(init)
            v1 = sess.graph.get_tensor_by_name('var1:0' )
            v2 = sess.graph.get_tensor_by_name('var2:0' )
            v3 = sess.graph.get_tensor_by_name('var3:0' )
            v4 = sess.graph.get_tensor_by_name('var4:0' )
            
            sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
            print('sumTensor is : ' , sumTensor )
            print( sess.run( sumTensor ,  feed_dict={v1:1} ) )  
    
graphCreate()
graphGet()
    

四、保存pb函数代码里的操作名称/类型/返回的张量:
 

operation nameoperation typeoutput  
var1PlaceholderTensor("var1:0" dtype=int32) 
var2/initial_valueConstTensor("var2/initial_value:0" shape=() dtype=int32)
var2VariableV2Tensor("var2:0" shape=() dtype=int32_ref)
var2/AssignAssignTensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/readIdentityTensor("var2/read:0" shape=() dtype=int32)
var3/initial_valueConstTensor("var3/initial_value:0" shape=() dtype=int32)
var3VariableV2Tensor("var3:0" shape=() dtype=int32_ref)
var3/AssignAssignTensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/readIdentityTensor("var3/read:0" shape=() dtype=int32)
var4/initial_valueConstTensor("var4/initial_value:0" shape=() dtype=int32)
var4VariableV2Tensor("var4:0" shape=() dtype=int32_ref)
var4/AssignAssignTensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/readIdentityTensor("var4/read:0" shape=() dtype=int32)
var4op1/valueConstTensor("var4op1/value:0" shape=() dtype=int32)
var4op1AssignTensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_valueConstTensor("sum/initial_value:0" shape=() dtype=int32)
sumVariableV2Tensor("sum:0" shape=() dtype=int32_ref)
sum/AssignAssignTensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/readIdentityTensor("sum/read:0" shape=() dtype=int32)
var1_var2AddTensor("var1_var2:0" dtype=int32) 
sum_var3AddTensor("sum_var3:0" dtype=int32) 
sum_operationAddTensor("sum_operation:0" dtype=int32) 
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值