pinocchio  3.7.0
A fast and flexible implementation of Rigid Body Dynamics algorithms and their analytical derivatives
 
Loading...
Searching...
No Matches
qnet.py
1"""
2Example of Q-table learning with a simple discretized 1-pendulum environment using a
3linear Q network.
4"""
5
6import signal
7import time
8
9import matplotlib.pyplot as plt
10import numpy as np
11import tensorflow as tf
12from dpendulum import DPendulum
13
14# --- Random seed
15RANDOM_SEED = int((time.time() % 10) * 1000)
16print(f"Seed = {RANDOM_SEED}")
17np.random.seed(RANDOM_SEED)
18tf.set_random_seed(RANDOM_SEED)
19
20# --- Hyper paramaters
21NEPISODES = 500 # Number of training episodes
22NSTEPS = 50 # Max episode length
23LEARNING_RATE = 0.1 # Step length in optimizer
24DECAY_RATE = 0.99 # Discount factor
25
26# --- Environment
27env = DPendulum()
28NX = env.nx
29NU = env.nu
30
31
32# --- Q-value networks
34 def __init__(self):
35 x = tf.placeholder(shape=[1, NX], dtype=tf.float32)
36 W = tf.Variable(tf.random_uniform([NX, NU], 0, 0.01, seed=100))
37 qvalue = tf.matmul(x, W)
38 u = tf.argmax(qvalue, 1)
39
40 qref = tf.placeholder(shape=[1, NU], dtype=tf.float32)
41 loss = tf.reduce_sum(tf.square(qref - qvalue))
42 optim = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss)
43
44 self.x = x # Network input
45 self.qvalue = qvalue # Q-value as a function of x
46 self.u = u # Policy as a function of x
47 # Reference Q-value at next step (to be set to l+Q o f)
48 self.qref = qref
49 self.optim = optim # Optimizer
50
51
52# --- Tensor flow initialization
53tf.reset_default_graph()
54qvalue = QValueNetwork()
55sess = tf.InteractiveSession()
56tf.global_variables_initializer().run()
57
58
59def onehot(ix, n=NX):
60 """Return a vector which is 0 everywhere except index <i> set to 1."""
61 return np.array(
62 [
63 [(i == ix) for i in range(n)],
64 ],
65 np.float,
66 )
67
68
69def disturb(u, i):
70 u += int(np.random.randn() * 10 / (i / 50 + 10))
71 return np.clip(u, 0, NU - 1)
72
73
74def rendertrial(maxiter=100):
75 x = env.reset()
76 for i in range(maxiter):
77 u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})
78 x, r = env.step(u)
79 env.render()
80 if r == 1:
81 print("Reward!")
82 break
83
84
85signal.signal(
86 signal.SIGTSTP, lambda x, y: rendertrial()
87) # Roll-out when CTRL-Z is pressed
88
89# --- History of search
90h_rwd = [] # Learning history (for plot).
91
92# --- Training
93for episode in range(1, NEPISODES):
94 x = env.reset()
95 rsum = 0.0
96
97 for step in range(NSTEPS - 1):
98 # Greedy policy ...
99 u = sess.run(qvalue.u, feed_dict={qvalue.x: onehot(x)})[0]
100 u = disturb(u, episode) # ... with noise
101 x2, reward = env.step(u)
102
103 # Compute reference Q-value at state x respecting HJB
104 Q2 = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x2)})
105 Qref = sess.run(qvalue.qvalue, feed_dict={qvalue.x: onehot(x)})
106 Qref[0, u] = reward + DECAY_RATE * np.max(Q2)
107
108 # Update Q-table to better fit HJB
109 sess.run(qvalue.optim, feed_dict={qvalue.x: onehot(x), qvalue.qref: Qref})
110
111 rsum += reward
112 x = x2
113 if reward == 1:
114 break
115
116 h_rwd.append(rsum)
117 if not episode % 20:
118 print(f"Episode #{episode} done with {sum(h_rwd[-20:])} sucess")
119
120print(f"Total rate of success: {sum(h_rwd) / NEPISODES:.3f}")
121rendertrial()
122plt.plot(np.cumsum(h_rwd) / range(1, NEPISODES))
123plt.show()
onehot(ix, n=NX)
Definition qnet.py:59