Sarsa-Lamda

xiaoxiao2021-02-28  63

1、算法: Sarsa-lambda 是基于 Sarsa 方法的升级版, 他能更有效率地学习到怎么样获得好的 reward. 如果说 Sarsa 和 Qlearning 都是每次获取到 reward, 只更新获取到 reward 的前一步. 那 Sarsa-lambda 就是更新获取到 reward 的前 lambda 步. lambda 是在 [0, 1] 之间取值, 如果 lambda = 0, Sarsa-lambda 就是 Sarsa, 只更新获取到 reward 前经历的最后一步. 如果 lambda = 1, Sarsa-lambda 更新的是 获取到 reward 前所有经历的步. 2、代码实现: 算法更新同Sarsa; 算法决策: 2.1、主结构:

class SarsaLambdaTable: # 初始化 (有改变) def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9): # 选行为 (与之前一样) def choose_action(self, observation): # 学习更新参数 (有改变) def learn(self, s, a, r, s_): # 检测 state 是否存在 (有改变) def check_state_exist(self, state):

2.2、预设值:

在预设值当中, 添加了 trace_decay=0.9 这个就是 lambda 的值了. 这个值将会使得拿到 reward 前的每一步都有价值.

class SarsaLambdaTable(RL): # 继承 RL class def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9): super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 后向观测算法, eligibility trace. self.lambda_ = trace_decay self.eligibility_trace = self.q_table.copy() # 空的 eligibility trace 表

2.3、检查state是否存在:

check_state_exist 和之前的是高度相似的. 唯一不同的地方是我们考虑了 eligibility_trace,

class SarsaLambdaTable(RL): # 继承 RL class def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9): ... def check_state_exist(self, state): if state not in self.q_table.index: # append new state to q table to_be_append = pd.Series( [0] * len(self.actions), index=self.q_table.columns, name=state, ) self.q_table = self.q_table.append(to_be_append) # also update eligibility trace self.eligibility_trace = self.eligibility_trace.append(to_be_append)

2.4、学习:

有了父类的 RL, 我们这次的编写就很简单, 只需要编写 SarsaLambdaTable 中 learn 这个功能就完成了. 因为其他功能都和父类是一样的. 这就是我们所有的 SarsaLambdaTable 于父类 RL 不同之处的代码. 是不是很简单.

class SarsaLambdaTable(RL): # 继承 RL class def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9): ... def check_state_exist(self, state): ... def learn(self, s, a, r, s_, a_): # 这部分和 Sarsa 一样 self.check_state_exist(s_) q_predict = self.q_table.ix[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.ix[s_, a_] else: q_target = r error = q_target - q_predict # 这里开始不同: # 对于经历过的 state-action, 我们让他+1, 证明他是得到 reward 路途中不可或缺的一环 self.eligibility_trace.ix[s, a] += 1 # Q table 更新 self.q_table += self.lr * error * self.eligibility_trace **# 更有效的更新方式: self.eligibility_trace.ix[s, :] *= 0 self.eligibility_trace.ix[s, a] = 1** # 随着时间衰减 eligibility trace 的值, 离获取 reward 越远的步, 他的"不可或缺性"越小 self.eligibility_trace *= self.gamma*self.lambda_

2.5、环境:

import numpy as np np.random.seed(1) import tkinter as tk import time UNIT = 40 # pixels MAZE_H = 4 # grid height MAZE_W = 4 # grid width class Maze(tk.Tk,object): #注意不要却object def __init__(self): super(Maze, self).__init__() self.action_space = ['u', 'd', 'l', 'r'] self.n_actions = len(self.action_space) self.title('maze') self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT)) self._build_maze() def _build_maze(self): self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT) # create grids for c in range(0, MAZE_W * UNIT, UNIT): x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT self.canvas.create_line(x0, y0, x1, y1) for r in range(0, MAZE_H * UNIT, UNIT): x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r self.canvas.create_line(x0, y0, x1, y1) # create origin origin = np.array([20, 20]) # hell hell1_center = origin + np.array([UNIT * 2, UNIT]) self.hell1 = self.canvas.create_rectangle( hell1_center[0] - 15, hell1_center[1] - 15, hell1_center[0] + 15, hell1_center[1] + 15, fill='black') # hell hell2_center = origin + np.array([UNIT, UNIT * 2]) self.hell2 = self.canvas.create_rectangle( hell2_center[0] - 15, hell2_center[1] - 15, hell2_center[0] + 15, hell2_center[1] + 15, fill='black') # create oval oval_center = origin + UNIT * 2 self.oval = self.canvas.create_oval( oval_center[0] - 15, oval_center[1] - 15, oval_center[0] + 15, oval_center[1] + 15, fill='yellow') # create red rect self.rect = self.canvas.create_rectangle( origin[0] - 15, origin[1] - 15, origin[0] + 15, origin[1] + 15, fill='red') # pack all self.canvas.pack() def reset(self): self.update() time.sleep(0.5) self.canvas.delete(self.rect) origin = np.array([20, 20]) self.rect = self.canvas.create_rectangle( origin[0] - 15, origin[1] - 15, origin[0] + 15, origin[1] + 15, fill='red') # return observation return self.canvas.coords(self.rect) def step(self, action): s = self.canvas.coords(self.rect) base_action = np.array([0, 0]) if action == 0: # up if s[1] > UNIT: base_action[1] -= UNIT elif action == 1: # down if s[1] < (MAZE_H - 1) * UNIT: base_action[1] += UNIT elif action == 2: # right if s[0] < (MAZE_W - 1) * UNIT: base_action[0] += UNIT elif action == 3: # left if s[0] > UNIT: base_action[0] -= UNIT self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent s_ = self.canvas.coords(self.rect) # next state # reward function if s_ == self.canvas.coords(self.oval): reward = 1 done = True elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]: reward = -1 done = True else: reward = 0 done = False return s_, reward, done def render(self): time.sleep(0.05) self.update()
转载请注明原文地址: https://www.6miu.com/read-79789.html

最新回复(0)