@@ -180,22 +180,40 @@ async def send(self, channel, message):
180180 """
181181 Send a message onto a (general or specific) channel.
182182 """
183+ await self .send_bulk (channel , (message ,))
184+
185+ async def send_bulk (self , channel , messages ):
186+ """
187+ Send multiple messages in bulk onto a (general or specific) channel.
188+ The `messages` argument should be an iterable of dicts.
189+ """
190+
183191 # Typecheck
184- assert isinstance (message , dict ), "message is not a dict"
185192 assert self .valid_channel_name (channel ), "Channel name not valid"
186- # Make sure the message does not contain reserved keys
187- assert "__asgi_channel__" not in message
193+
188194 # If it's a process-local channel, strip off local part and stick full name in message
189195 channel_non_local_name = channel
190- if "!" in channel :
191- message = dict (message .items ())
192- message ["__asgi_channel__" ] = channel
196+ process_local = "!" in channel
197+ if process_local :
193198 channel_non_local_name = self .non_local_name (channel )
199+
200+ now = time .time ()
201+ mapping = {}
202+ for message in messages :
203+ assert isinstance (message , dict ), "message is not a dict"
204+ # Make sure the message does not contain reserved keys
205+ assert "__asgi_channel__" not in message
206+ if process_local :
207+ message = dict (message .items ())
208+ message ["__asgi_channel__" ] = channel
209+
210+ mapping [self .serialize (message )] = now
211+
194212 # Write out message into expiring key (avoids big items in list)
195213 channel_key = self .prefix + channel_non_local_name
196214 # Pick a connection to the right server - consistent for specific
197215 # channels, random for general channels
198- if "!" in channel :
216+ if process_local :
199217 index = self .consistent_hash (channel )
200218 else :
201219 index = next (self ._send_index_generator )
@@ -207,13 +225,13 @@ async def send(self, channel, message):
207225
208226 # Check the length of the list before send
209227 # This can allow the list to leak slightly over capacity, but that's fine.
210- if await connection .zcount (channel_key , "-inf" , "+inf" ) >= self . get_capacity (
211- channel
212- ):
228+ current_length = await connection .zcount (channel_key , "-inf" , "+inf" )
229+
230+ if current_length + len ( messages ) > self . get_capacity ( channel ):
213231 raise ChannelFull ()
214232
215233 # Push onto the list then set it to expire in case it's not consumed
216- await connection .zadd (channel_key , { self . serialize ( message ): time . time ()} )
234+ await connection .zadd (channel_key , mapping )
217235 await connection .expire (channel_key , int (self .expiry ))
218236
219237 def _backup_channel_name (self , channel ):
@@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
517535 connection = self .connection (self .consistent_hash (group ))
518536 await connection .zrem (key , channel )
519537
520- async def group_send (self , group , message ):
521- """
522- Sends a message to the entire group.
523- """
538+ async def _get_group_connection_and_channels (self , group ):
524539 assert self .valid_group_name (group ), "Group name not valid"
525540 # Retrieve list of all channel names
526541 key = self ._group_key (group )
@@ -532,11 +547,36 @@ async def group_send(self, group, message):
532547
533548 channel_names = [x .decode ("utf8" ) for x in await connection .zrange (key , 0 , - 1 )]
534549
550+ return connection , channel_names
551+
552+ async def _exec_group_lua_script (
553+ self , conn_idx , group , channel_redis_keys , channel_names , script , args
554+ ):
555+ # channel_keys does not contain a single redis key more than once
556+ connection = self .connection (conn_idx )
557+ channels_over_capacity = await connection .eval (
558+ script , len (channel_redis_keys ), * channel_redis_keys , * args
559+ )
560+ if channels_over_capacity > 0 :
561+ logger .info (
562+ "%s of %s channels over capacity in group %s" ,
563+ channels_over_capacity ,
564+ len (channel_names ),
565+ group ,
566+ )
567+
568+ async def group_send (self , group , message ):
569+ """
570+ Sends a message to the entire group.
571+ """
572+
573+ connection , channel_names = await self ._get_group_connection_and_channels (group )
574+
535575 (
536576 connection_to_channel_keys ,
537577 channel_keys_to_message ,
538578 channel_keys_to_capacity ,
539- ) = self ._map_channel_keys_to_connection (channel_names , message )
579+ ) = self ._map_channel_keys_to_connection (channel_names , ( message ,) )
540580
541581 for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
542582 # Discard old messages based on expiry
@@ -569,7 +609,7 @@ async def group_send(self, group, message):
569609
570610 # We need to filter the messages to keep those related to the connection
571611 args = [
572- channel_keys_to_message [channel_key ]
612+ channel_keys_to_message [channel_key ][ 0 ]
573613 for channel_key in channel_redis_keys
574614 ]
575615
@@ -581,20 +621,88 @@ async def group_send(self, group, message):
581621
582622 args += [time .time (), self .expiry ]
583623
584- # channel_keys does not contain a single redis key more than once
585- connection = self .connection (connection_index )
586- channels_over_capacity = await connection .eval (
587- group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
624+ await self ._exec_group_lua_script (
625+ connection_index ,
626+ group ,
627+ channel_redis_keys ,
628+ channel_names ,
629+ group_send_lua ,
630+ args ,
588631 )
589- if channels_over_capacity > 0 :
590- logger .info (
591- "%s of %s channels over capacity in group %s" ,
592- channels_over_capacity ,
593- len (channel_names ),
594- group ,
632+
633+ async def group_send_bulk (self , group , messages ):
634+ """
635+ Sends multiple messages in bulk to the entire group.
636+ The `messages` argument should be an iterable of dicts.
637+ """
638+
639+ connection , channel_names = await self ._get_group_connection_and_channels (group )
640+
641+ (
642+ connection_to_channel_keys ,
643+ channel_keys_to_message ,
644+ channel_keys_to_capacity ,
645+ ) = self ._map_channel_keys_to_connection (channel_names , messages )
646+
647+ for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
648+ # Discard old messages based on expiry
649+ pipe = connection .pipeline ()
650+ for key in channel_redis_keys :
651+ pipe .zremrangebyscore (
652+ key , min = 0 , max = int (time .time ()) - int (self .expiry )
595653 )
654+ await pipe .execute ()
655+
656+ # Create a LUA script specific for this connection.
657+ # Make sure to use the message list specific to this channel, it is
658+ # stored in channel_to_message dict and each message contains the
659+ # __asgi_channel__ key.
660+
661+ group_send_lua = """
662+ local over_capacity = 0
663+ local num_messages = tonumber(ARGV[#ARGV - 2])
664+ local current_time = ARGV[#ARGV - 1]
665+ local expiry = ARGV[#ARGV]
666+ for i=1,#KEYS do
667+ if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
668+ local messages = {}
669+ for j=num_messages * (i - 1) + 1, num_messages * i do
670+ table.insert(messages, current_time)
671+ table.insert(messages, ARGV[j])
672+ end
673+ redis.call('ZADD', KEYS[i], unpack(messages))
674+ redis.call('EXPIRE', KEYS[i], expiry)
675+ else
676+ over_capacity = over_capacity + 1
677+ end
678+ end
679+ return over_capacity
680+ """
681+
682+ # We need to filter the messages to keep those related to the connection
683+ args = []
684+
685+ for channel_key in channel_redis_keys :
686+ args += channel_keys_to_message [channel_key ]
687+
688+ # We need to send the capacity for each channel
689+ args += [
690+ channel_keys_to_capacity [channel_key ]
691+ for channel_key in channel_redis_keys
692+ ]
596693
597- def _map_channel_keys_to_connection (self , channel_names , message ):
694+ args += [len (messages ), time .time (), self .expiry ]
695+
696+ await self ._exec_group_lua_script (
697+ connection_index ,
698+ group ,
699+ channel_redis_keys ,
700+ channel_names ,
701+ group_send_lua ,
702+ args ,
703+ )
704+
705+ def _map_channel_keys_to_connection (self , channel_names , messages ):
598706 """
599707 For a list of channel names, GET
600708
@@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609717 # Connection dict keyed by index to list of redis keys mapped on that index
610718 connection_to_channel_keys = collections .defaultdict (list )
611719 # Message dict maps redis key to the message that needs to be send on that key
612- channel_key_to_message = dict ( )
720+ channel_key_to_message = collections . defaultdict ( list )
613721 # Channel key mapped to its capacity
614722 channel_key_to_capacity = dict ()
615723
@@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623731 # Have we come across the same redis key?
624732 if channel_key not in channel_key_to_message :
625733 # If not, fill the corresponding dicts
626- message = dict (message .items ())
627- message ["__asgi_channel__" ] = [channel ]
628- channel_key_to_message [channel_key ] = message
734+ for message in messages :
735+ message = dict (message .items ())
736+ message ["__asgi_channel__" ] = [channel ]
737+ channel_key_to_message [channel_key ].append (message )
629738 channel_key_to_capacity [channel_key ] = self .get_capacity (channel )
630739 idx = self .consistent_hash (channel_non_local_name )
631740 connection_to_channel_keys [idx ].append (channel_key )
632741 else :
633742 # Yes, Append the channel in message dict
634- channel_key_to_message [channel_key ]["__asgi_channel__" ].append (channel )
743+ for message in channel_key_to_message [channel_key ]:
744+ message ["__asgi_channel__" ].append (channel )
635745
636746 # Now that we know what message needs to be send on a redis key we serialize it
637747 for key , value in channel_key_to_message .items ():
638748 # Serialize the message stored for each redis key
639- channel_key_to_message [key ] = self .serialize (value )
749+ for idx , message in enumerate (value ):
750+ channel_key_to_message [key ][idx ] = self .serialize (message )
640751
641752 return (
642753 connection_to_channel_keys ,
0 commit comments