summaryrefslogtreecommitdiff
path: root/tools/mcuboot/imgtool/keys/rsa.py
blob: f8273bf5aebdd79e9885698a2ca1604e1672b8d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
RSA Key management
"""

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.padding import PSS, MGF1
from cryptography.hazmat.primitives.hashes import SHA256

from .general import KeyClass


# Sizes that bootutil will recognize
RSA_KEY_SIZES = [2048, 3072]


class RSAUsageError(Exception):
    pass


class RSAPublic(KeyClass):
    """The public key can only do a few operations"""
    def __init__(self, key):
        self.key = key

    def key_size(self):
        return self.key.key_size

    def shortname(self):
        return "rsa"

    def _unsupported(self, name):
        raise RSAUsageError("Operation {} requires private key".format(name))

    def _get_public(self):
        return self.key

    def get_public_bytes(self):
        # The key embedded into MCUboot is in PKCS1 format.
        return self._get_public().public_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PublicFormat.PKCS1)

    def get_private_bytes(self, minimal):
        self._unsupported('get_private_bytes')

    def export_private(self, path, passwd=None):
        self._unsupported('export_private')

    def export_public(self, path):
        """Write the public key to the given file."""
        pem = self._get_public().public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo)
        with open(path, 'wb') as f:
            f.write(pem)

    def sig_type(self):
        return "PKCS1_PSS_RSA{}_SHA256".format(self.key_size())

    def sig_tlv(self):
        return"RSA{}".format(self.key_size())

    def sig_len(self):
        return self.key_size() / 8

    def verify(self, signature, payload):
        k = self.key
        if isinstance(self.key, rsa.RSAPrivateKey):
            k = self.key.public_key()
        return k.verify(signature=signature, data=payload,
                        padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
                        algorithm=SHA256())


class RSA(RSAPublic):
    """
    Wrapper around an RSA key, with imgtool support.
    """

    def __init__(self, key):
        """The key should be a private key from cryptography"""
        self.key = key

    @staticmethod
    def generate(key_size=2048):
        if key_size not in RSA_KEY_SIZES:
            raise RSAUsageError("Key size {} is not supported by MCUboot"
                                .format(key_size))
        pk = rsa.generate_private_key(
            public_exponent=65537,
            key_size=key_size,
            backend=default_backend())
        return RSA(pk)

    def _get_public(self):
        return self.key.public_key()

    def _build_minimal_rsa_privkey(self, der):
        '''
        Builds a new DER that only includes N/E/D/P/Q RSA parameters;
        standard DER private bytes provided by OpenSSL also includes
        CRT params (DP/DQ/QP) which can be removed.
        '''
        OFFSET_N = 7  # N is always located at this offset
        b = bytearray(der)
        off = OFFSET_N
        if b[off + 1] != 0x82:
            raise RSAUsageError("Error parsing N while minimizing")
        len_N = (b[off + 2] << 8) + b[off + 3] + 4
        off += len_N
        if b[off + 1] != 0x03:
            raise RSAUsageError("Error parsing E while minimizing")
        len_E = b[off + 2] + 4
        off += len_E
        if b[off + 1] != 0x82:
            raise RSAUsageError("Error parsing D while minimizing")
        len_D = (b[off + 2] << 8) + b[off + 3] + 4
        off += len_D
        if b[off + 1] != 0x81:
            raise RSAUsageError("Error parsing P while minimizing")
        len_P = b[off + 2] + 3
        off += len_P
        if b[off + 1] != 0x81:
            raise RSAUsageError("Error parsing Q while minimizing")
        len_Q = b[off + 2] + 3
        off += len_Q
        # adjust DER size for removed elements
        b[2] = (off - 4) >> 8
        b[3] = (off - 4) & 0xff
        return b[:off]

    def get_private_bytes(self, minimal):
        priv = self.key.private_bytes(
            encoding=serialization.Encoding.DER,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption())
        if minimal:
            priv = self._build_minimal_rsa_privkey(priv)
        return priv

    def export_private(self, path, passwd=None):
        """Write the private key to the given file, protecting it with the
        optional password."""
        if passwd is None:
            enc = serialization.NoEncryption()
        else:
            enc = serialization.BestAvailableEncryption(passwd)
        pem = self.key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=enc)
        with open(path, 'wb') as f:
            f.write(pem)

    def sign(self, payload):
        # The verification code only allows the salt length to be the
        # same as the hash length, 32.
        return self.key.sign(
            data=payload,
            padding=PSS(mgf=MGF1(SHA256()), salt_length=32),
            algorithm=SHA256())