#!/usr/bin/env python3

# SPDX-FileCopyrightText: 2022 FIT-Connect contributors
#
# SPDX-License-Identifier: EUPL-1.2

import argparse
import pathlib
import random
import tempfile
import textwrap

from OpenSSL import crypto
from jwcrypto import jwk


def create_self_signed_cert():
    # create key pair
    keypair = crypto.PKey()
    keypair.generate_key(crypto.TYPE_RSA, 4096)

    # create self-signed cert
    cert = crypto.X509()
    cert.get_subject().C = "DE"
    cert.get_subject().O = "Testbehoerde"
    cert.get_subject().CN = "FIT Connect Testzertifikat"
    cert.set_serial_number(random.randint(50000000, 100000000))
    cert.gmtime_adj_notBefore(0)
    cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
    cert.set_issuer(cert.get_subject())
    cert.set_pubkey(keypair)
    cert.sign(keypair, "sha512")

    return cert, keypair


def cert_to_x5c(cert):
    # export certificate as ASN1 (DER)
    cert_asn1 = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
    cert_asn1_base64 = (
        cert_asn1.replace(b"-----BEGIN CERTIFICATE-----\n", b"")
        .replace(b"\n-----END CERTIFICATE-----\n", b"")
        .replace(b"\n", b"")
    )
    cert_asn1_base64_str = cert_asn1_base64.decode("UTF-8")
    return cert_asn1_base64_str


def create_jwk(output_dir):
    # create encryption cert and keypair
    cert_enc, keypair_enc = create_self_signed_cert()

    # derive public key
    jwk_wrapKey = jwk.JWK.from_pem(
        crypto.dump_certificate(crypto.FILETYPE_PEM, cert_enc)
    )
    jwk_wrapKey.setdefault("alg", "RSA-OAEP-256")
    jwk_wrapKey.setdefault("x5c", [cert_to_x5c(cert_enc)])
    jwk_wrapKey.setdefault("key_ops", ["wrapKey"])

    # derive private key
    jwk_unwrapKey = jwk.JWK.from_pem(
        crypto.dump_privatekey(crypto.FILETYPE_PEM, keypair_enc)
    )
    jwk_unwrapKey.setdefault("alg", "RSA-OAEP-256")
    jwk_unwrapKey.setdefault("key_ops", ["unwrapKey"])

    # create signature cert and keypair
    cert_sig, keypair_sig = create_self_signed_cert()

    # derive public key
    jwk_verify = jwk.JWK.from_pem(
        crypto.dump_certificate(crypto.FILETYPE_PEM, cert_sig)
    )
    jwk_verify.setdefault("alg", "PS512")
    jwk_verify.setdefault("x5c", [cert_to_x5c(cert_sig)])
    jwk_verify.setdefault("key_ops", ["verify"])

    # derive private key
    jwk_sign = jwk.JWK.from_pem(
        crypto.dump_privatekey(crypto.FILETYPE_PEM, keypair_sig)
    )
    jwk_sign.setdefault("alg", "PS512")
    jwk_sign.setdefault("key_ops", ["sign"])

    # create JWKS of public keys
    jwks = jwk.JWKSet()
    jwks.add(jwk_wrapKey)
    jwks.add(jwk_verify)

    # define file paths
    output_dir.mkdir(parents=True, exist_ok=True)

    keySet_file = pathlib.Path(output_dir, "set-public-keys.json")
    publicKey_wrapkey_file = pathlib.Path(output_dir, "publicKey_encryption.json")
    publicKey_verify_file = pathlib.Path(
        output_dir, "publicKey_signature_verification.json"
    )
    privateKey_unwrapkey_file = pathlib.Path(output_dir, "privateKey_decryption.json")
    privateKey_sign_file = pathlib.Path(output_dir, "privateKey_signing.json")

    # write JWKS to file
    with open(keySet_file, "wb") as f:
        exp = jwks.export(private_keys=False)
        f.write(exp.encode("UTF-8"))

    # write public keys to file
    with open(publicKey_wrapkey_file, "wb") as f:
        exp = jwk_wrapKey.export(private_key=False)
        f.write(exp.encode("UTF-8"))

    with open(publicKey_verify_file, "wb") as f:
        exp = jwk_verify.export(private_key=False)
        f.write(exp.encode("UTF-8"))

    # write private keys to file
    with open(privateKey_unwrapkey_file, "wb") as f:
        exp = jwk_unwrapKey.export(private_key=True)
        f.write(exp.encode("UTF-8"))

    with open(privateKey_sign_file, "wb") as f:
        exp = jwk_sign.export(private_key=True)
        f.write(exp.encode("UTF-8"))

    print(
        textwrap.dedent(
            f"""\
        🔒 Wrote JWK representation of encryption public key (key_use=wrapKey) to {publicKey_wrapkey_file}
        🔒 Wrote JWK representation of signature validation public key (key_use=verify) to {publicKey_verify_file}
            Please upload these keys when creating a destination in the self-service portal.

        🔒 Wrote JWKS of Public Keys to {keySet_file}
            This key set can be used to update (rotate) keys via the Submission-API (PUT /destinations/{{destinationID}})

        🔒 Wrote JWK representation of decryption private key (key_use=unwrapKey) to {privateKey_unwrapkey_file}
        🔒 Wrote JWK representation of signing private key (key_use=sign) to {privateKey_sign_file}
            These keys can be used to sign and decrypt in your client application."""
        )
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate SET JWKS.")
    parser.add_argument(
        "-o",
        "--output",
        default=tempfile.mkdtemp(),
        help="Directory to store the generated SET JWKS in. Default: a temporary directory",
        type=pathlib.Path,
    )
    args = parser.parse_args()

    create_jwk(args.output)