Reinforcement Learning: How to Train an RL Agent from Scratch
Authors: Sofie Verrewaere, Senior Data Scientist, Hiru Ranasinghe, Machine Learning Engineer, Daniel Miskell, Machine Learning Engineer, Ollie Blackman, AI + Analytics Intern
In the first blog post of our Reinforcement Learning (RL) series, we gave an introduction to the core concepts of RL, and guidance for identifying suitable use cases to show you where RL can provide business value across a range of industries.
This second blog will introduce you to a specific type of RL, called Q-learning, and show you how to code your own RL agent using the example of the game, Catch.
The Game of Catch
The game of Catch is an ideal example to illustrate the inner workings of RL as only a few hundred lines of code are required. In this game, the RL agent tries to catch falling fruit (a single pixel) in a basket (three pixels-wide). The basket can only move horizontally. The agent is successful if it can catch the fruit; if it misses the fruit, the game is lost. This game can be translated into an RL framework:
- The agent moves a basket trying to catch the falling fruit.
- The environment is a 10×10 grid consisting of fruit dropped from the top of the grid in a random column, falling over 10 consecutive time steps. The agent’s basket consists of 3 cells that move laterally at the bottom row of the grid.
- The state of the environment is defined as a 3-dimensional tuple: [fruit row position, fruit column position, basket column position].
- The three actions are: moving the basket left, right, or staying stationary.
- The reward is +1 if the basket catches the fruit, -1 if the basket misses the fruit and 0 any other time (when the fruit is still falling).
This problem lends itself well to being solved by RL as we can create an algorithm where the agent finds a way to maximise its score without us giving it any indication of how the game should be played.
In this example, we will use a specific type of RL called Q-learning.
What is Q-Learning?
In Q-learning, a Q-value for each action and state is learned. These Q-values represent how useful an action is in a given state for gaining future rewards.
We can subdivide RL algorithms according to how the agent learns, the agent’s representation of the environment, the agent’s behaviour function, and whether the agent waits to learn until the end of the game or not. In the figure below you can see how Q-learning resides within the RL Landscape.
Learning Methods of the Agent
We’ll start with how the agent learns. There are two main approaches to solving RL problems:
- Policy-based methods: explicitly building a representation of the policy. This is a probabilistic mapping from the states to a set of valid actions that the agent keeps in memory whilst learning.
- Value-based methods: not storing any explicit policy, only a value function.
There is also a hybrid, 3. actor-critic approach, which employs both value and policy functions. In our code example, we use a value-based approach, using the Q-value.
The Environment’s Representation
The agent’s representation of the environment can be modelled using a model-based approach or, when not using any predictions of the environment, can be model-free. Our catch example is a model-free approach as no model represents the agent’s environment; only the game interface is used.
The Agent’s Behaviour
The agent can behave either off-policy or on-policy. In our example, the current policy is updated from experiences generated by a previous policy. This means we are dealing with an off-policy learner, as it does not learn the optimal Q-value directly from the agent’s actions.
When to Update the Estimates
In this code example, the estimates refer to the estimates of the Q-values (see Fig 11). One option would be to update the estimates once the final outcome is known (i.e. at the end of the game); these are called Monte Carlo methods. A second option is to adjust predictions to match later, more accurate, predictions about the future before the final outcome is known. This latter option is known as temporal difference learning.
In our implementation of RL to solve the game of Catch, the Q-value estimates are learned by sampling transitions stored in an experience replay buffer. The value functions are updated during the game using temporal differencing.
To summarise all of this, Q-Learning is a value-based, model-free, off-policy, temporal difference algorithm.
The Experience Replay Buffer
When using Q-learning, we can also use an experience replay buffer. Generating new experiences can be difficult, time-consuming, and expensive.
The experience replay buffer allows us to minimise the number of experiences we need to generate to train a robust and high-performing agent by reusing previous ones. Without an experience replay buffer, the agent would only update its Q-values based on the most recent transition. This would require much more experience to converge.
The system stores a history of the agent’s experiences. An experience (or transition) consists of the environment state and the action at a given time step. It also contains the received reward, the next state and an optional done signal.
The learning phase is based on taking random samples from the experience replay. This improves the sample efficiency and stability of the training, making the experience replay a crucial component in the learning process of off-policy deep RL.
Reinforcement Learning Catch Code Walkthrough
The implementation of this game can be found in this GitHub repo.
The codebase comprises four python files defining the environment, the experience replay, the training of the agent and running a trained agent. Let’s dive into the python code of our RL model!
Environment Definition (env.py)
The Catch environment is defined as a 10×10 pixel grid and includes the methods of interacting with it. At each time step, we iterate through the agent’s action, the state transition, and the reward cycle. We redraw the environment upon each state change.
When the agent wants to interact with the environment, it will call the ‘act’ function (see Fig. 5). The action must be specified as a parameter when calling the ‘act’ function. The action is used to update the state of the environment, presenting us with a new reward. Every time an action is taken, we need to check whether the fruit reached the bottom of the grid at the end of the game.
When the ‘act’ function returns the updated values, the ‘observe’ function is called to redraw the 10×10 pixel grid to show the new state.
Additionally, the environment file defines a ‘reset’ function which resets the environment state. The fruit row is reset to the top of the grid (position 0) and the fruit column and basket position are randomly initialised. Since the fruit basket is three pixels wide and defined by its central position, it cannot be initialised in either position 0 or 9.
Experience Replay Definition (experience_replay.py)
The experience replay file contains a helper class that defines the experience replay buffer and its associated methods. The experience replay buffer stores previous state-action interactions. A sample of these interactions will be used to train a Deep Learning model to make better decisions on actions to take going forward.
We define an experience as a tuple of the previous state, the action taken, the next state transitioned to, the reward received from transitioning to the next state, and the signal that it is done.
We use a simple python list as our experience replay buffer. The experience replay buffer has a maximum size. Any experiences added to a full experience replay buffer will replace the oldest experience. This functionality is defined in the ‘add_experience’ method.
There is also a ‘get_qlearning_batch’ function that will randomly sample the experience replay buffer and build a training set for the Keras model. The input features are a flattened representation of the grid pixels, representing the current state, with the target being the Q-value of each state-action pair.
Note that the ‘get_qlearning_batch’ method of the experience buffer is specific to Q-learning.
Train File (train.py)
The Q-learning training file defines the Keras model and the agent training loop.
We define a simple Neural Net, with two hidden layers, that takes the grid pixels of the current state as its input, and then outputs the predicted Q-values for each action.
We also define instances of the environment and the experience replay buffer:
An epoch of the training loop is defined by running through an iteration of the catch game and storing the state experiences and rewards received in the experience replay buffer as the game is played. After each action selection, we create a batch of training data by sampling the experience replay buffer and training the model on these experiences.
We follow an epsilon-greedy strategy when deciding which action to take. This means that at least 90% of the time the agent chooses what it believes to be the best action, but 10% of the time a random action is taken to see what happens. We can change the ratio of exploration by changing the ‘epsilon’ variable.
In the next game (epoch), the model is then used to run another game of catch, producing more experiences with slightly different, and hopefully improved, action decisions. The loop is repeated until all epochs are run. The resulting model artefact is saved.
Inference File (run.py)
The inference file defines how a trained model can be used to show a reasonable game of catch. The trained model should predict more optimal Q-values, allowing the agent to make better decisions when moving the basket and therefore should be more successful at catching the fruit.
At each time step:
- We pass the grid state to the model
- The model returns the estimated Q-values for each action
- The action with the highest Q-value is chosen and sent to the environment
- The environment returns the new state and the reward received
- The new state is drawn on the canvas
Summary
Following the code above, you should be able to create a Reinforcement Learning model successfully. While we’ve used RL to develop a model to play a relatively simple game, RL has several impactful real-world applications.
As 4x Google Cloud Partner of the Year, Datatonic has a wealth of experience in both Machine Learning and a range of Google products and services. Get in touch to learn how your business can benefit from Reinforcement Learning or other Machine Learning models.
Further Reading
If you want to learn more about Reinforcement learning, take a look at these resources:
- Coursera Course on Unsupervised Learning, Recommenders, Reinforcement Learning
- Stanford Machine Learning Course taught by Andrew Ng (from the middle of Lecture 16 onwards)
- UCL Course on Reinforcement Learning taught by David Silver