#-*- coding:utf-8 -*-
importnumpyasnp
importtensorflowastf
importmatplotlib.pyplotasplt
importtrain_Car_Data
importtime
data = train_Car_Data.load_Data(download=False)
new_Data = train_Car_Data.covert2onehot(data)
#准备好所有数据 训练集和测试集
new_Data = new_Data.values.astype(np.float32)#将oneshot转化为32位
np.random.shuffle(new_Data)#随机化数据
sep =int(0.7*len(new_Data))#提取出前70%的数据的下标
train_data = new_Data[:sep]
test_Data = new_Data[sep:]
#建立网络
tf_input = tf.placeholder(tf.float32,[None,25],"input")#第一个是行 第二个是列(有25列)
tfx = tf_input[:,:21]#打竖的行不要求,对打横的选21列
tfy = tf_input[:,21:]
l1 = tf.layers.dense(tfx,128,tf.nn.relu,name="l1")#第一个参数是数据 第二个是unit第三个是激励函数 第四个是名字 这是在设置隐藏层
l2 = tf.layers.dense(l1,128,tf.nn.relu,name="l2")#再设置一个隐藏层
out = tf.layers.dense(l2,4,name="l3")#输出层
prediction = tf.nn.softmax(out,name="pred")#先用激励函数softmax得到预期值 留到后面对比
loss = tf.losses.softmax_cross_entropy(onehot_labels=tfy,logits=out)#给标签 也就是y值 和x值来得到损失值
accuracy = tf.metrics.accuracy(# return (acc, update_op), and create 2 local var得到准确度
labels=tf.argmax(tfy,axis=1),predictions=tf.argmax(out,axis=1),
)[1]
opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)#梯度下降
train_op = opt.minimize(loss)#最小化损失 让损失最小
sess = tf.Session()
sess.run(tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()))
# plt.ion()
# fig , (ax1,ax2) = plt.subplots(1,2,figsize=(8,4))
# accuracies,steps = [], []
fortinrange(4000):
batch_index = np.random.randint(len(train_data),size=32)
sess.run(train_op,{tf_input:train_data[batch_index]})
ift %50==0:
acc_,pre_,loss_ = sess.run([accuracy,prediction,loss],{tf_input:test_Data})
# accuracies.append(acc_)
# steps.append(t)
print("Set: %i "% t,"| Accurate: %.2f"% acc_,"| Loss: %.2f"% loss_)
writer = tf.summary.FileWriter('./my_graph',sess.graph)
# ax1.cla()
#
# for c in range(4):
# bp = ax1.bar(x=c+0.1,height=sum((np.argmax(pre_,axis=1) == c)),width=0.2,color='red')
# bt = ax1.bar(x=c-0.1,height=sum((np.argmax(test_Data[:,21:],axis=1) == c)),width= 0.2,color='blue')
# ax1.set_xticks(range(4),["accepted", "good", "unaccepted", "very good"])
# ax1.legend(handles=[bp, bt], labels=["prediction", "target"])
# ax1.set_ylim((0,400))
# ax2.cla()
# ax2.plot(steps,accuracies,label="accuracy")
# ax2.set_ylim(ymax=1)
# ax2.set_ylabel("accuracy")
#
#
# plt.ioff()
# plt.show()
输出
Set: 0 | Accurate: 0.69 | Loss: 1.22
Set: 50 | Accurate: 0.74 | Loss: 0.53
Set: 100 | Accurate: 0.78 | Loss: 0.39
Set: 150 | Accurate: 0.80 | Loss: 0.29
Set: 200 | Accurate: 0.82 | Loss: 0.25
Set: 250 | Accurate: 0.84 | Loss: 0.21
Set: 300 | Accurate: 0.85 | Loss: 0.17
Set: 350 | Accurate: 0.86 | Loss: 0.17
Set: 400 | Accurate: 0.87 | Loss: 0.14
Set: 450 | Accurate: 0.88 | Loss: 0.13
Set: 500 | Accurate: 0.89 | Loss: 0.11
Set: 550 | Accurate: 0.90 | Loss: 0.10
Set: 600 | Accurate: 0.90 | Loss: 0.10
Set: 650 | Accurate: 0.91 | Loss: 0.09
Set: 700 | Accurate: 0.91 | Loss: 0.08
Set: 750 | Accurate: 0.91 | Loss: 0.07
Set: 800 | Accurate: 0.92 | Loss: 0.06
Set: 850 | Accurate: 0.92 | Loss: 0.06
Set: 900 | Accurate: 0.93 | Loss: 0.06
Set: 950 | Accurate: 0.93 | Loss: 0.05
Set: 1000 | Accurate: 0.93 | Loss: 0.05
Set: 1050 | Accurate: 0.93 | Loss: 0.05
Set: 1100 | Accurate: 0.94 | Loss: 0.06
Set: 1150 | Accurate: 0.94 | Loss: 0.04
Set: 1200 | Accurate: 0.94 | Loss: 0.04
Set: 1250 | Accurate: 0.94 | Loss: 0.04
Set: 1300 | Accurate: 0.94 | Loss: 0.03
Set: 1350 | Accurate: 0.95 | Loss: 0.03
Set: 1400 | Accurate: 0.95 | Loss: 0.03
Set: 1450 | Accurate: 0.95 | Loss: 0.03
Set: 1500 | Accurate: 0.95 | Loss: 0.03
Set: 1550 | Accurate: 0.95 | Loss: 0.03
Set: 1600 | Accurate: 0.95 | Loss: 0.03
Set: 1650 | Accurate: 0.95 | Loss: 0.03
Set: 1700 | Accurate: 0.96 | Loss: 0.02
Set: 1750 | Accurate: 0.96 | Loss: 0.03
Set: 1800 | Accurate: 0.96 | Loss: 0.02
Set: 1850 | Accurate: 0.96 | Loss: 0.02
Set: 1900 | Accurate: 0.96 | Loss: 0.02
Set: 1950 | Accurate: 0.96 | Loss: 0.02
Set: 2000 | Accurate: 0.96 | Loss: 0.02
Set: 2050 | Accurate: 0.96 | Loss: 0.02
Set: 2100 | Accurate: 0.96 | Loss: 0.02
Set: 2150 | Accurate: 0.96 | Loss: 0.02
Set: 2200 | Accurate: 0.97 | Loss: 0.02
Set: 2250 | Accurate: 0.97 | Loss: 0.02
Set: 2300 | Accurate: 0.97 | Loss: 0.02
Set: 2350 | Accurate: 0.97 | Loss: 0.02
Set: 2400 | Accurate: 0.97 | Loss: 0.02
Set: 2450 | Accurate: 0.97 | Loss: 0.02
Set: 2500 | Accurate: 0.97 | Loss: 0.02
Set: 2550 | Accurate: 0.97 | Loss: 0.02
Set: 2600 | Accurate: 0.97 | Loss: 0.01
Set: 2650 | Accurate: 0.97 | Loss: 0.01
Set: 2700 | Accurate: 0.97 | Loss: 0.01
Set: 2750 | Accurate: 0.97 | Loss: 0.01
Set: 2800 | Accurate: 0.97 | Loss: 0.01
Set: 2850 | Accurate: 0.97 | Loss: 0.01
Set: 2900 | Accurate: 0.97 | Loss: 0.01
Set: 2950 | Accurate: 0.97 | Loss: 0.01
Set: 3000 | Accurate: 0.97 | Loss: 0.01
Set: 3050 | Accurate: 0.97 | Loss: 0.01
Set: 3100 | Accurate: 0.97 | Loss: 0.01
Set: 3150 | Accurate: 0.97 | Loss: 0.01
Set: 3200 | Accurate: 0.98 | Loss: 0.01
Set: 3250 | Accurate: 0.98 | Loss: 0.01
Set: 3300 | Accurate: 0.98 | Loss: 0.01
Set: 3350 | Accurate: 0.98 | Loss: 0.01
Set: 3400 | Accurate: 0.98 | Loss: 0.01
Set: 3450 | Accurate: 0.98 | Loss: 0.01
Set: 3500 | Accurate: 0.98 | Loss: 0.01
Set: 3550 | Accurate: 0.98 | Loss: 0.01
Set: 3600 | Accurate: 0.98 | Loss: 0.01
Set: 3650 | Accurate: 0.98 | Loss: 0.01
Set: 3700 | Accurate: 0.98 | Loss: 0.01
Set: 3750 | Accurate: 0.98 | Loss: 0.01
Set: 3800 | Accurate: 0.98 | Loss: 0.01
Set: 3850 | Accurate: 0.98 | Loss: 0.01
Set: 3900 | Accurate: 0.98 | Loss: 0.01
Set: 3950 | Accurate: 0.98 | Loss: 0.01
Process finished with exit code 0
可以发现准确度逐步增加,损失逐步减少