forked from GooFit/GooFit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
PdfBase.hh
108 lines (90 loc) · 4.31 KB
/
PdfBase.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#ifndef PDF_BASE_HH
#define PDF_BASE_HH
#include "Variable.hh"
#include "GlobalCudaDefines.hh"
#include "FitControl.hh"
#include <set>
#include "BinnedDataSet.hh"
#include "UnbinnedDataSet.hh"
#include <thrust/iterator/constant_iterator.h>
#include <thrust/device_vector.h>
#include <algorithm>
typedef thrust::counting_iterator<int> IndexIterator;
typedef thrust::constant_iterator<fptype*> DataIterator;
typedef thrust::constant_iterator<int> SizeIterator;
typedef thrust::tuple<IndexIterator, DataIterator, SizeIterator> EventTuple;
typedef thrust::zip_iterator<EventTuple> EventIterator;
const int maxParams = 2000;
extern fptype* dev_event_array;
extern fptype host_normalisation[maxParams];
extern fptype host_params[maxParams];
extern unsigned int host_indices[maxParams];
extern int totalParams;
extern int totalConstants;
class PdfBase {
public:
PdfBase (Variable* x, std::string n);
enum Specials {ForceSeparateNorm = 1, ForceCommonNorm = 2};
__host__ virtual double calculateNLL () const = 0;
__host__ virtual fptype normalise () const = 0;
__host__ void initialiseIndices (std::vector<unsigned int> pindices);
typedef std::vector<Variable*> obsCont;
typedef obsCont::iterator obsIter;
typedef obsCont::const_iterator obsConstIter;
typedef std::vector<Variable*> parCont;
typedef parCont::iterator parIter;
typedef parCont::const_iterator parConstIter;
__host__ void addSpecialMask (int m) {specialMask |= m;}
__host__ void copyParams (const std::vector<double>& pars) const;
__host__ void copyParams ();
__host__ void copyNormFactors () const;
__host__ void generateNormRange ();
__host__ std::string getName () const {return name;}
__host__ virtual void getObservables (obsCont& ret) const;
__host__ virtual void getParameters (parCont& ret) const;
__host__ Variable* getParameterByName (string n) const;
__host__ int getSpecialMask () const {return specialMask;}
__host__ void setData (BinnedDataSet* data);
__host__ void setData (UnbinnedDataSet* data);
__host__ void setData (std::vector<std::map<Variable*, fptype> >& data);
__host__ virtual void setFitControl (FitControl* const fc, bool takeOwnerShip = true) = 0;
__host__ virtual bool hasAnalyticIntegral () const {return false;}
__host__ unsigned int getFunctionIndex () const {return functionIdx;}
__host__ unsigned int getParameterIndex () const {return parameters;}
__host__ unsigned int registerParameter (Variable* var);
__host__ unsigned int registerConstants (unsigned int amount);
__host__ virtual void recursiveSetNormalisation (fptype norm = 1) const;
__host__ void unregisterParameter (Variable* var);
__host__ void registerObservable (Variable* obs);
__host__ void setIntegrationFineness (int i);
__host__ void printProfileInfo (bool topLevel = true);
__host__ bool parametersChanged () const;
__host__ void storeParameters () const;
__host__ obsIter obsBegin () {return observables.begin();}
__host__ obsIter obsEnd () {return observables.end();}
__host__ obsConstIter obsCBegin () const {return observables.begin();}
__host__ obsConstIter obsCEnd () const {return observables.end();}
__host__ void checkInitStatus (std::vector<std::string>& unInited) const;
void clearCurrentFit ();
__host__ void SigGenSetIndices(){setIndices();}
protected:
fptype numEvents; // Non-integer to allow weighted events
unsigned int numEntries; // Eg number of bins - not always the same as number of events, although it can be.
fptype* normRanges; // This is specific to functor instead of variable so that MetricTaker::operator needn't use indices.
unsigned int parameters; // Stores index, in 'paramIndices', where this functor's information begins.
unsigned int cIndex; // Stores location of constants.
obsCont observables;
parCont parameterList;
FitControl* fitControl;
std::vector<PdfBase*> components;
int integrationBins;
int specialMask; // For storing information unique to PDFs, eg "Normalise me separately" for TddpPdf.
mutable fptype* cachedParams;
bool properlyInitialised; // Allows checking for required extra steps in, eg, Tddp and Convolution.
unsigned int functionIdx; // Stores index of device function pointer.
private:
std::string name;
__host__ void recursiveSetIndices ();
__host__ void setIndices ();
};
#endif