- Added support for PKCS#1 v2.1 encoding and thus support for the RSAES-OAEP and RSASSA-PSS operations (enabled by POLARSSL_PKCS1_V21)

This commit is contained in:
Paul Bakker 2011-03-08 14:16:06 +00:00
parent fea43a2501
commit 9dcc32236b
16 changed files with 1884 additions and 123 deletions

View file

@ -34,6 +34,7 @@
#if defined(POLARSSL_RSA_C)
#include "polarssl/rsa.h"
#include "polarssl/md.h"
#include <stdlib.h>
#include <string.h>
@ -291,6 +292,55 @@ cleanup:
return( 0 );
}
#if defined(POLARSSL_PKCS1_V21)
/**
* Generate and apply the MGF1 operation (from PKCS#1 v2.1) to a buffer.
*
* @param dst buffer to mask
* @param dlen length of destination buffer
* @param src source of the mask generation
* @param slen length of the source buffer
* @param md_ctx message digest context to use
* @param hlen length of the digest result
*/
static void mgf_mask( unsigned char *dst, int dlen, unsigned char *src, int slen,
md_context_t *md_ctx )
{
unsigned char mask[POLARSSL_MD_MAX_SIZE];
unsigned char counter[4];
unsigned char *p;
int i, use_len, hlen;
memset( mask, 0, POLARSSL_MD_MAX_SIZE );
memset( counter, 0, 4 );
hlen = md_ctx->md_info->size;
// Generate and apply dbMask
//
p = dst;
while( dlen > 0 )
{
use_len = hlen;
if( dlen < hlen )
use_len = dlen;
md_starts( md_ctx );
md_update( md_ctx, src, slen );
md_update( md_ctx, counter, 4 );
md_finish( md_ctx, mask );
for( i = 0; i < use_len; ++i )
*p++ ^= mask[i];
counter[3]++;
dlen -= use_len;
}
}
#endif
/*
* Add the message padding, then do an RSA operation
*/
@ -303,14 +353,22 @@ int rsa_pkcs1_encrypt( rsa_context *ctx,
{
int nb_pad, olen;
unsigned char *p = output;
#if defined(POLARSSL_PKCS1_V21)
const md_info_t *md_info;
md_context_t md_ctx;
int i, hlen;
#endif
olen = ctx->len;
if( f_rng == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
switch( ctx->padding )
{
case RSA_PKCS_V15:
if( ilen < 0 || olen < ilen + 11 || f_rng == NULL )
if( ilen < 0 || olen < ilen + 11 )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
nb_pad = olen - 3 - ilen;
@ -336,6 +394,50 @@ int rsa_pkcs1_encrypt( rsa_context *ctx,
*p++ = 0;
memcpy( p, input, ilen );
break;
#if defined(POLARSSL_PKCS1_V21)
case RSA_PKCS_V21:
md_info = md_info_from_type( ctx->hash_id );
if( md_info == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
hlen = md_get_size( md_info );
if( ilen < 0 || olen < ilen + 2 * hlen + 2 || f_rng == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
memset( output, 0, olen );
memset( &md_ctx, 0, sizeof( md_context_t ) );
md_init_ctx( &md_ctx, md_info );
*p++ = 0;
// Generate a random octet string seed
//
for( i = 0; i < hlen; ++i )
*p++ = (unsigned char) f_rng( p_rng );
// Construct DB
//
md( md_info, p, 0, p );
p += hlen;
p += olen - 2 * hlen - 2 - ilen;
*p++ = 1;
memcpy( p, input, ilen );
// maskedDB: Apply dbMask to DB
//
mgf_mask( output + hlen + 1, olen - hlen - 1, output + 1, hlen,
&md_ctx );
// maskedSeed: Apply seedMask to seed
//
mgf_mask( output + 1, hlen, output + hlen + 1, olen - hlen - 1,
&md_ctx );
break;
#endif
default:
@ -359,6 +461,12 @@ int rsa_pkcs1_decrypt( rsa_context *ctx,
int ret, ilen;
unsigned char *p;
unsigned char buf[1024];
#if defined(POLARSSL_PKCS1_V21)
unsigned char lhash[POLARSSL_MD_MAX_SIZE];
const md_info_t *md_info;
md_context_t md_ctx;
int hlen;
#endif
ilen = ctx->len;
@ -390,6 +498,56 @@ int rsa_pkcs1_decrypt( rsa_context *ctx,
p++;
break;
#if defined(POLARSSL_PKCS1_V21)
case RSA_PKCS_V21:
if( *p++ != 0 )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
md_info = md_info_from_type( ctx->hash_id );
if( md_info == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
hlen = md_get_size( md_info );
memset( &md_ctx, 0, sizeof( md_context_t ) );
md_init_ctx( &md_ctx, md_info );
// Generate lHash
//
md( md_info, lhash, 0, lhash );
// seed: Apply seedMask to maskedSeed
//
mgf_mask( buf + 1, hlen, buf + hlen + 1, ilen - hlen - 1,
&md_ctx );
// DB: Apply dbMask to maskedDB
//
mgf_mask( buf + hlen + 1, ilen - hlen - 1, buf + 1, hlen,
&md_ctx );
p += hlen;
// Check validity
//
if( memcmp( lhash, p, hlen ) != 0 )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
p += hlen;
while( *p == 0 && p < buf + ilen )
p++;
if( p == buf + ilen )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
if( *p++ != 0x01 )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
break;
#endif
default:
return( POLARSSL_ERR_RSA_INVALID_PADDING );
@ -408,6 +566,8 @@ int rsa_pkcs1_decrypt( rsa_context *ctx,
* Do an RSA operation to sign the message digest
*/
int rsa_pkcs1_sign( rsa_context *ctx,
int (*f_rng)(void *),
void *p_rng,
int mode,
int hash_id,
int hashlen,
@ -416,6 +576,15 @@ int rsa_pkcs1_sign( rsa_context *ctx,
{
int nb_pad, olen;
unsigned char *p = sig;
#if defined(POLARSSL_PKCS1_V21)
unsigned char salt[POLARSSL_MD_MAX_SIZE];
const md_info_t *md_info;
md_context_t md_ctx;
int i, hlen, msb, offset = 0;
#else
(void) f_rng;
(void) p_rng;
#endif
olen = ctx->len;
@ -468,63 +637,152 @@ int rsa_pkcs1_sign( rsa_context *ctx,
memset( p, 0xFF, nb_pad );
p += nb_pad;
*p++ = 0;
switch( hash_id )
{
case SIG_RSA_RAW:
memcpy( p, hash, hashlen );
break;
case SIG_RSA_MD2:
memcpy( p, ASN1_HASH_MDX, 18 );
memcpy( p + 18, hash, 16 );
p[13] = 2; break;
case SIG_RSA_MD4:
memcpy( p, ASN1_HASH_MDX, 18 );
memcpy( p + 18, hash, 16 );
p[13] = 4; break;
case SIG_RSA_MD5:
memcpy( p, ASN1_HASH_MDX, 18 );
memcpy( p + 18, hash, 16 );
p[13] = 5; break;
case SIG_RSA_SHA1:
memcpy( p, ASN1_HASH_SHA1, 15 );
memcpy( p + 15, hash, 20 );
break;
case SIG_RSA_SHA224:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 28 );
p[1] += 28; p[14] = 4; p[18] += 28; break;
case SIG_RSA_SHA256:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 32 );
p[1] += 32; p[14] = 1; p[18] += 32; break;
case SIG_RSA_SHA384:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 48 );
p[1] += 48; p[14] = 2; p[18] += 48; break;
case SIG_RSA_SHA512:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 64 );
p[1] += 64; p[14] = 3; p[18] += 64; break;
default:
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
}
break;
#if defined(POLARSSL_PKCS1_V21)
case RSA_PKCS_V21:
if( f_rng == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
switch( hash_id )
{
case SIG_RSA_MD2:
case SIG_RSA_MD4:
case SIG_RSA_MD5:
hashlen = 16;
break;
case SIG_RSA_SHA1:
hashlen = 20;
break;
case SIG_RSA_SHA224:
hashlen = 28;
break;
case SIG_RSA_SHA256:
hashlen = 32;
break;
case SIG_RSA_SHA384:
hashlen = 48;
break;
case SIG_RSA_SHA512:
hashlen = 64;
break;
default:
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
}
md_info = md_info_from_type( ctx->hash_id );
if( md_info == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
hlen = md_get_size( md_info );
memset( sig, 0, olen );
memset( &md_ctx, 0, sizeof( md_context_t ) );
md_init_ctx( &md_ctx, md_info );
msb = mpi_msb( &ctx->N ) - 1;
// Generate salt of length hlen
//
for( i = 0; i < hlen; ++i )
salt[i] = (unsigned char) f_rng( p_rng );
// Note: EMSA-PSS encoding is over the length of N - 1 bits
//
msb = mpi_msb( &ctx->N ) - 1;
p += olen - hlen * 2 - 2;
*p++ = 0x01;
memcpy( p, salt, hlen );
p += hlen;
// Generate H = Hash( M' )
//
md_starts( &md_ctx );
md_update( &md_ctx, p, 8 );
md_update( &md_ctx, hash, hashlen );
md_update( &md_ctx, salt, hlen );
md_finish( &md_ctx, p );
// Compensate for boundary condition when applying mask
//
if( msb % 8 == 0 )
offset = 1;
// maskedDB: Apply dbMask to DB
//
mgf_mask( sig + offset, olen - hlen - 1 - offset, p, hlen, &md_ctx );
msb = mpi_msb( &ctx->N ) - 1;
sig[0] &= 0xFF >> ( olen * 8 - msb );
p += hlen;
*p++ = 0xBC;
break;
#endif
default:
return( POLARSSL_ERR_RSA_INVALID_PADDING );
}
switch( hash_id )
{
case SIG_RSA_RAW:
memcpy( p, hash, hashlen );
break;
case SIG_RSA_MD2:
memcpy( p, ASN1_HASH_MDX, 18 );
memcpy( p + 18, hash, 16 );
p[13] = 2; break;
case SIG_RSA_MD4:
memcpy( p, ASN1_HASH_MDX, 18 );
memcpy( p + 18, hash, 16 );
p[13] = 4; break;
case SIG_RSA_MD5:
memcpy( p, ASN1_HASH_MDX, 18 );
memcpy( p + 18, hash, 16 );
p[13] = 5; break;
case SIG_RSA_SHA1:
memcpy( p, ASN1_HASH_SHA1, 15 );
memcpy( p + 15, hash, 20 );
break;
case SIG_RSA_SHA224:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 28 );
p[1] += 28; p[14] = 4; p[18] += 28; break;
case SIG_RSA_SHA256:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 32 );
p[1] += 32; p[14] = 1; p[18] += 32; break;
case SIG_RSA_SHA384:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 48 );
p[1] += 48; p[14] = 2; p[18] += 48; break;
case SIG_RSA_SHA512:
memcpy( p, ASN1_HASH_SHA2X, 19 );
memcpy( p + 19, hash, 64 );
p[1] += 64; p[14] = 3; p[18] += 64; break;
default:
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
}
return( ( mode == RSA_PUBLIC )
? rsa_public( ctx, sig, sig )
: rsa_private( ctx, sig, sig ) );
@ -543,7 +801,12 @@ int rsa_pkcs1_verify( rsa_context *ctx,
int ret, len, siglen;
unsigned char *p, c;
unsigned char buf[1024];
#if defined(POLARSSL_PKCS1_V21)
unsigned char zeros[8];
const md_info_t *md_info;
md_context_t md_ctx;
int hlen, msb;
#endif
siglen = ctx->len;
if( siglen < 16 || siglen > (int) sizeof( buf ) )
@ -572,67 +835,158 @@ int rsa_pkcs1_verify( rsa_context *ctx,
p++;
}
p++;
len = siglen - (int)( p - buf );
if( len == 34 )
{
c = p[13];
p[13] = 0;
if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
( c == 4 && hash_id == SIG_RSA_MD4 ) ||
( c == 5 && hash_id == SIG_RSA_MD5 ) )
{
if( memcmp( p + 18, hash, 16 ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
}
if( len == 35 && hash_id == SIG_RSA_SHA1 )
{
if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
memcmp( p + 15, hash, 20 ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
{
c = p[1] - 17;
p[1] = 17;
p[14] = 0;
if( p[18] == c &&
memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
memcmp( p + 19, hash, c ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
if( len == hashlen && hash_id == SIG_RSA_RAW )
{
if( memcmp( p, hash, hashlen ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
break;
#if defined(POLARSSL_PKCS1_V21)
case RSA_PKCS_V21:
if( buf[siglen - 1] != 0xBC )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
switch( hash_id )
{
case SIG_RSA_MD2:
case SIG_RSA_MD4:
case SIG_RSA_MD5:
hashlen = 16;
break;
case SIG_RSA_SHA1:
hashlen = 20;
break;
case SIG_RSA_SHA224:
hashlen = 28;
break;
case SIG_RSA_SHA256:
hashlen = 32;
break;
case SIG_RSA_SHA384:
hashlen = 48;
break;
case SIG_RSA_SHA512:
hashlen = 64;
break;
default:
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
}
md_info = md_info_from_type( ctx->hash_id );
if( md_info == NULL )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
hlen = md_get_size( md_info );
memset( &md_ctx, 0, sizeof( md_context_t ) );
memset( zeros, 0, 8 );
md_init_ctx( &md_ctx, md_info );
// Note: EMSA-PSS verification is over the length of N - 1 bits
//
msb = mpi_msb( &ctx->N ) - 1;
// Compensate for boundary condition when applying mask
//
if( msb % 8 == 0 )
{
p++;
siglen -= 1;
}
if( buf[0] >> ( 8 - siglen * 8 + msb ) )
return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
mgf_mask( p, siglen - hlen - 1, p + siglen - hlen - 1, hlen, &md_ctx );
buf[0] &= 0xFF >> ( siglen * 8 - msb );
while( *p == 0 && p < buf + siglen )
p++;
if( p == buf + siglen )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
if( *p++ != 0x01 )
return( POLARSSL_ERR_RSA_INVALID_PADDING );
// Generate H = Hash( M' )
//
md_starts( &md_ctx );
md_update( &md_ctx, zeros, 8 );
md_update( &md_ctx, hash, hashlen );
md_update( &md_ctx, p, hlen );
md_finish( &md_ctx, p );
if( memcmp( p, p + hlen, hlen ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
break;
#endif
default:
return( POLARSSL_ERR_RSA_INVALID_PADDING );
}
len = siglen - (int)( p - buf );
if( len == 34 )
{
c = p[13];
p[13] = 0;
if( memcmp( p, ASN1_HASH_MDX, 18 ) != 0 )
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
if( ( c == 2 && hash_id == SIG_RSA_MD2 ) ||
( c == 4 && hash_id == SIG_RSA_MD4 ) ||
( c == 5 && hash_id == SIG_RSA_MD5 ) )
{
if( memcmp( p + 18, hash, 16 ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
}
if( len == 35 && hash_id == SIG_RSA_SHA1 )
{
if( memcmp( p, ASN1_HASH_SHA1, 15 ) == 0 &&
memcmp( p + 15, hash, 20 ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
if( ( len == 19 + 28 && p[14] == 4 && hash_id == SIG_RSA_SHA224 ) ||
( len == 19 + 32 && p[14] == 1 && hash_id == SIG_RSA_SHA256 ) ||
( len == 19 + 48 && p[14] == 2 && hash_id == SIG_RSA_SHA384 ) ||
( len == 19 + 64 && p[14] == 3 && hash_id == SIG_RSA_SHA512 ) )
{
c = p[1] - 17;
p[1] = 17;
p[14] = 0;
if( p[18] == c &&
memcmp( p, ASN1_HASH_SHA2X, 18 ) == 0 &&
memcmp( p + 19, hash, c ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
if( len == hashlen && hash_id == SIG_RSA_RAW )
{
if( memcmp( p, hash, hashlen ) == 0 )
return( 0 );
else
return( POLARSSL_ERR_RSA_VERIFY_FAILED );
}
return( POLARSSL_ERR_RSA_INVALID_PADDING );
}
@ -789,7 +1143,7 @@ int rsa_self_test( int verbose )
sha1( rsa_plaintext, PT_LEN, sha1sum );
if( rsa_pkcs1_sign( &rsa, RSA_PRIVATE, SIG_RSA_SHA1, 20,
if( rsa_pkcs1_sign( &rsa, NULL, NULL, RSA_PRIVATE, SIG_RSA_SHA1, 20,
sha1sum, rsa_ciphertext ) != 0 )
{
if( verbose != 0 )

View file

@ -667,7 +667,8 @@ static int ssl_write_certificate_verify( ssl_context *ssl )
if( ssl->rsa_key )
{
ret = rsa_pkcs1_sign( ssl->rsa_key, RSA_PRIVATE, SIG_RSA_RAW,
ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
RSA_PRIVATE, SIG_RSA_RAW,
36, hash, ssl->out_msg + 6 );
} else {
#if defined(POLARSSL_PKCS11_C)

View file

@ -619,7 +619,8 @@ static int ssl_write_server_key_exchange( ssl_context *ssl )
if ( ssl->rsa_key )
{
ret = rsa_pkcs1_sign( ssl->rsa_key, RSA_PRIVATE,
ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
RSA_PRIVATE,
SIG_RSA_RAW, 36, hash, ssl->out_msg + 6 + n );
}
#if defined(POLARSSL_PKCS11_C)