Serenity Operating System
1/*
2 * Copyright (c) 2020, the SerenityOS developers.
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#include "MCTSTree.h"
8#include <AK/DeprecatedString.h>
9#include <stdlib.h>
10
11MCTSTree::MCTSTree(Chess::Board const& board, MCTSTree* parent)
12 : m_parent(parent)
13 , m_board(make<Chess::Board>(board))
14 , m_last_move(board.last_move())
15 , m_turn(board.turn())
16{
17}
18
19MCTSTree::MCTSTree(MCTSTree&& other)
20 : m_children(move(other.m_children))
21 , m_parent(other.m_parent)
22 , m_white_points(other.m_white_points)
23 , m_simulations(other.m_simulations)
24 , m_board(move(other.m_board))
25 , m_last_move(move(other.m_last_move))
26 , m_turn(other.m_turn)
27 , m_moves_generated(other.m_moves_generated)
28{
29 other.m_parent = nullptr;
30}
31
32MCTSTree& MCTSTree::select_leaf()
33{
34 if (!expanded() || m_children.size() == 0)
35 return *this;
36
37 MCTSTree* node = nullptr;
38 double max_uct = -double(INFINITY);
39 for (auto& child : m_children) {
40 double uct = child->uct(m_turn);
41 if (uct >= max_uct) {
42 max_uct = uct;
43 node = child;
44 }
45 }
46 VERIFY(node);
47 return node->select_leaf();
48}
49
50MCTSTree& MCTSTree::expand()
51{
52 VERIFY(!expanded() || m_children.size() == 0);
53
54 if (!m_moves_generated) {
55 m_board->generate_moves([&](Chess::Move chess_move) {
56 auto clone = m_board->clone_without_history();
57 clone.apply_move(chess_move);
58 m_children.append(make<MCTSTree>(move(clone), this));
59 return IterationDecision::Continue;
60 });
61 m_moves_generated = true;
62 if (m_children.size() != 0)
63 m_board = nullptr; // Release the board to save memory.
64 }
65
66 if (m_children.size() == 0) {
67 return *this;
68 }
69
70 for (auto& child : m_children) {
71 if (child->m_simulations == 0) {
72 return *child;
73 }
74 }
75 VERIFY_NOT_REACHED();
76}
77
78int MCTSTree::simulate_game() const
79{
80 Chess::Board clone = *m_board;
81 while (!clone.game_finished()) {
82 clone.apply_move(clone.random_move());
83 }
84 return clone.game_score();
85}
86
87int MCTSTree::heuristic() const
88{
89 if (m_board->game_finished())
90 return m_board->game_score();
91
92 double winchance = max(min(double(m_board->material_imbalance()) / 6, 1.0), -1.0);
93
94 double random = double(rand()) / RAND_MAX;
95 if (winchance >= random)
96 return 1;
97 if (winchance <= -random)
98 return -1;
99
100 return 0;
101}
102
103void MCTSTree::apply_result(int game_score)
104{
105 m_simulations++;
106 m_white_points += game_score;
107
108 if (m_parent)
109 m_parent->apply_result(game_score);
110}
111
112void MCTSTree::do_round()
113{
114
115 // Note: Limit expansion to spare some memory
116 // Efficient Selectivity and Backup Operators in Monte-Carlo Tree Search.
117 // Rémi Coulom.
118 auto* node_ptr = &select_leaf();
119 if (node_ptr->m_simulations > s_number_of_visit_parameter)
120 node_ptr = &select_leaf().expand();
121
122 auto& node = *node_ptr;
123
124 int result;
125 if constexpr (s_eval_method == EvalMethod::Simulation) {
126 result = node.simulate_game();
127 } else {
128 result = node.heuristic();
129 }
130 node.apply_result(result);
131}
132
133Optional<MCTSTree&> MCTSTree::child_with_move(Chess::Move chess_move)
134{
135 for (auto& node : m_children) {
136 if (node->last_move() == chess_move)
137 return *node;
138 }
139 return {};
140}
141
142MCTSTree& MCTSTree::best_node()
143{
144 int score_multiplier = (m_turn == Chess::Color::White) ? 1 : -1;
145
146 MCTSTree* best_node_ptr = nullptr;
147 double best_score = -double(INFINITY);
148 VERIFY(m_children.size());
149 for (auto& node : m_children) {
150 double node_score = node->expected_value() * score_multiplier;
151 if (node_score >= best_score) {
152 best_node_ptr = node;
153 best_score = node_score;
154 }
155 }
156 VERIFY(best_node_ptr);
157
158 return *best_node_ptr;
159}
160
161Chess::Move MCTSTree::last_move() const
162{
163 return m_last_move.value();
164}
165
166double MCTSTree::expected_value() const
167{
168 if (m_simulations == 0)
169 return 0;
170
171 return double(m_white_points) / m_simulations;
172}
173
174double MCTSTree::uct(Chess::Color color) const
175{
176 // UCT: Upper Confidence Bound Applied to Trees.
177 // Kocsis, Levente; Szepesvári, Csaba (2006). "Bandit based Monte-Carlo Planning"
178
179 // Fun fact: Szepesvári was my data structures professor.
180 double expected = expected_value() * ((color == Chess::Color::White) ? 1 : -1);
181 return expected + s_exploration_parameter * sqrt(log(m_parent->m_simulations) / m_simulations);
182}
183
184bool MCTSTree::expanded() const
185{
186 if (!m_moves_generated)
187 return false;
188
189 for (auto& child : m_children) {
190 if (child->m_simulations == 0)
191 return false;
192 }
193
194 return true;
195}