403 lines
14 KiB
Python
403 lines
14 KiB
Python
"""JSON Web Key."""
|
|
|
|
import abc
|
|
import json
|
|
import logging
|
|
import math
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
import cryptography.exceptions
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
|
|
|
import josepy.util
|
|
from josepy import errors, json_util, util
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JWK(json_util.TypedJSONObjectWithFields, metaclass=abc.ABCMeta):
|
|
"""JSON Web Key."""
|
|
|
|
type_field_name = "kty"
|
|
TYPES: Dict[str, Type["JWK"]] = {}
|
|
cryptography_key_types: Tuple[Type[Any], ...] = ()
|
|
"""Subclasses should override."""
|
|
|
|
required: Sequence[str] = NotImplemented
|
|
"""Required members of public key's representation as defined by JWK/JWA."""
|
|
|
|
_thumbprint_json_dumps_params: Dict[str, Union[Optional[int], Sequence[str], bool]] = {
|
|
# "no whitespace or line breaks before or after any syntactic
|
|
# elements"
|
|
"indent": None,
|
|
"separators": (",", ":"),
|
|
# "members ordered lexicographically by the Unicode [UNICODE]
|
|
# code points of the member names"
|
|
"sort_keys": True,
|
|
}
|
|
key: Any
|
|
|
|
def thumbprint(
|
|
self, hash_function: Callable[[], hashes.HashAlgorithm] = hashes.SHA256
|
|
) -> bytes:
|
|
"""Compute JWK Thumbprint.
|
|
|
|
https://tools.ietf.org/html/rfc7638
|
|
|
|
:returns: bytes
|
|
|
|
"""
|
|
digest = hashes.Hash(hash_function(), backend=default_backend())
|
|
digest.update(
|
|
json.dumps(
|
|
{k: v for k, v in self.to_json().items() if k in self.required},
|
|
**self._thumbprint_json_dumps_params, # type: ignore[arg-type]
|
|
).encode()
|
|
)
|
|
return digest.finalize()
|
|
|
|
@abc.abstractmethod
|
|
def public_key(self) -> "JWK": # pragma: no cover
|
|
"""Generate JWK with public key.
|
|
|
|
For symmetric cryptosystems, this would return ``self``.
|
|
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def _load_cryptography_key(
|
|
cls, data: bytes, password: Optional[bytes] = None, backend: Optional[Any] = None
|
|
) -> Any:
|
|
backend = default_backend() if backend is None else backend
|
|
exceptions = {}
|
|
|
|
# private key?
|
|
loader_private: Any
|
|
for loader_private in (
|
|
serialization.load_pem_private_key,
|
|
serialization.load_der_private_key,
|
|
):
|
|
try:
|
|
return loader_private(data, password, backend)
|
|
except (ValueError, TypeError, cryptography.exceptions.UnsupportedAlgorithm) as error:
|
|
exceptions[str(loader_private)] = error
|
|
|
|
# public key?
|
|
loader_public: Any
|
|
for loader_public in (serialization.load_pem_public_key, serialization.load_der_public_key):
|
|
try:
|
|
return loader_public(data, backend)
|
|
except (ValueError, cryptography.exceptions.UnsupportedAlgorithm) as error:
|
|
exceptions[str(loader_public)] = error
|
|
|
|
# no luck
|
|
raise errors.Error("Unable to deserialize key: {0}".format(exceptions))
|
|
|
|
@classmethod
|
|
def load(
|
|
cls, data: bytes, password: Optional[bytes] = None, backend: Optional[Any] = None
|
|
) -> "JWK":
|
|
"""Load serialized key as JWK.
|
|
|
|
:param str data: Public or private key serialized as PEM or DER.
|
|
:param str password: Optional password.
|
|
:param backend: A `.PEMSerializationBackend` and
|
|
`.DERSerializationBackend` provider.
|
|
|
|
:raises errors.Error: if unable to deserialize, or unsupported
|
|
JWK algorithm
|
|
|
|
:returns: JWK of an appropriate type.
|
|
:rtype: `JWK`
|
|
|
|
"""
|
|
try:
|
|
key = cls._load_cryptography_key(data, password, backend)
|
|
except errors.Error as error:
|
|
logger.debug("Loading symmetric key, asymmetric failed: %s", error)
|
|
return JWKOct(key=data)
|
|
|
|
if cls.typ is not NotImplemented and not isinstance(key, cls.cryptography_key_types):
|
|
raise errors.Error(
|
|
"Unable to deserialize {0} into {1}".format(key.__class__, cls.__class__)
|
|
)
|
|
for jwk_cls in cls.TYPES.values():
|
|
if isinstance(key, jwk_cls.cryptography_key_types):
|
|
return jwk_cls(key=key)
|
|
raise errors.Error("Unsupported algorithm: {0}".format(key.__class__))
|
|
|
|
|
|
@JWK.register
|
|
class JWKOct(JWK):
|
|
"""Symmetric JWK."""
|
|
|
|
typ = "oct"
|
|
__slots__ = ("key",)
|
|
required = ("k", JWK.type_field_name)
|
|
key: bytes
|
|
|
|
def fields_to_partial_json(self) -> Dict[str, str]:
|
|
# TODO: An "alg" member SHOULD also be present to identify the
|
|
# algorithm intended to be used with the key, unless the
|
|
# application uses another means or convention to determine
|
|
# the algorithm used.
|
|
return {"k": json_util.encode_b64jose(self.key)}
|
|
|
|
@classmethod
|
|
def fields_from_json(cls, jobj: Mapping[str, Any]) -> "JWKOct":
|
|
return cls(key=json_util.decode_b64jose(jobj["k"]))
|
|
|
|
def public_key(self) -> "JWKOct":
|
|
return self
|
|
|
|
|
|
@JWK.register
|
|
class JWKRSA(JWK):
|
|
"""RSA JWK.
|
|
|
|
:ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey`
|
|
or :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey` wrapped
|
|
in :class:`~josepy.util.ComparableRSAKey`
|
|
|
|
"""
|
|
|
|
typ = "RSA"
|
|
cryptography_key_types = (rsa.RSAPublicKey, rsa.RSAPrivateKey)
|
|
__slots__ = ("key",)
|
|
required = ("e", JWK.type_field_name, "n")
|
|
key: josepy.util.ComparableRSAKey
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
if "key" in kwargs and not isinstance(kwargs["key"], util.ComparableRSAKey):
|
|
kwargs["key"] = util.ComparableRSAKey(kwargs["key"])
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def _encode_param(cls, data: int) -> str:
|
|
"""Encode Base64urlUInt.
|
|
:type data: long
|
|
:rtype: unicode
|
|
"""
|
|
length = max(data.bit_length(), 8) # decoding 0
|
|
length = math.ceil(length / 8)
|
|
return json_util.encode_b64jose(data.to_bytes(byteorder="big", length=length))
|
|
|
|
@classmethod
|
|
def _decode_param(cls, data: str) -> int:
|
|
"""Decode Base64urlUInt."""
|
|
try:
|
|
binary = json_util.decode_b64jose(data)
|
|
if not binary:
|
|
raise errors.DeserializationError()
|
|
return int.from_bytes(binary, byteorder="big")
|
|
except ValueError: # invalid literal for long() with base 16
|
|
raise errors.DeserializationError()
|
|
|
|
def public_key(self) -> "JWKRSA":
|
|
return type(self)(key=self.key.public_key())
|
|
|
|
@classmethod
|
|
def fields_from_json(cls, jobj: Mapping[str, Any]) -> "JWKRSA":
|
|
n, e = (cls._decode_param(jobj[x]) for x in ("n", "e"))
|
|
public_numbers = rsa.RSAPublicNumbers(e=e, n=n)
|
|
|
|
# public key
|
|
if "d" not in jobj:
|
|
return cls(key=public_numbers.public_key(default_backend()))
|
|
|
|
# private key
|
|
d = cls._decode_param(jobj["d"])
|
|
if (
|
|
"p" in jobj
|
|
or "q" in jobj
|
|
or "dp" in jobj
|
|
or "dq" in jobj
|
|
or "qi" in jobj
|
|
or "oth" in jobj
|
|
):
|
|
# "If the producer includes any of the other private
|
|
# key parameters, then all of the others MUST be
|
|
# present, with the exception of "oth", which MUST
|
|
# only be present when more than two prime factors
|
|
# were used."
|
|
(
|
|
p,
|
|
q,
|
|
dp,
|
|
dq,
|
|
qi,
|
|
) = all_params = tuple(jobj.get(x) for x in ("p", "q", "dp", "dq", "qi"))
|
|
if tuple(param for param in all_params if param is None):
|
|
raise errors.Error("Some private parameters are missing: {0}".format(all_params))
|
|
p, q, dp, dq, qi = tuple(cls._decode_param(str(x)) for x in all_params)
|
|
|
|
# TODO: check for oth
|
|
else:
|
|
# cryptography>=0.8
|
|
p, q = rsa.rsa_recover_prime_factors(n, e, d)
|
|
dp = rsa.rsa_crt_dmp1(d, p)
|
|
dq = rsa.rsa_crt_dmq1(d, q)
|
|
qi = rsa.rsa_crt_iqmp(p, q)
|
|
|
|
key = rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, public_numbers).private_key(
|
|
default_backend()
|
|
)
|
|
|
|
return cls(key=key)
|
|
|
|
def fields_to_partial_json(self) -> Dict[str, Any]:
|
|
if isinstance(self.key._wrapped, rsa.RSAPublicKey):
|
|
numbers = self.key.public_numbers()
|
|
params = {
|
|
"n": numbers.n,
|
|
"e": numbers.e,
|
|
}
|
|
else: # rsa.RSAPrivateKey
|
|
private = self.key.private_numbers()
|
|
public = self.key.public_key().public_numbers()
|
|
params = {
|
|
"n": public.n,
|
|
"e": public.e,
|
|
"d": private.d,
|
|
"p": private.p,
|
|
"q": private.q,
|
|
"dp": private.dmp1,
|
|
"dq": private.dmq1,
|
|
"qi": private.iqmp,
|
|
}
|
|
return {key: self._encode_param(value) for key, value in params.items()}
|
|
|
|
|
|
@JWK.register
|
|
class JWKEC(JWK):
|
|
"""EC JWK.
|
|
|
|
:ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey`
|
|
or :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey` wrapped
|
|
in :class:`~josepy.util.ComparableECKey`
|
|
|
|
"""
|
|
|
|
typ = "EC"
|
|
__slots__ = ("key",)
|
|
cryptography_key_types = (ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey)
|
|
required = ("crv", JWK.type_field_name, "x", "y")
|
|
key: josepy.util.ComparableECKey
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
if "key" in kwargs and not isinstance(kwargs["key"], util.ComparableECKey):
|
|
kwargs["key"] = util.ComparableECKey(kwargs["key"])
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def _encode_param(cls, data: int, length: int) -> str:
|
|
"""Encode Base64urlUInt.
|
|
:type data: long
|
|
:type key_size: long
|
|
:rtype: unicode
|
|
"""
|
|
return json_util.encode_b64jose(data.to_bytes(byteorder="big", length=length))
|
|
|
|
@classmethod
|
|
def _decode_param(cls, data: str, name: str, valid_length: int) -> int:
|
|
"""Decode Base64urlUInt."""
|
|
try:
|
|
binary = json_util.decode_b64jose(data)
|
|
if len(binary) != valid_length:
|
|
raise errors.DeserializationError(
|
|
f'Expected parameter "{name}" to be {valid_length} bytes '
|
|
f"after base64-decoding; got {len(binary)} bytes instead"
|
|
)
|
|
return int.from_bytes(binary, byteorder="big")
|
|
except ValueError: # invalid literal for long() with base 16
|
|
raise errors.DeserializationError()
|
|
|
|
@classmethod
|
|
def _curve_name_to_crv(cls, curve_name: str) -> str:
|
|
if curve_name == "secp256r1":
|
|
return "P-256"
|
|
if curve_name == "secp384r1":
|
|
return "P-384"
|
|
if curve_name == "secp521r1":
|
|
return "P-521"
|
|
raise errors.SerializationError()
|
|
|
|
@classmethod
|
|
def _crv_to_curve(cls, crv: str) -> ec.EllipticCurve:
|
|
# crv is case-sensitive
|
|
if crv == "P-256":
|
|
return ec.SECP256R1()
|
|
if crv == "P-384":
|
|
return ec.SECP384R1()
|
|
if crv == "P-521":
|
|
return ec.SECP521R1()
|
|
raise errors.DeserializationError()
|
|
|
|
@classmethod
|
|
def expected_length_for_curve(cls, curve: ec.EllipticCurve) -> int:
|
|
if isinstance(curve, ec.SECP256R1):
|
|
return 32
|
|
elif isinstance(curve, ec.SECP384R1):
|
|
return 48
|
|
elif isinstance(curve, ec.SECP521R1):
|
|
return 66
|
|
raise ValueError(f"Unexpected curve: {curve}")
|
|
|
|
def fields_to_partial_json(self) -> Dict[str, Any]:
|
|
params = {}
|
|
if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey):
|
|
public = self.key.public_numbers()
|
|
elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey):
|
|
private = self.key.private_numbers()
|
|
public = self.key.public_key().public_numbers()
|
|
params["d"] = private.private_value
|
|
else:
|
|
raise errors.SerializationError(
|
|
"Supplied key is neither of type EllipticCurvePublicKey "
|
|
"nor EllipticCurvePrivateKey"
|
|
)
|
|
params["x"] = public.x
|
|
params["y"] = public.y
|
|
params = {
|
|
key: self._encode_param(value, self.expected_length_for_curve(public.curve))
|
|
for key, value in params.items()
|
|
}
|
|
params["crv"] = self._curve_name_to_crv(public.curve.name)
|
|
return params
|
|
|
|
@classmethod
|
|
def fields_from_json(cls, jobj: Mapping[str, Any]) -> "JWKEC":
|
|
curve = cls._crv_to_curve(jobj["crv"])
|
|
expected_length = cls.expected_length_for_curve(curve)
|
|
x, y = (cls._decode_param(jobj[n], n, expected_length) for n in ("x", "y"))
|
|
public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve)
|
|
|
|
# private key
|
|
if "d" not in jobj:
|
|
return cls(key=public_numbers.public_key(default_backend()))
|
|
|
|
# private key
|
|
d = cls._decode_param(jobj["d"], "d", expected_length)
|
|
key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key(default_backend())
|
|
return cls(key=key)
|
|
|
|
def public_key(self) -> "JWKEC":
|
|
# Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key()
|
|
if hasattr(self.key, "public_key"):
|
|
key = self.key.public_key()
|
|
else:
|
|
key = self.key.public_numbers().public_key(default_backend())
|
|
return type(self)(key=key)
|