-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathuct.h
126 lines (106 loc) · 4.39 KB
/
uct.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#pragma once
#include <vector>
#include <memory>
#include <map>
#include <cstdint>
#include <random>
#include <boost/range/join.hpp>
#include <giolib/memory.h>
#include "utils/utils.h"
#include "parsing/parser.h"
#include "parsing/unif.h"
#include "mm/library.h"
#include "mm/toolbox.h"
#include "mm/engine.h"
class SentenceNode;
class StepNode;
class UCTProver;
enum VisitResult {
PROVED,
CONTINUE,
DEAD,
};
class UCTProver : public gio::virtual_enable_create< UCTProver > {
public:
VisitResult visit();
const std::vector< ParsingTree2< SymTok, LabTok > > &get_hypotheses() const;
const LibraryToolbox &get_toolbox() const;
temp_allocator &get_temp_allocator();
std::ranlux48 &get_rand();
const std::set<std::pair<LabTok, LabTok> > &get_antidists() const;
void replay_proof(CheckpointedProofEngine &engine) const;
bool is_assertion_useful(const Assertion &ass) const;
const std::unordered_map< LabTok, std::vector< LabTok > > &get_root_useful_asses() const;
const std::unordered_map< LabTok, std::vector< LabTok > > &get_imp_con_useful_asses() const;
void set_children_callbacks(std::vector< std::function< void() > > &&children_callbacks);
std::function< void() > get_children_callback(size_t idx) const;
protected:
UCTProver(const LibraryToolbox &tb, const ParsingTree2< SymTok, LabTok > &thesis, const std::vector< ParsingTree2< SymTok, LabTok > > &hypotheses, const std::set< std::pair< LabTok, LabTok > > &antidists = {});
~UCTProver();
void init();
private:
void compute_useful_assertions();
std::shared_ptr< SentenceNode > root;
std::set< std::pair< LabTok, LabTok > > antidists;
const LibraryToolbox &tb;
temp_stacked_allocator tsa;
ParsingTree2< SymTok, LabTok > thesis;
std::vector< ParsingTree2< SymTok, LabTok > > hypotheses;
std::unordered_map< LabTok, std::vector< LabTok > > root_useful_asses;
std::unordered_map< LabTok, std::vector< LabTok > > imp_con_useful_asses;
std::ranlux48 rand;
std::vector< std::function< void() > > children_callbacks;
};
class SentenceNode : public gio::virtual_enable_create< SentenceNode > {
public:
VisitResult visit();
float get_value();
uint32_t get_visit_num();
std::weak_ptr< StepNode > get_parent();
const ParsingTree2< SymTok, LabTok > &get_sentence();
void replay_proof(CheckpointedProofEngine &engine) const;
protected:
SentenceNode(std::weak_ptr< UCTProver > uct, std::weak_ptr< StepNode > parent, const ParsingTree2< SymTok, LabTok > &sentence);
~SentenceNode();
private:
bool check_subst_map(const SubstMap2< SymTok, LabTok > &subst_map, const Assertion &ass);
std::weak_ptr< UCTProver > uct;
std::vector< std::shared_ptr< StepNode > > children;
std::weak_ptr< StepNode > parent;
ParsingTree2< SymTok, LabTok > sentence;
uint32_t visit_num = 0;
bool exhausted = false;
size_t hyp_num = 0;
float value = 0.0;
float total_children_value = 0.0;
boost::range::joined_range< const std::vector< LabTok >, const std::vector< LabTok > > ass_range;
boost::range::joined_range< const std::vector< LabTok >, const std::vector< LabTok > >::iterator ass_it;
};
class StepNode : public gio::virtual_enable_create< StepNode > {
public:
VisitResult visit();
float get_value() const;
uint32_t get_visit_num() const;
std::weak_ptr< SentenceNode > get_parent() const;
void replay_proof(CheckpointedProofEngine &engine) const;
const std::map< LabTok, gio::safe_weak_ptr< StepNode > > &get_open_vars() const;
protected:
StepNode(std::weak_ptr< UCTProver > uct, std::weak_ptr< SentenceNode > parent, LabTok label, const SubstMap2< SymTok, LabTok > &const_subst_map, const std::map< LabTok, gio::safe_weak_ptr< StepNode > > &open_vars);
~StepNode();
private:
VisitResult create_child(const ParsingTree2< SymTok, LabTok > &sent);
VisitResult create_children();
VisitResult visit_child(size_t i);
std::weak_ptr< UCTProver > uct;
std::vector< std::shared_ptr< SentenceNode > > children;
std::weak_ptr< SentenceNode > parent;
std::vector< std::shared_ptr< SentenceNode > > active_children;
size_t worst_child = 0;
LabTok label;
SubstMap2< SymTok, LabTok > const_subst_map;
SubstMap2< SymTok, LabTok > unconst_subst_map;
bool exhausted = false;
/*float value = 0.0;
uint32_t visit_num = 0;*/
std::map< LabTok, gio::safe_weak_ptr< StepNode > > open_vars;
};