博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow 中 Cross Entropy算法理解
阅读量:5165 次
发布时间:2019-06-13

本文共 1765 字,大约阅读时间需要 5 分钟。

关于tensorflow 中cross entropy 的 numpy实现

 

import tensorflow as tfimport numpy as npimport osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'# Make up some testing data, need to be rank 2x = np.array([        [0.,2.,1.],        [0.,0.,2.]        ])label = np.array([        [0.,1.,0.],        [0.,0.,1.]        ])label2 = np.array([1,2])# Numpy part #def sigmoid(logits):    return (1/(1+np.exp(-logits)))def softmax(logits):    sf = np.exp(logits)    sf = sf/np.sum(sf, axis=1).reshape(-1,1)    return sfdef cross_entropy2(softmax, labels):    return -(labels * np.log(softmax) + (1- labels) * ( np.log(1- softmax)))def cross_entropy(softmax, labels):    return -np.sum(labels * np.log(softmax),axis=1)numpy_sig = cross_entropy2( sigmoid(x), label )numpy_softmax = cross_entropy( softmax(x), label )print(softmax(x))print("my sigmoid_cross_entropy_with_logits is \n %s \n "%(numpy_sig))print("my softmax_cross_entropy_with_logits is \n %s \n "%(numpy_softmax))# Tensorflow part #g = tf.Graph()with g.as_default():    tf_x = tf.constant(x)    tf_label = tf.constant(label)    tf_label2 = tf.constant(label2)    tf_ret = tf.nn.sigmoid_cross_entropy_with_logits(logits= tf_x,labels=tf_label)    tf_softmax = tf.nn.softmax_cross_entropy_with_logits(logits= tf_x,labels=tf_label)    tf_softmax_2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits= tf_x,labels=tf_label2)with tf.Session(graph=g) as ss:    r_sig,r_softmax,r_softmax_sparse = ss.run([tf_ret,tf_softmax,tf_softmax_2])print("tensorflow sigmoid_cross_entropy_with_logits is \n %s \n "%(r_sig))print("tensorflow softmax_cross_entropy_with_logits is \n %s \n "%(r_softmax))print("tensorflow sparse_softmax_cross_entropy_with_logits is \n %s \n "%(r_softmax_sparse))

 

转载于:https://www.cnblogs.com/kakamilan/p/7116332.html

你可能感兴趣的文章
软件测试-HW03
查看>>
linux第1天 fork exec 守护进程
查看>>
Ajax原理学习
查看>>
最新最潮的24段魔尺立体几何玩法(2016版)
查看>>
C# 3.0 LINQ的准备工作
查看>>
CodeForces - 449D Jzzhu and Numbers
查看>>
mysql批量插入更新操作
查看>>
静态代码审查工具FxCop插件开发(c#)
查看>>
创建代码仓库
查看>>
理解裸机部署过程ironic
查看>>
Django 组件-ModelForm
查看>>
zabbix 二 zabbix agent 客户端
查看>>
大数据分析中,有哪些常见的大数据分析模型?
查看>>
Generate SSH key
查看>>
URL中不应出现汉字
查看>>
SSH框架面试总结----1
查看>>
如何防止Arp攻击
查看>>
luoguP1313 [NOIp2011]计算系数 [组合数学]
查看>>
清明 DAY2
查看>>
[LintCode] 全排列
查看>>