1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import unicode_literals
3
+
4
+ from collections import OrderedDict
5
+ from importlib import import_module
6
+ from django .conf import settings
7
+ from elasticsearch import Elasticsearch , helpers , TransportError
8
+
9
+ from .repository import BaseRepository
10
+
11
+ ADD , UPDATE , REMOVE = range (3 )
12
+
13
+
14
+ def without (keys , dct , move_up = None ):
15
+ """Returns dictionary without listed keys
16
+
17
+ Optionally can move up keys from nested dictionary to parent before removing key.
18
+ :param keys: list of keys to remove
19
+ :param dct: dictionary to perform removing
20
+ :param move_up: definiton of keys which should be moved to parent
21
+ """
22
+ _dct = dct .copy ()
23
+ if move_up :
24
+ for k , v in move_up .iteritems ():
25
+ for moved_key in v :
26
+ _dct [moved_key ] = _dct [k ][moved_key ]
27
+ return {k : v for k , v in _dct .iteritems () if k not in keys }
28
+
29
+
30
+ class RepositoryError (Exception ):
31
+ def __init__ (self , message , cause = None ):
32
+ # Bu, exceptions chaining is avaliable only in py3.
33
+ super (RepositoryError , self ).__init__ (message + ', caused by ' + repr (cause ))
34
+ self .cause = cause
35
+
36
+
37
+ class EntityNotFound (RepositoryError ):
38
+ pass
39
+
40
+
41
+ class PersistedEntity (object ):
42
+
43
+ def __init__ (self , entity , state = ADD , index = 'default' ):
44
+ self ._initial_value = None
45
+ self ._entity = entity
46
+ if state == UPDATE :
47
+ self .reset_state ()
48
+ self .state = state
49
+ self ._index = index
50
+ self ._registry = {}
51
+
52
+ def get_stmt (self ):
53
+ if self .state == ADD :
54
+ return self ._add ()
55
+ elif self .state == UPDATE :
56
+ return self ._update ()
57
+ elif self .state == REMOVE :
58
+ return self ._remove ()
59
+
60
+ def reset_state (self ):
61
+ self ._initial_value = self ._entity .to_storage ()
62
+ if 'id' in self ._initial_value :
63
+ del self ._initial_value ['id' ]
64
+ self .state = UPDATE
65
+
66
+ def set_id (self , _id ):
67
+ self ._entity ['id' ] = _id
68
+
69
+ def _add (self ):
70
+ source = self ._entity .to_storage ()
71
+ stmt = {
72
+ '_index' : self ._index ,
73
+ '_type' : self ._entity .type ,
74
+ }
75
+ if 'id' in source :
76
+ stmt ['_id' ] = source ['id' ]
77
+ del source ['id' ]
78
+ if '_parent' in self ._entity :
79
+ stmt ['_parent' ] = self ._entity ['_parent' ]
80
+ stmt ['_source' ] = source
81
+ return stmt
82
+
83
+ def _update (self ):
84
+ if 'id' not in self ._entity :
85
+ return None
86
+ diff = self ._update_diff ()
87
+ if not diff :
88
+ return None
89
+ return {
90
+ '_op_type' : 'update' ,
91
+ '_index' : self ._index ,
92
+ '_type' : self ._entity .type ,
93
+ '_id' : self ._entity ['id' ],
94
+ 'doc' : diff
95
+ }
96
+
97
+ def _remove (self ):
98
+ if 'id' not in self ._entity :
99
+ return None
100
+ return {
101
+ '_op_type' : 'delete' ,
102
+ '_index' : self ._index ,
103
+ '_type' : self ._entity .type ,
104
+ '_id' : self ._entity ['id' ],
105
+ }
106
+
107
+ def _update_diff (self ):
108
+ current_state = self ._entity .to_storage ()
109
+ if 'id' in current_state :
110
+ del current_state ['id' ]
111
+ diff = {}
112
+ for k , v in current_state .iteritems ():
113
+ if (k not in self ._initial_value ) or (k in self ._initial_value and v != self ._initial_value [k ]):
114
+ diff [k ] = v
115
+ for k in set (self ._initial_value .keys ()) - set (current_state .keys ()):
116
+ diff [k ] = None
117
+ return diff
118
+
119
+
120
+ class EntityManager (object ):
121
+ def __init__ (self , index = 'default' , es_settings = None ):
122
+ if es_settings :
123
+ self .es = Elasticsearch (** es_settings )
124
+ else :
125
+ self .es = Elasticsearch ()
126
+ self ._index = index
127
+ self ._registry = {}
128
+
129
+ def persist (self , entity ):
130
+ if not hasattr (entity , 'to_storage' ) or not hasattr (entity , '__getitem__' ) or not hasattr (entity , 'type' ):
131
+ raise TypeError ('entity object must have to_storage, type and behave like a dict methods' )
132
+ self ._persist (entity , state = ADD )
133
+
134
+ def remove (self , entity ):
135
+ self ._persist (entity , state = REMOVE )
136
+
137
+ def flush (self ):
138
+ actions = OrderedDict ()
139
+ for pe in self ._registry .itervalues ():
140
+ stmt = pe .get_stmt ()
141
+ if stmt :
142
+ actions [pe ] = stmt
143
+ blk = [result for result in helpers .streaming_bulk (self .es , actions .itervalues ())] # TODO: exceptions?
144
+ for i , pe in enumerate (actions .iterkeys ()):
145
+ if 'create' in blk [i ][1 ]:
146
+ pe .set_id (blk [i ][1 ]['create' ]['_id' ])
147
+ pe .reset_state ()
148
+
149
+ def find (self , _id , _type , scope = None ):
150
+ kwargs = {'id' : _id , 'index' : self ._index , 'doc_type' : _type .get_type ()}
151
+ if scope :
152
+ kwargs ['_source' ] = _type .get_fields (scope )
153
+ try :
154
+ _data = self .es .get (** kwargs )
155
+ except TransportError as e : # TODO: the might be other errors like server unavaliable
156
+ raise EntityNotFound ('Entity {type} {_id} not found.' .format (type = _type .get_type (), _id = _id ), e )
157
+ source = _data ['_source' ]
158
+ source ['id' ] = _data ['_id' ]
159
+ entity = _type (source , scope )
160
+ self ._persist (entity , state = UPDATE )
161
+ return entity
162
+
163
+ def find_many (self , _ids , _type , scope = None ):
164
+ kwargs = {'body' : {'ids' : _ids }, 'index' : self ._index }
165
+ if scope :
166
+ kwargs ['_source' ] = _type .get_fields (scope )
167
+ try :
168
+ _data = self .es .mget (** kwargs )
169
+ except TransportError as e : # TODO: the might be other errors like server unavaliable
170
+ raise EntityNotFound ('Entity {type} {_id} not found.' .format (
171
+ type = _type .get_type (), _id = ', ' .join (_ids )), e )
172
+ entities = []
173
+ for _entity in _data ['docs' ]:
174
+ source = _entity ['_source' ]
175
+ source ['id' ] = _entity ['_id' ]
176
+ entity = _type (source , scope )
177
+ self ._persist (entity , state = UPDATE )
178
+ entities .append (entity )
179
+ return entities
180
+
181
+ def query (self , query , _type , scope = None ):
182
+ kwargs = {}
183
+ if scope :
184
+ kwargs ['_source' ] = _type .get_fields (scope )
185
+ try :
186
+ data = self .es .search (index = self ._index , doc_type = _type .get_type (), body = query .update (kwargs ))
187
+ except TransportError as e :
188
+ raise RepositoryError ('Transport returned error' , cause = e )
189
+ entities = []
190
+ for record in data ['hits' ]['hits' ]:
191
+ source = record ['_source' ]
192
+ source ['id' ] = record ['_id' ]
193
+ source ['_score' ] = record ['_score' ]
194
+ entity = _type (source , scope )
195
+ self ._persist (entity , state = UPDATE )
196
+ entities .append (entity )
197
+ return entities , without (['hits' ], data , move_up = {'hits' : ['max_score' , 'total' ]})
198
+
199
+ def query_one (self , query , _type , scope = None ):
200
+ entities , meta = self .query (query , _type , scope )
201
+ if len (entities ) == 1 :
202
+ return entities [0 ]
203
+ raise RepositoryError ('Expected one result, found {num}' .format (num = len (entities )))
204
+
205
+ def get_repository (self , repository ):
206
+ app , repository_class_name = repository .split (':' )
207
+ if app not in settings .INSTALLED_APPS :
208
+ app = filter (lambda _app : _app .endswith (app ), settings .INSTALLED_APPS )
209
+ if not app :
210
+ raise RepositoryError ('Given application {app} are not in INSTALLED_APPS' .format (app = app ))
211
+ try :
212
+ module = import_module (app + '.' + 'repositories' )
213
+ except ImportError :
214
+ raise RepositoryError ('Given application {app} has no repositories' .format (app = app ))
215
+ if not hasattr (module , repository_class_name ):
216
+ raise RepositoryError ('Given repository {repository_class_name} does not exists in application {app}' .format (
217
+ repository_class_name = repository_class_name , app = app
218
+ ))
219
+ repository_class = getattr (module , repository_class_name )
220
+ if not isinstance (repository_class , BaseRepository ):
221
+ raise RepositoryError ('Custom repository must be subclass of BaseRepository' )
222
+ return repository_class (self )
223
+
224
+ def get_client (self ):
225
+ return self .es
226
+
227
+ def _persist (self , entity , state ):
228
+ if id (entity ) in self ._registry :
229
+ self ._registry [id (entity )].state = state
230
+ else :
231
+ self ._registry [id (entity )] = PersistedEntity (entity , state = state , index = self ._index )
0 commit comments