pinocchio  3.7.0
A fast and flexible implementation of Rigid Body Dynamics algorithms and their analytical derivatives
 
Loading...
Searching...
No Matches
qtable.py
1"""
2Example of Q-table learning with a simple discretized 1-pendulum environment.
3"""
4
5import signal
6import time
7
8import matplotlib.pyplot as plt
9import numpy as np
10from dpendulum import DPendulum
11
12# --- Random seed
13RANDOM_SEED = int((time.time() % 10) * 1000)
14print(f"Seed = {RANDOM_SEED}")
15np.random.seed(RANDOM_SEED)
16
17# --- Hyper paramaters
18NEPISODES = 500 # Number of training episodes
19NSTEPS = 50 # Max episode length
20LEARNING_RATE = 0.85 #
21DECAY_RATE = 0.99 # Discount factor
22
23# --- Environment
24env = DPendulum()
25NX = env.nx # Number of (discrete) states
26NU = env.nu # Number of (discrete) controls
27
28Q = np.zeros([env.nx, env.nu]) # Q-table initialized to 0
29
30
31def rendertrial(maxiter=100):
32 """Roll-out from random state using greedy policy."""
33 s = env.reset()
34 for i in range(maxiter):
35 a = np.argmax(Q[s, :])
36 s, r = env.step(a)
37 env.render()
38 if r == 1:
39 print("Reward!")
40 break
41
42
43signal.signal(
44 signal.SIGTSTP, lambda x, y: rendertrial()
45) # Roll-out when CTRL-Z is pressed
46h_rwd = [] # Learning history (for plot).
47
48for episode in range(1, NEPISODES):
49 x = env.reset()
50 rsum = 0.0
51 for steps in range(NSTEPS):
52 u = np.argmax(
53 Q[x, :] + np.random.randn(1, NU) / episode
54 ) # Greedy action with noise
55 x2, reward = env.step(u)
56
57 # Compute reference Q-value at state x respecting HJB
58 Qref = reward + DECAY_RATE * np.max(Q[x2, :])
59
60 # Update Q-Table to better fit HJB
61 Q[x, u] += LEARNING_RATE * (Qref - Q[x, u])
62 x = x2
63 rsum += reward
64 if reward == 1:
65 break
66
67 h_rwd.append(rsum)
68 if not episode % 20:
69 print(f"Episode #{episode} done with {sum(h_rwd[-20:])} sucess")
70
71print(f"Total rate of success: {sum(h_rwd) / NEPISODES:.3f}")
73plt.plot(np.cumsum(h_rwd) / range(1, NEPISODES))
74plt.show()
rendertrial(maxiter=100)
Definition qtable.py:31