pinocchio  3.2.0 A fast and flexible implementation of Rigid Body Dynamics algorithms and their analytical derivatives
qnet.py
1 """
2 Example of Q-table learning with a simple discretized 1-pendulum environment using a linear Q network.
3 """
4
5 import signal
6 import time
7
8 import matplotlib.pyplot as plt
9 import numpy as np
10 import tensorflow as tf
11
12 from dpendulum import DPendulum
13
14
15 RANDOM_SEED = int((time.time() % 10) * 1000)
16 print("Seed = %d" % RANDOM_SEED)
17 np.random.seed(RANDOM_SEED)
18 tf.set_random_seed(RANDOM_SEED)
19
20
21 NEPISODES = 500 # Number of training episodes
22 NSTEPS = 50 # Max episode length
23 LEARNING_RATE = 0.1 # Step length in optimizer
24 DECAY_RATE = 0.99 # Discount factor
25
26
27 env = DPendulum()
28 NX = env.nx
29 NU = env.nu
30
31
32
34  def __init__(self):
35  x = tf.placeholder(shape=[1, NX], dtype=tf.float32)
36  W = tf.Variable(tf.random_uniform([NX, NU], 0, 0.01, seed=100))
37  qvalue = tf.matmul(x, W)
38  u = tf.argmax(qvalue, 1)
39
40  qref = tf.placeholder(shape=[1, NU], dtype=tf.float32)
41  loss = tf.reduce_sum(tf.square(qref - qvalue))
43
44  self.xx = x # Network input
45  self.qvalueqvalue = qvalue # Q-value as a function of x
46  self.uu = u # Policy as a function of x
47  self.qrefqref = qref # Reference Q-value at next step (to be set to l+Q o f)
48  self.optimoptim = optim # Optimizer
49
50
51
52 tf.reset_default_graph()
53 qvalue = QValueNetwork()
54 sess = tf.InteractiveSession()
55 tf.global_variables_initializer().run()
56
57
58 def onehot(ix, n=NX):
59  """Return a vector which is 0 everywhere except index <i> set to 1."""
60  return np.array(
61  [
62  [(i == ix) for i in range(n)],
63  ],
64  np.float,
65  )
66
67
68 def disturb(u, i):
69  u += int(np.random.randn() * 10 / (i / 50 + 10))
70  return np.clip(u, 0, NU - 1)
71
72
73 def rendertrial(maxiter=100):
74  x = env.reset()
75  for i in range(maxiter):
76  u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})
77  x, r = env.step(u)
78  env.render()
79  if r == 1:
80  print("Reward!")
81  break
82
83
84 signal.signal(
85  signal.SIGTSTP, lambda x, y: rendertrial()
86 ) # Roll-out when CTRL-Z is pressed
87
88
89 h_rwd = [] # Learning history (for plot).
90
91
92 for episode in range(1, NEPISODES):
93  x = env.reset()
94  rsum = 0.0
95
96  for step in range(NSTEPS - 1):
97  u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})[0] # Greedy policy ...
98  u = disturb(u, episode) # ... with noise
99  x2, reward = env.step(u)
100
101  # Compute reference Q-value at state x respecting HJB
102  Q2 = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x2)})
103  Qref = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x)})
104  Qref[0, u] = reward + DECAY_RATE * np.max(Q2)
105
106  # Update Q-table to better fit HJB
107  sess.run(qvalue.optim, feed_dict={qvalue.x: onehot(x), qvalue.qref: Qref})
108
109  rsum += reward
110  x = x2
111  if reward == 1:
112  break
113
114  h_rwd.append(rsum)
115  if not episode % 20:
116  print("Episode #%d done with %d sucess" % (episode, sum(h_rwd[-20:])))
117
118 print("Total rate of success: %.3f" % (sum(h_rwd) / NEPISODES))
119 rendertrial()
120 plt.plot(np.cumsum(h_rwd) / range(1, NEPISODES))
121 plt.show()
— Q-value networks
Definition: qnet.py:33
def onehot(ix, n=NX)
Definition: qnet.py:58