Skip to content

Commit

Permalink
Merge pull request #290 from saltstack/cve/3004.1/CVE-2022-22936-bugfix
Browse files Browse the repository at this point in the history
Test fix
  • Loading branch information
garethgreenaway authored Feb 28, 2022
2 parents d1811dd + 8890a6c commit 064729c
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 54 deletions.
48 changes: 27 additions & 21 deletions salt/transport/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ class PubServer(salt.ext.tornado.tcpserver.TCPServer):
TCP publisher
"""

def __init__(self, opts, io_loop=None):
def __init__(self, opts, io_loop=None, pack_publish=lambda _: _):
super().__init__(ssl_options=opts.get("ssl"))
self.io_loop = io_loop
self.opts = opts
Expand All @@ -1449,6 +1449,10 @@ def __init__(self, opts, io_loop=None):
)
else:
self.event = None
self._pack_publish = pack_publish

def pack_publish(self, load):
return self._pack_publish(load)

def close(self):
if self._closing:
Expand Down Expand Up @@ -1557,6 +1561,7 @@ def handle_stream(self, stream, address):
@salt.ext.tornado.gen.coroutine
def publish_payload(self, package, _):
log.debug("TCP PubServer sending payload: %s", package)
payload = self.pack_publish(package)
payload = salt.transport.frame.frame_msg(package["payload"])

to_remove = []
Expand Down Expand Up @@ -1632,7 +1637,9 @@ def _publish_daemon(self, **kwargs):
self.io_loop = salt.ext.tornado.ioloop.IOLoop.current()

# Spin up the publisher
pub_server = PubServer(self.opts, io_loop=self.io_loop)
pub_server = PubServer(
self.opts, io_loop=self.io_loop, pack_publish=self.pack_publish
)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
_set_tcp_keepalive(sock, self.opts)
Expand Down Expand Up @@ -1675,10 +1682,7 @@ def pre_fork(self, process_manager, kwargs=None):
"""
process_manager.add_process(self._publish_daemon, kwargs=kwargs)

