一、一点小小改进
def chose_action(self,ai=False):
pre_s_space = self.s + self.action_space
pre_s_space = [pre_s for pre_s in pre_s_space if not np.all((self.s_chain ==pre_s),axis=1).any()]
if pre_s_space:self.action_space = pre_s_space - self.s
a = random.choice(self.action_space)
return a
返回
playing 0 times!
526 times success!
状态不重复的话提速一倍,把环境弄复杂一点也是一样:
playing 20000 times!
playing 40000 times!
playing 60000 times!
70613 times success!
这种改进其实没什么大用,因为它是针对这种游戏专门的设计的,你的小人工智能还没有掌握普遍的学习智慧。
二、Q_learning
现在我们加入一点称得上人工智能的东西 Q_learning !
公式:
写成自加模式:
只用下面这个公式!
原理见 【强化学习】Q-Learning算法详解
把环境的 step函数改成:
def step(self,a=[0,0],ai=False):
s = self.agent.s
s_ = np.array(s) + np.array(a)
if 0<=s_[0]<=3 and 0<=s_[1]<=11:
self.agent.s = s_
r = self.env[s_[0],s_[1]]
else:
s_ = s
r = -1
self.agent.post_step(s,a,r,s_)
return s_,r
人工智能体的改动:
class Agent:
def __init__(self):
self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
self.s = np.array([0,0])
self.s_chain = np.expand_dims(self.s,0)
self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
self.Q = np.zeros((4,12,4))
self.epsilon = 0.2
def chose_action(self):
if self.epsilon < random.random():
a = random.choice(self.action_space)
else:
a = self.action_space[np.argmax(self.Q[self.s[0],self.s[1]])]
return a
def post_step(self,s,a,r,s_):
self.s_chain = np.vstack([self.s_chain,s])
self.sar_chain = np.vstack([self.sar_chain,np.hstack([s,a,r])])
a_number = np.where((agent.action_space == a).all(axis=1))
if r == -1:
self.Q[s_[0],s_[1]] = -1
update = 0.1*(0.9*self.Q[s_[0],s_[1],].max() - self.Q[s[0],s[1],a_number])
self.Q[s[0],s[1],a_number] += update
def reset(self):
self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
self.s = np.array([0,0])
self.s_chain = np.expand_dims(self.s,0)
self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
三、拉出来遛遛
把环境弄复杂一点点,经过30万轮碰壁,他学会了找到出口:
327737 times success!
他学到的东西叫做Q表,算是他的知识库,我们打开看看:
plt.imshow((agent.Q[:,:,:-1]*255).astype(np.uint8))
可以看到越接近出口的地方越明亮,极暗的方块是他觉得不安全的位置。
四、文末放上全部代码:
import numpy as np
from matplotlib import pyplot as plt
import random
from itertools import count
class Env:
def __init__(self):
self.action_space = []
self.agent = None
self.env = np.zeros((4,12))
self.env[-1,1:-1] = -1
self.env[:2,3] = -1
self.env[1:-1,8] = -1
self.env[-1,-1] = 1
self.env_show = self.env.copy()
def step(self,a=[0,0],ai=False):
s = self.agent.s
s_ = np.array(s) + np.array(a)
if 0<=s_[0]<=3 and 0<=s_[1]<=11:
self.agent.s = s_
r = self.env[s_[0],s_[1]]
else:
s_ = s
r = -1
self.agent.post_step(s,a,r,s_)
return s_,r
def play(self):
env.reset()
for t in count(1):
a = agent.chose_action()
if a is not None:
s,r = env.step(a)
if r in [-1,1]:
break
else:
r = None
break
return t,r
def play_until_success(self):
for t in count(1):
_,r = self.play()
if r:
if t%20000 == 0:
print(f"playing {t} times!")
if r == 1:
print(f"{t} times success!")
self.render()
break
else:break
def render(self):
for i,j in self.agent.s_chain:
self.env_show[i,j] = 0.5
plt.imshow(self.env_show)
plt.show()
def reset(self):
self.agent.reset()
self.env_show = self.env.copy()
def register(self,agent):
self.agent = agent
class Agent:
def __init__(self):
self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
self.s = np.array([0,0])
self.s_chain = np.expand_dims(self.s,0)
self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
self.Q = np.zeros((4,12,4))
self.epsilon = 0.2
def chose_action(self):
if self.epsilon < random.random():
a = random.choice(self.action_space)
else:
a = self.action_space[np.argmax(self.Q[self.s[0],self.s[1]])]
return a
def post_step(self,s,a,r,s_):
self.s_chain = np.vstack([self.s_chain,s])
self.sar_chain = np.vstack([self.sar_chain,np.hstack([s,a,r])])
a_number = np.where((agent.action_space == a).all(axis=1))
if r == -1:
self.Q[s_[0],s_[1]] = -1
update = 0.1*(0.9*self.Q[s_[0],s_[1],].max() - self.Q[s[0],s[1],a_number])
self.Q[s[0],s[1],a_number] += update
def reset(self):
self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
self.s = np.array([0,0])
self.s_chain = np.expand_dims(self.s,0)
self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
env = Env()
agent = Agent()
env.register(agent)
# env.render()
env.play_until_success()