io-chess
UCI chess engine
Loading...
Searching...
No Matches
MCTS.h
Go to the documentation of this file.
1#pragma once
2
10
11#include "../eval/IEvaluator.h"
12#include "ISearch.h"
13#include <cmath>
14#include <memory>
15#include <vector>
16
21struct MCTSNode {
23 MCTSNode *parent = nullptr;
24 std::vector<std::unique_ptr<MCTSNode>> children;
25
26 int visits = 0;
27 float valueSum = 0.0f;
28 bool expanded = false;
29
30 MCTSNode() = default;
31 MCTSNode(Move m, MCTSNode *p) : move(m), parent(p) {}
32
39 float ucb(float explorationConstant) const {
40 if (visits == 0)
41 return std::numeric_limits<float>::infinity();
42 float avgValue = valueSum / visits;
43 float exploration =
44 explorationConstant * std::sqrt(std::log(parent->visits) / visits);
45 return avgValue + exploration;
46 }
47
53 float meanValue() const { return visits > 0 ? valueSum / visits : 0.0f; }
54};
55
63class MCTS : public ISearch {
64private:
66
67 std::unique_ptr<MCTSNode> root_;
68 std::atomic<bool> stopFlag_{false};
69 std::atomic<bool> searching_{false};
70 std::atomic<uint64_t> nodes_{0};
71
72 // MCTS parameters
73 float explorationConstant_ = 1.41f;
74 int maxPlayouts_ = 100000;
75
77
78public:
79 MCTS(IEvaluator &eval);
80
81 // ISearch interface
82 Move startSearch(Board &root, const SearchParams &params) override;
83 void stop() override { stopFlag_ = true; }
84 bool isSearching() const override { return searching_; }
85 void setInfoCallback(InfoCallback callback) override { infoCallback_ = callback; }
86 uint64_t getNodes() const override { return nodes_; }
87
92
93private:
94 // Core MCTS phases
95 MCTSNode *selection(MCTSNode *node, Board &board);
96 void expansion(MCTSNode *node, Board &board);
97 float simulation(Board &board);
98 void backpropagation(MCTSNode *node, float value);
99
100 // Utilities
101 MCTSNode *selectBestChild(MCTSNode *node) const;
102 MCTSNode *selectUCBChild(MCTSNode *node) const;
103};
Abstract interface for board evaluation algorithms.
Search interface and shared data structures.
std::function< void(int depth, int score, int nodes, int nps, const std::vector< Move > &pv)> InfoCallback
Callback function type for sending UCI 'info' updates during search.
Definition ISearch.h:145
chess::Move Move
Alias for chess::Move.
Definition Types.h:15
chess::Board Board
Alias for chess::Board.
Definition Types.h:14
Abstract interface for evaluators.
Definition IEvaluator.h:25
Abstract interface for search algorithms.
Definition ISearch.h:155
MCTSNode * selectBestChild(MCTSNode *node) const
Definition MCTS.cpp:153
void expansion(MCTSNode *node, Board &board)
Definition MCTS.cpp:112
std::atomic< bool > stopFlag_
Signals the search to stop.
Definition MCTS.h:68
bool isSearching() const override
Checks if the search algorithm is currently running.
Definition MCTS.h:84
MCTSNode * selectUCBChild(MCTSNode *node) const
Definition MCTS.cpp:171
MCTSNode * selection(MCTSNode *node, Board &board)
Definition MCTS.cpp:104
void backpropagation(MCTSNode *node, float value)
Definition MCTS.cpp:142
void setInfoCallback(InfoCallback callback) override
Sets the callback function for periodic information updates.
Definition MCTS.h:85
void stop() override
Asynchronously signals the search to stop immediately.
Definition MCTS.h:83
uint64_t getNodes() const override
Retrieves the total number of nodes evaluated by the search algorithm.
Definition MCTS.h:86
std::unique_ptr< MCTSNode > root_
Root node of the search tree.
Definition MCTS.h:67
std::atomic< bool > searching_
True while a search is active.
Definition MCTS.h:69
int maxPlayouts_
Hard limit on search iterations.
Definition MCTS.h:74
float explorationConstant_
Weight of exploration in UCB formula.
Definition MCTS.h:73
MCTS(IEvaluator &eval)
Definition MCTS.cpp:10
Move startSearch(Board &root, const SearchParams &params) override
Starts the search process on the given root board.
Definition MCTS.cpp:12
float simulation(Board &board)
Definition MCTS.cpp:130
IEvaluator & evalCtx_
The evaluator used for leaf node evaluation.
Definition MCTS.h:65
InfoCallback infoCallback_
Callback for UCI info updates.
Definition MCTS.h:76
void setExplorationConstant(float c)
Sets the exploration constant (Cp) used in the UCB calculation.
Definition MCTS.h:91
std::atomic< uint64_t > nodes_
Total nodes visited.
Definition MCTS.h:70
Represents a single node in the Monte Carlo Search Tree.
Definition MCTS.h:21
int visits
Number of times this node has been visited.
Definition MCTS.h:26
float valueSum
Cumulative value from all rollouts/evaluations passing through this node.
Definition MCTS.h:27
MCTSNode()=default
MCTSNode(Move m, MCTSNode *p)
Definition MCTS.h:31
float ucb(float explorationConstant) const
Calculates the Upper Confidence Bound (UCB1) for this node.
Definition MCTS.h:39
std::vector< std::unique_ptr< MCTSNode > > children
Child nodes.
Definition MCTS.h:24
bool expanded
True if the node's children have been generated.
Definition MCTS.h:28
Move move
Move that led to this node.
Definition MCTS.h:22
float meanValue() const
Calculates the mean expected value of the node.
Definition MCTS.h:53
MCTSNode * parent
Pointer to the parent node.
Definition MCTS.h:23
Parameters defining the constraints for a search operation.
Definition ISearch.h:125