def publish(self, load):
"""
Publish "load" to minions
"""
def pack_publish(self, load):
payload = {"enc": "aes"}
load["serial"] = salt.master.SMaster.get_serial()
crypticle = salt.crypt.Crypticle(
Expand All @@ -1689,20 +1693,6 @@ def publish(self, load):
master_pem_path = os.path.join(self.opts["pki_dir"], "master.pem")
log.debug("Signing data packet")
payload["sig"] = salt.crypt.sign_message(master_pem_path, payload["load"])
# Use the Salt IPC server
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
# TODO: switch to the actual asynchronous interface
# pub_sock = salt.transport.ipc.IPCMessageClient(self.opts, io_loop=self.io_loop)
pub_sock = salt.utils.asynchronous.SyncWrapper(
salt.transport.ipc.IPCMessageClient,
(pull_uri,),
loop_kwarg="io_loop",
)
pub_sock.connect()

int_payload = {"payload": salt.payload.dumps(payload)}

# add some targeting stuff for lists only (for now)
Expand All @@ -1719,5 +1709,21 @@ def publish(self, load):
int_payload["topic_lst"] = match_ids
else:
int_payload["topic_lst"] = load["tgt"]
return int_payload

def publish(self, load):
"""
Publish "load" to minions
"""
# Send it over IPC!
pub_sock.send(int_payload)
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = int(self.opts.get("tcp_master_publish_pull", 4514))
else:
pull_uri = os.path.join(self.opts["sock_dir"], "publish_pull.ipc")
pub_sock = salt.utils.asynchronous.SyncWrapper(
salt.transport.ipc.IPCMessageClient,
(pull_uri,),
loop_kwarg="io_loop",
)
pub_sock.connect()
pub_sock.send(load)
27 changes: 17 additions & 10 deletions salt/transport/zeromq.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,8 @@ def _publish_daemon(self, log_queue=None):
try:
log.debug("Publish daemon getting data from puller %s", pull_uri)
package = pull_sock.recv()
package = salt.payload.loads(package)
package = self.pack_publish(package)
log.debug("Publish daemon received payload. size=%d", len(package))

unpacked_package = salt.payload.unpackage(package)
Expand Down Expand Up @@ -1031,8 +1033,8 @@ def pub_connect(self):
"""
if self.pub_sock:
self.pub_close()
ctx = zmq.Context.instance()
self._sock_data.sock = ctx.socket(zmq.PUSH)
self._sock_data._ctx = zmq.Context()
self._sock_data.sock = self._sock_data._ctx.socket(zmq.PUSH)
self.pub_sock.setsockopt(zmq.LINGER, -1)
if self.opts.get("ipc_mode", "") == "tcp":
pull_uri = "tcp://127.0.0.1:{}".format(
Expand All @@ -1054,14 +1056,10 @@ def pub_close(self):
if hasattr(self._sock_data, "sock"):
self._sock_data.sock.close()
delattr(self._sock_data, "sock")
if hasattr(self._sock_data, "_ctx"):
self._sock_data._ctx.destroy()

def publish(self, load):
"""
Publish "load" to minions. This send the load to the publisher daemon
process with does the actual sending to minions.
:param dict load: A load to be sent across the wire to minions
"""
def pack_publish(self, load):
payload = {"enc": "aes"}
load["serial"] = salt.master.SMaster.get_serial()
crypticle = salt.crypt.Crypticle(
Expand Down Expand Up @@ -1094,9 +1092,18 @@ def publish(self, load):
load.get("jid", None),
len(payload),
)
return payload

def publish(self, load):
"""
Publish "load" to minions. This send the load to the publisher daemon
process with does the actual sending to minions.
:param dict load: A load to be sent across the wire to minions
"""
if not self.pub_sock:
self.pub_connect()
self.pub_sock.send(payload)
self.pub_sock.send(salt.payload.dumps(load))
log.debug("Sent payload to publish daemon.")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def __init__(self, master_config, minion_config, **collector_kwargs):
self.minion_config = minion_config
self.collector_kwargs = collector_kwargs
self.aes_key = salt.crypt.Crypticle.generate_key_string()
salt.master.SMaster.secrets["aes"] = {
"secret": multiprocessing.Array(
ctypes.c_char,
salt.utils.stringutils.to_bytes(self.aes_key),
),
"serial": multiprocessing.Value(
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
),
}
self.process_manager = salt.utils.process.ProcessManager(
name="ZMQ-PubServer-ProcessManager"
)
Expand All @@ -145,15 +154,6 @@ def __init__(self, master_config, minion_config, **collector_kwargs):
)

def run(self):
salt.master.SMaster.secrets["aes"] = {
"secret": multiprocessing.Array(
ctypes.c_char,
salt.utils.stringutils.to_bytes(self.aes_key),
),
"serial": multiprocessing.Value(
ctypes.c_longlong, lock=False # We'll use the lock from 'secret'
),
}
try:
while True:
payload = self.queue.get()
Expand Down Expand Up @@ -247,12 +247,16 @@ def test_issue_36469_tcp(salt_master, salt_minion):
https://github.com/saltstack/salt/issues/36469
"""

def _send_small(server_channel, sid, num=10):
def _send_small(opts, sid, num=10):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
for idx in range(num):
load = {"tgt_type": "glob", "tgt": "*", "jid": "{}-s{}".format(sid, idx)}
server_channel.publish(load)
time.sleep(0.3)
server_channel.close_pub()

def _send_large(server_channel, sid, num=10, size=250000 * 3):
def _send_large(opts, sid, num=10, size=250000 * 3):
server_channel = salt.transport.zeromq.ZeroMQPubServerChannel(opts)
for idx in range(num):
load = {
"tgt_type": "glob",
Expand All @@ -261,16 +265,19 @@ def _send_large(server_channel, sid, num=10, size=250000 * 3):
"xdata": "0" * size,
}
server_channel.publish(load)
time.sleep(0.3)
server_channel.close_pub()

opts = dict(salt_master.config.copy(), ipc_mode="tcp", pub_hwm=0)
send_num = 10 * 4
expect = []
with PubServerChannelProcess(opts, salt_minion.config.copy()) as server_channel:
assert "aes" in salt.master.SMaster.secrets
with ThreadPoolExecutor(max_workers=4) as executor:
executor.submit(_send_small, server_channel, 1)
executor.submit(_send_large, server_channel, 2)
executor.submit(_send_small, server_channel, 3)
executor.submit(_send_large, server_channel, 4)
executor.submit(_send_small, opts, 1)
executor.submit(_send_large, opts, 2)
executor.submit(_send_small, opts, 3)
executor.submit(_send_large, opts, 4)
expect.extend(["{}-s{}".format(a, b) for a in range(10) for b in (1, 3)])
expect.extend(["{}-l{}".format(a, b) for a in range(10) for b in (2, 4)])
results = server_channel.collector.results
Expand Down
20 changes: 12 additions & 8 deletions tests/pytests/unit/transport/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,24 +210,27 @@ def test_tcp_pub_server_channel_publish_filtering(temp_salt_master):
SyncWrapper.return_value = wrap

# try simple publish with glob tgt_type
channel.publish({"test": "value", "tgt_type": "glob", "tgt": "*"})
payload = wrap.send.call_args[0][0]
payload = channel.pack_publish(
{"test": "value", "tgt_type": "glob", "tgt": "*"}
)

# verify we send it without any specific topic
assert "topic_lst" not in payload

# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
payload = channel.pack_publish(
{"test": "value", "tgt_type": "list", "tgt": ["minion01"]}
)

# verify we send it with correct topic
assert "topic_lst" in payload
assert payload["topic_lst"] == ["minion01"]

# try with syndic settings
opts["order_masters"] = True
channel.publish({"test": "value", "tgt_type": "list", "tgt": ["minion01"]})
payload = wrap.send.call_args[0][0]
payload = channel.pack_publish(
{"test": "value", "tgt_type": "list", "tgt": ["minion01"]}
)

# verify we send it without topic for syndics
assert "topic_lst" not in payload
Expand Down Expand Up @@ -257,8 +260,9 @@ def test_tcp_pub_server_channel_publish_filtering_str_list(temp_salt_master):
check_minions.return_value = {"minions": ["minion02"]}

# try simple publish with list tgt_type
channel.publish({"test": "value", "tgt_type": "list", "tgt": "minion02"})
payload = wrap.send.call_args[0][0]
payload = channel.pack_publish(
{"test": "value", "tgt_type": "list", "tgt": "minion02"}
)

# verify we send it with correct topic
assert "topic_lst" in payload
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/transport/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class IPCMessagePubSubCase(salt.ext.tornado.testing.AsyncTestCase):
def setUp(self):
super().setUp()
self.opts = {"ipc_write_buffer": 0}
if not os.path.exists(RUNTIME_VARS.TMP):
os.mkdir(RUNTIME_VARS.TMP)
self.socket_path = os.path.join(RUNTIME_VARS.TMP, "ipc_test.ipc")
self.pub_channel = self._get_pub_channel()
self.sub_channel = self._get_sub_channel()
Expand Down

0 comments on commit 064729c

Please sign in to comment.