diff --git a/azurelinuxagent/common/utils/cryptutil.py b/azurelinuxagent/common/utils/cryptutil.py index 0471a1fa0f..bc37a729e9 100644 --- a/azurelinuxagent/common/utils/cryptutil.py +++ b/azurelinuxagent/common/utils/cryptutil.py @@ -33,6 +33,7 @@ DECRYPT_SECRET_CMD = "{0} cms -decrypt -inform DER -inkey {1} -in /dev/stdin" + class CryptUtil(object): def __init__(self, openssl_cmd): self.openssl_cmd = openssl_cmd @@ -53,8 +54,8 @@ def get_pubkey_from_prv(self, file_name): if not os.path.exists(file_name): raise IOError(errno.ENOENT, "File not found", file_name) else: - cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, - file_name) + cmd = "{0} pkey -in {1} -pubout 2>/dev/null".format(self.openssl_cmd, + file_name) pub = shellutil.run_get_output(cmd)[1] return pub @@ -117,7 +118,7 @@ def asn1_to_ssh(self, pubkey): keydata.extend(b"\0") keydata.extend(self.num_to_bytes(n)) keydata_base64 = base64.b64encode(bytebuffer(keydata)) - return ustr(b"ssh-rsa " + keydata_base64 + b"\n", + return ustr(b"ssh-rsa " + keydata_base64 + b"\n", encoding='utf-8') except ImportError as e: raise CryptError("Failed to load pyasn1.codec.der") diff --git a/bin/waagent2.0 b/bin/waagent2.0 index 25aa0ce0ff..400a74d19a 100644 --- a/bin/waagent2.0 +++ b/bin/waagent2.0 @@ -3292,7 +3292,7 @@ class Certificates(object): index = 1 filename = str(index) + ".prv" while os.path.isfile(filename): - pubkey = RunGetOutput(Openssl + " rsa -in " + filename + " -pubout 2> /dev/null ")[1] + pubkey = RunGetOutput(Openssl + " pkey -in " + filename + " -pubout 2> /dev/null ")[1] os.rename(filename, keys[pubkey] + ".prv") os.chmod(keys[pubkey] + ".prv", 0600) MyDistro.setSelinuxContext( keys[pubkey] + '.prv','unconfined_u:object_r:ssh_home_t:s0') diff --git a/tests/data/wire/trans_pub b/tests/data/wire/trans_pub new file mode 100644 index 0000000000..e4a8ea8aad --- /dev/null +++ b/tests/data/wire/trans_pub @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA09wkCR3pXk16iBIqMh5N +c5YLnHMpPK4k+3hhkxVKixTSUjprTAen6DZ8/bbOtWzBb5AnPoBVaiMgSotC6ndb +IJdlO/xFRuUeciOS9f/4n8ZoubPQbknNkikQsvYLwh9AsfYiI+Ur0s5AfTRbvhYV +wrdCpwnorDwZxVp5JdPWvtdBwYyoSNxYmSkougwm/csy58T4kx1tcNQZj4+ztmJy +7wpe8E9opWxzofaOuoFLx62NdvMvKt7NNQPPjmubJEnMI7lKTamiG5iDvfBTKQBQ +9XF3svxadLKrPW/jOs5uqfAEDKivrslH+GNMF+MU693yoUaid+K/ZWfP1exgVNmx +cQIDAQAB +-----END PUBLIC KEY----- diff --git a/tests/utils/test_crypt_util.py b/tests/utils/test_crypt_util.py index 3e5321e88e..c9ae9bc019 100644 --- a/tests/utils/test_crypt_util.py +++ b/tests/utils/test_crypt_util.py @@ -15,26 +15,9 @@ # Requires Python 2.6+ and Openssl 1.0+ # -import base64 -import binascii -import errno as errno -import glob -import random -import string -import subprocess -import sys -import tempfile -import uuid -import unittest - -import azurelinuxagent.common.conf as conf -import azurelinuxagent.common.utils.shellutil as shellutil -from azurelinuxagent.common.future import ustr -from azurelinuxagent.common.utils.cryptutil import CryptUtil from azurelinuxagent.common.exception import CryptError -from azurelinuxagent.common.version import PY_VERSION_MAJOR +from azurelinuxagent.common.utils.cryptutil import CryptUtil from tests.tools import * -from subprocess import CalledProcessError def is_python_version_26(): @@ -42,6 +25,7 @@ def is_python_version_26(): class TestCryptoUtilOperations(AgentTestCase): + def test_decrypt_encrypted_text(self): encrypted_string = load_data("wire/encrypted.enc") prv_key = os.path.join(self.tmp_dir, "TransportPrivate.pem") @@ -75,6 +59,20 @@ def test_decrypt_encrypted_text_text_not_encrypted(self): crypto = CryptUtil(conf.get_openssl_cmd()) self.assertRaises(CryptError, crypto.decrypt_secret, encrypted_string, prv_key) + def test_get_pubkey_from_crt(self): + crypto = CryptUtil(conf.get_openssl_cmd()) + prv_key = os.path.join(data_dir, "wire", "trans_prv") + expected_pub_key = os.path.join(data_dir, "wire", "trans_pub") + + with open(expected_pub_key) as fh: + self.assertEqual(fh.read(), crypto.get_pubkey_from_prv(prv_key)) + + def test_get_pubkey_from_crt_invalid_file(self): + crypto = CryptUtil(conf.get_openssl_cmd()) + prv_key = os.path.join(data_dir, "wire", "trans_prv_does_not_exist") + + self.assertRaises(IOError, crypto.get_pubkey_from_prv, prv_key) + if __name__ == '__main__': unittest.main()