diff --git a/library/pk.c b/library/pk.c index ec3741b13..097777f2c 100644 --- a/library/pk.c +++ b/library/pk.c @@ -1126,6 +1126,12 @@ int mbedtls_pk_verify_ext(mbedtls_pk_type_t type, const void *options, return mbedtls_pk_verify(ctx, md_alg, hash, hash_len, sig, sig_len); } + /* Ensure the PK context is of the right type otherwise mbedtls_pk_rsa() + * below would return a NULL pointer. */ + if (mbedtls_pk_get_type(ctx) != MBEDTLS_PK_RSA) { + return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE; + } + #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PKCS1_V21) int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; const mbedtls_pk_rsassa_pss_options *pss_opts;