| /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ |
| /* vim: set ts=2 et sw=2 tw=80: */ |
| /* This Source Code Form is subject to the terms of the Mozilla Public |
| * License, v. 2.0. If a copy of the MPL was not distributed with this file, |
| * You can obtain one at http://mozilla.org/MPL/2.0/. */ |
| |
| #include "tls_agent.h" |
| #include "databuffer.h" |
| #include "keyhi.h" |
| #include "pk11func.h" |
| #include "ssl.h" |
| #include "sslerr.h" |
| #include "sslexp.h" |
| #include "sslproto.h" |
| #include "tls_filter.h" |
| #include "tls_parser.h" |
| |
| // This is an internal header, used to get DTLS_1_3_DRAFT_VERSION. |
| #include "ssl3prot.h" |
| |
| extern "C" { |
| // This is not something that should make you happy. |
| #include "libssl_internals.h" |
| } |
| |
| #define GTEST_HAS_RTTI 0 |
| #include "gtest/gtest.h" |
| #include "gtest_utils.h" |
| #include "nss_scoped_ptrs.h" |
| |
| extern std::string g_working_dir_path; |
| |
| namespace nss_test { |
| |
| const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; |
| |
| const std::string TlsAgent::kClient = "client"; // both sign and encrypt |
| const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger |
| const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed |
| const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt |
| const std::string TlsAgent::kServerRsaSign = "rsa_sign"; |
| const std::string TlsAgent::kServerRsaPss = "rsa_pss"; |
| const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; |
| const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; |
| const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; |
| const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; |
| const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; |
| const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; |
| const std::string TlsAgent::kServerDsa = "dsa"; |
| const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256"; |
| const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048"; |
| const std::string TlsAgent::kDelegatorRsaPss2048 = "delegator_rsa_pss2048"; |
| |
| static const uint8_t kCannedTls13ServerHello[] = { |
| 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, |
| 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, |
| 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, |
| 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, |
| 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, |
| 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, |
| 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, |
| 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}; |
| |
| TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var) |
| : name_(nm), |
| variant_(var), |
| role_(rl), |
| server_key_bits_(0), |
| adapter_(new DummyPrSocket(role_str(), var)), |
| ssl_fd_(nullptr), |
| state_(STATE_INIT), |
| timer_handle_(nullptr), |
| falsestart_enabled_(false), |
| expected_version_(0), |
| expected_cipher_suite_(0), |
| expect_client_auth_(false), |
| expect_ech_(false), |
| expect_psk_(ssl_psk_none), |
| can_falsestart_hook_called_(false), |
| sni_hook_called_(false), |
| auth_certificate_hook_called_(false), |
| expected_received_alert_(kTlsAlertCloseNotify), |
| expected_received_alert_level_(kTlsAlertWarning), |
| expected_sent_alert_(kTlsAlertCloseNotify), |
| expected_sent_alert_level_(kTlsAlertWarning), |
| handshake_callback_called_(false), |
| resumption_callback_called_(false), |
| error_code_(0), |
| send_ctr_(0), |
| recv_ctr_(0), |
| expect_readwrite_error_(false), |
| handshake_callback_(), |
| auth_certificate_callback_(), |
| sni_callback_(), |
| skip_version_checks_(false), |
| resumption_token_(), |
| policy_() { |
| memset(&info_, 0, sizeof(info_)); |
| memset(&csinfo_, 0, sizeof(csinfo_)); |
| SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| |
| TlsAgent::~TlsAgent() { |
| if (timer_handle_) { |
| timer_handle_->Cancel(); |
| } |
| |
| if (adapter_) { |
| Poller::Instance()->Cancel(READABLE_EVENT, adapter_); |
| } |
| |
| // Add failures manually, if any, so we don't throw in a destructor. |
| if (expected_received_alert_ != kTlsAlertCloseNotify || |
| expected_received_alert_level_ != kTlsAlertWarning) { |
| ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str(); |
| } |
| if (expected_sent_alert_ != kTlsAlertCloseNotify || |
| expected_sent_alert_level_ != kTlsAlertWarning) { |
| ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str(); |
| } |
| } |
| |
| void TlsAgent::SetState(State s) { |
| if (state_ == s) return; |
| |
| LOG("Changing state from " << state_ << " to " << s); |
| state_ = s; |
| } |
| |
| /*static*/ bool TlsAgent::LoadCertificate(const std::string& name, |
| ScopedCERTCertificate* cert, |
| ScopedSECKEYPrivateKey* priv) { |
| cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); |
| EXPECT_NE(nullptr, cert); |
| if (!cert) return false; |
| EXPECT_NE(nullptr, cert->get()); |
| if (!cert->get()) return false; |
| |
| priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); |
| EXPECT_NE(nullptr, priv); |
| if (!priv) return false; |
| EXPECT_NE(nullptr, priv->get()); |
| if (!priv->get()) return false; |
| |
| return true; |
| } |
| |
| // Loads a key pair from the certificate identified by |id|. |
| /*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name, |
| ScopedSECKEYPublicKey* pub, |
| ScopedSECKEYPrivateKey* priv) { |
| ScopedCERTCertificate cert; |
| if (!TlsAgent::LoadCertificate(name, &cert, priv)) { |
| return false; |
| } |
| |
| pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo)); |
| if (!pub->get()) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| void TlsAgent::DelegateCredential(const std::string& name, |
| const ScopedSECKEYPublicKey& dc_pub, |
| SSLSignatureScheme dc_cert_verify_alg, |
| PRUint32 dc_valid_for, PRTime now, |
| SECItem* dc) { |
| ScopedCERTCertificate cert; |
| ScopedSECKEYPrivateKey cert_priv; |
| EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv)) |
| << "Could not load delegate certificate: " << name |
| << "; test db corrupt?"; |
| |
| EXPECT_EQ(SECSuccess, |
| SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(), |
| dc_cert_verify_alg, dc_valid_for, now, dc)); |
| } |
| |
| void TlsAgent::EnableDelegatedCredentials() { |
| ASSERT_TRUE(EnsureTlsSetup()); |
| SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE); |
| } |
| |
| void TlsAgent::AddDelegatedCredential(const std::string& dc_name, |
| SSLSignatureScheme dc_cert_verify_alg, |
| PRUint32 dc_valid_for, PRTime now) { |
| ASSERT_TRUE(EnsureTlsSetup()); |
| |
| ScopedSECKEYPublicKey pub; |
| ScopedSECKEYPrivateKey priv; |
| EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv)); |
| |
| StackSECItem dc; |
| TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for, |
| now, &dc); |
| |
| SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, |
| nullptr, &dc, priv.get()}; |
| EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data)); |
| } |
| |
| bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits, |
| const SSLExtraServerCertData* serverCertData) { |
| ScopedCERTCertificate cert; |
| ScopedSECKEYPrivateKey priv; |
| if (!TlsAgent::LoadCertificate(id, &cert, &priv)) { |
| return false; |
| } |
| |
| if (updateKeyBits) { |
| ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); |
| EXPECT_NE(nullptr, pub.get()); |
| if (!pub.get()) return false; |
| server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); |
| } |
| |
| SECStatus rv = |
| SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); |
| EXPECT_EQ(SECFailure, rv); |
| rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, |
| serverCertData ? sizeof(*serverCertData) : 0); |
| return rv == SECSuccess; |
| } |
| |
| bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { |
| // Don't set up twice |
| if (ssl_fd_) return true; |
| NssManagePolicy policyManage(policy_); |
| |
| ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); |
| EXPECT_NE(nullptr, dummy_fd); |
| if (!dummy_fd) { |
| return false; |
| } |
| if (adapter_->variant() == ssl_variant_stream) { |
| ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); |
| } else { |
| ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); |
| } |
| |
| EXPECT_NE(nullptr, ssl_fd_); |
| if (!ssl_fd_) { |
| return false; |
| } |
| dummy_fd.release(); // Now subsumed by ssl_fd_. |
| |
| SECStatus rv; |
| if (!skip_version_checks_) { |
| rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| } |
| |
| ScopedCERTCertList anchors(CERT_NewCertList()); |
| rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); |
| if (rv != SECSuccess) return false; |
| |
| if (role_ == SERVER) { |
| EXPECT_TRUE(ConfigServerCert(name_, true)); |
| |
| rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| |
| rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| } else { |
| rv = SSL_SetURL(ssl_fd(), "server"); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| } |
| |
| rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| |
| rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| |
| rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| |
| rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); |
| EXPECT_EQ(SECSuccess, rv); |
| if (rv != SECSuccess) return false; |
| |
| // All these tests depend on having this disabled to start with. |
| SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE); |
| |
| return true; |
| } |
| |
| bool TlsAgent::MaybeSetResumptionToken() { |
| if (!resumption_token_.empty()) { |
| LOG("setting external resumption token"); |
| SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), |
| resumption_token_.size()); |
| |
| // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR |
| // if the resumption token was bad (expired/malformed/etc.). |
| if (expect_psk_ == ssl_psk_resume) { |
| // Only in case we expect resumption this has to be successful. We might |
| // not expect resumption due to some reason but the token is totally fine. |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| if (rv != SECSuccess) { |
| EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); |
| resumption_token_.clear(); |
| EXPECT_FALSE(expect_psk_ == ssl_psk_resume); |
| if (expect_psk_ == ssl_psk_resume) return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) { |
| EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get())); |
| } |
| |
| void TlsAgent::SetupClientAuth() { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| ASSERT_EQ(CLIENT, role_); |
| |
| EXPECT_EQ(SECSuccess, |
| SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, |
| reinterpret_cast<void*>(this))); |
| } |
| |
| void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) { |
| ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr)); |
| |
| ASSERT_EQ(expected->nnames, caNames->nnames); |
| |
| for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) { |
| EXPECT_EQ(SECEqual, |
| SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i]))); |
| } |
| } |
| |
| SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, |
| CERTDistNames* caNames, |
| CERTCertificate** clientCert, |
| SECKEYPrivateKey** clientKey) { |
| TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); |
| ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); |
| EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; |
| |
| // See bug 1573945 |
| // CheckCertReqAgainstDefaultCAs(caNames); |
| |
| ScopedCERTCertificate cert; |
| ScopedSECKEYPrivateKey priv; |
| if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { |
| return SECFailure; |
| } |
| |
| *clientCert = cert.release(); |
| *clientKey = priv.release(); |
| return SECSuccess; |
| } |
| |
| bool TlsAgent::GetPeerChainLength(size_t* count) { |
| CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); |
| if (!chain) return false; |
| *count = 0; |
| |
| for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list; |
| cursor = PR_NEXT_LINK(cursor)) { |
| CERTCertListNode* node = (CERTCertListNode*)cursor; |
| std::cerr << node->cert->subjectName << std::endl; |
| ++(*count); |
| } |
| |
| CERT_DestroyCertList(chain); |
| |
| return true; |
| } |
| |
| void TlsAgent::CheckCipherSuite(uint16_t suite) { |
| EXPECT_EQ(csinfo_.cipherSuite, suite); |
| } |
| |
| void TlsAgent::RequestClientAuth(bool requireAuth) { |
| ASSERT_EQ(SERVER, role_); |
| |
| SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); |
| SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE); |
| |
| EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( |
| ssl_fd(), &TlsAgent::ClientAuthenticated, this)); |
| expect_client_auth_ = true; |
| } |
| |
| void TlsAgent::StartConnect(PRFileDesc* model) { |
| EXPECT_TRUE(EnsureTlsSetup(model)); |
| |
| SECStatus rv; |
| rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); |
| EXPECT_EQ(SECSuccess, rv); |
| SetState(STATE_CONNECTING); |
| } |
| |
| void TlsAgent::DisableAllCiphers() { |
| for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { |
| SECStatus rv = |
| SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| } |
| |
| // Not actually all groups, just the onece that we are actually willing |
| // to use. |
| const std::vector<SSLNamedGroup> kAllDHEGroups = { |
| ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, |
| ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, |
| ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; |
| |
| const std::vector<SSLNamedGroup> kECDHEGroups = { |
| ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, |
| ssl_grp_ec_secp521r1}; |
| |
| const std::vector<SSLNamedGroup> kFFDHEGroups = { |
| ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, |
| ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; |
| |
| // Defined because the big DHE groups are ridiculously slow. |
| const std::vector<SSLNamedGroup> kFasterDHEGroups = { |
| ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, |
| ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072}; |
| |
| void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| |
| for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { |
| SSLCipherSuiteInfo csinfo; |
| |
| SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, |
| sizeof(csinfo)); |
| ASSERT_EQ(SECSuccess, rv); |
| EXPECT_EQ(sizeof(csinfo), csinfo.length); |
| |
| if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { |
| rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| } |
| } |
| |
| void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) { |
| switch (kea) { |
| case ssl_kea_dh: |
| ConfigNamedGroups(kFFDHEGroups); |
| break; |
| case ssl_kea_ecdh: |
| ConfigNamedGroups(kECDHEGroups); |
| break; |
| default: |
| break; |
| } |
| } |
| |
| void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) { |
| if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa || |
| authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) { |
| ConfigNamedGroups(kECDHEGroups); |
| } |
| } |
| |
| void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| |
| for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { |
| SSLCipherSuiteInfo csinfo; |
| |
| SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, |
| sizeof(csinfo)); |
| ASSERT_EQ(SECSuccess, rv); |
| |
| if ((csinfo.authType == authType) || |
| (csinfo.keaType == ssl_kea_tls13_any)) { |
| rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| } |
| } |
| |
| void TlsAgent::EnableSingleCipher(uint16_t cipher) { |
| DisableAllCiphers(); |
| SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| |
| void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| |
| void TlsAgent::Set0RttEnabled(bool en) { |
| SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); |
| } |
| |
| void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { |
| vrange_.min = minver; |
| vrange_.max = maxver; |
| |
| if (ssl_fd()) { |
| SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| } |
| |
| SECStatus ResumptionTokenCallback(PRFileDesc* fd, |
| const PRUint8* resumptionToken, |
| unsigned int len, void* ctx) { |
| EXPECT_NE(nullptr, resumptionToken); |
| if (!resumptionToken) { |
| return SECFailure; |
| } |
| |
| std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len); |
| reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token); |
| reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled(); |
| return SECSuccess; |
| } |
| |
| void TlsAgent::SetResumptionTokenCallback() { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| SECStatus rv = |
| SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| |
| void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { |
| *minver = vrange_.min; |
| *maxver = vrange_.max; |
| } |
| |
| void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; } |
| |
| void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } |
| |
| void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } |
| |
| void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } |
| |
| void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, |
| size_t count) { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| EXPECT_LE(count, SSL_SignatureMaxCount()); |
| EXPECT_EQ(SECSuccess, |
| SSL_SignatureSchemePrefSet(ssl_fd(), schemes, |
| static_cast<unsigned int>(count))); |
| EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) |
| << "setting no schemes should fail and do nothing"; |
| |
| std::vector<SSLSignatureScheme> configuredSchemes(count); |
| unsigned int configuredCount; |
| EXPECT_EQ(SECFailure, |
| SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) |
| << "get schemes, schemes is nullptr"; |
| EXPECT_EQ(SECFailure, |
| SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], |
| &configuredCount, 0)) |
| << "get schemes, too little space"; |
| EXPECT_EQ(SECFailure, |
| SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, |
| configuredSchemes.size())) |
| << "get schemes, countOut is nullptr"; |
| |
| EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( |
| ssl_fd(), &configuredSchemes[0], &configuredCount, |
| configuredSchemes.size())); |
| // SignatureSchemePrefSet drops unsupported algorithms silently, so the |
| // number that are configured might be fewer. |
| EXPECT_LE(configuredCount, count); |
| unsigned int i = 0; |
| for (unsigned int j = 0; j < count && i < configuredCount; ++j) { |
| if (i < configuredCount && schemes[j] == configuredSchemes[i]) { |
| ++i; |
| } |
| } |
| EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; |
| } |
| |
| void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group, |
| size_t kea_size) const { |
| EXPECT_EQ(STATE_CONNECTED, state_); |
| EXPECT_EQ(kea, info_.keaType); |
| if (kea_size == 0) { |
| switch (kea_group) { |
| case ssl_grp_ec_curve25519: |
| kea_size = 255; |
| break; |
| case ssl_grp_ec_secp256r1: |
| kea_size = 256; |
| break; |
| case ssl_grp_ec_secp384r1: |
| kea_size = 384; |
| break; |
| case ssl_grp_ffdhe_2048: |
| kea_size = 2048; |
| break; |
| case ssl_grp_ffdhe_3072: |
| kea_size = 3072; |
| break; |
| case ssl_grp_ffdhe_custom: |
| break; |
| default: |
| if (kea == ssl_kea_rsa) { |
| kea_size = server_key_bits_; |
| } else { |
| EXPECT_TRUE(false) << "need to update group sizes"; |
| } |
| } |
| } |
| if (kea_group != ssl_grp_ffdhe_custom) { |
| EXPECT_EQ(kea_size, info_.keaKeyBits); |
| EXPECT_EQ(kea_group, info_.keaGroup); |
| } |
| } |
| |
| void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { |
| if (kea_group != ssl_grp_ffdhe_custom) { |
| EXPECT_EQ(kea_group, info_.originalKeaGroup); |
| } |
| } |
| |
| void TlsAgent::CheckAuthType(SSLAuthType auth, |
| SSLSignatureScheme sig_scheme) const { |
| EXPECT_EQ(STATE_CONNECTED, state_); |
| EXPECT_EQ(auth, info_.authType); |
| if (auth != ssl_auth_psk) { |
| EXPECT_EQ(server_key_bits_, info_.authKeyBits); |
| } |
| if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { |
| switch (auth) { |
| case ssl_auth_rsa_sign: |
| sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; |
| break; |
| case ssl_auth_ecdsa: |
| sig_scheme = ssl_sig_ecdsa_sha1; |
| break; |
| default: |
| break; |
| } |
| } |
| EXPECT_EQ(sig_scheme, info_.signatureScheme); |
| |
| if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) { |
| return; |
| } |
| |
| // Check authAlgorithm, which is the old value for authType. This is a second |
| // switch statement because default label is different. |
| switch (auth) { |
| case ssl_auth_rsa_sign: |
| case ssl_auth_rsa_pss: |
| EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) |
| << "authAlgorithm for RSA is always decrypt"; |
| break; |
| case ssl_auth_ecdh_rsa: |
| EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) |
| << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)"; |
| break; |
| case ssl_auth_ecdh_ecdsa: |
| EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm) |
| << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; |
| break; |
| default: |
| EXPECT_EQ(auth, csinfo_.authAlgorithm) |
| << "authAlgorithm is (usually) the same as authType"; |
| break; |
| } |
| } |
| |
| void TlsAgent::EnableFalseStart() { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| |
| falsestart_enabled_ = true; |
| EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( |
| ssl_fd(), CanFalseStartCallback, this)); |
| SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); |
| } |
| |
| void TlsAgent::ExpectEch(bool expected) { expect_ech_ = expected; } |
| |
| void TlsAgent::ExpectPsk(SSLPskType psk) { expect_psk_ = psk; } |
| |
| void TlsAgent::ExpectResumption() { expect_psk_ = ssl_psk_resume; } |
| |
| void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); |
| } |
| |
| void TlsAgent::AddPsk(const ScopedPK11SymKey& psk, std::string label, |
| SSLHashType hash, uint16_t zeroRttSuite) { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| EXPECT_EQ(SECSuccess, SSL_AddExternalPsk0Rtt( |
| ssl_fd(), psk.get(), |
| reinterpret_cast<const uint8_t*>(label.data()), |
| label.length(), hash, zeroRttSuite, 1000)); |
| } |
| |
| void TlsAgent::RemovePsk(std::string label) { |
| EXPECT_EQ(SECSuccess, |
| SSL_RemoveExternalPsk( |
| ssl_fd(), reinterpret_cast<const uint8_t*>(label.data()), |
| label.length())); |
| } |
| |
| void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, |
| const std::string& expected) const { |
| SSLNextProtoState alpn_state; |
| char chosen[10]; |
| unsigned int chosen_len; |
| SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state, |
| reinterpret_cast<unsigned char*>(chosen), |
| &chosen_len, sizeof(chosen)); |
| EXPECT_EQ(SECSuccess, rv); |
| EXPECT_EQ(expected_state, alpn_state); |
| if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) { |
| EXPECT_EQ("", expected); |
| } else { |
| EXPECT_NE("", expected); |
| EXPECT_EQ(expected, std::string(chosen, chosen_len)); |
| } |
| } |
| |
| void TlsAgent::CheckEpochs(uint16_t expected_read, |
| uint16_t expected_write) const { |
| uint16_t read_epoch = 0; |
| uint16_t write_epoch = 0; |
| EXPECT_EQ(SECSuccess, |
| SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch)); |
| EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch"; |
| EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch"; |
| } |
| |
| void TlsAgent::EnableSrtp() { |
| EXPECT_TRUE(EnsureTlsSetup()); |
| const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, |
| SRTP_AES128_CM_HMAC_SHA1_32}; |
| EXPECT_EQ(SECSuccess, |
| SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); |
| } |
| |
| void TlsAgent::CheckSrtp() const { |
| uint16_t actual; |
| EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); |
| EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); |
| } |
| |
| void TlsAgent::CheckErrorCode(int32_t expected) const { |
| EXPECT_EQ(STATE_ERROR, state_); |
| EXPECT_EQ(expected, error_code_) |
| << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " |
| << PORT_ErrorToName(expected) << std::endl; |
| } |
| |
| static uint8_t GetExpectedAlertLevel(uint8_t alert) { |
| if (alert == kTlsAlertCloseNotify) { |
| return kTlsAlertWarning; |
| } |
| return kTlsAlertFatal; |
| } |
| |
| void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { |
| expected_received_alert_ = alert; |
| if (level == 0) { |
| expected_received_alert_level_ = GetExpectedAlertLevel(alert); |
| } else { |
| expected_received_alert_level_ = level; |
| } |
| } |
| |
| void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { |
| expected_sent_alert_ = alert; |
| if (level == 0) { |
| expected_sent_alert_level_ = GetExpectedAlertLevel(alert); |
| } else { |
| expected_sent_alert_level_ = level; |
| } |
| } |
| |
| void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { |
| LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") |
| << " alert " << (sent ? "sent" : "received") << ": " |
| << static_cast<int>(alert->description)); |
| |
| auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; |
| auto& expected_level = |
| sent ? expected_sent_alert_level_ : expected_received_alert_level_; |
| /* Silently pass close_notify in case the test has already ended. */ |
| if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && |
| alert->description == expected && alert->level == expected_level) { |
| return; |
| } |
| |
| EXPECT_EQ(expected, alert->description); |
| EXPECT_EQ(expected_level, alert->level); |
| expected = kTlsAlertCloseNotify; |
| expected_level = kTlsAlertWarning; |
| } |
| |
| void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { |
| ASSERT_EQ(0, error_code_); |
| WAIT_(error_code_ != 0, delay); |
| EXPECT_EQ(expected, error_code_) |
| << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " |
| << PORT_ErrorToName(expected) << std::endl; |
| } |
| |
| void TlsAgent::CheckPreliminaryInfo() { |
| SSLPreliminaryChannelInfo preinfo; |
| EXPECT_EQ(SECSuccess, |
| SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo))); |
| EXPECT_EQ(sizeof(preinfo), preinfo.length); |
| EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); |
| |
| // A version of 0 is invalid and indicates no expectation. This value is |
| // initialized to 0 so that tests that don't explicitly set an expected |
| // version can negotiate a version. |
| if (!expected_version_) { |
| expected_version_ = preinfo.protocolVersion; |
| } |
| EXPECT_EQ(expected_version_, preinfo.protocolVersion); |
| |
| // As with the version; 0 is the null cipher suite (and also invalid). |
| if (!expected_cipher_suite_) { |
| expected_cipher_suite_ = preinfo.cipherSuite; |
| } |
| EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite); |
| } |
| |
| // Check that all the expected callbacks have been called. |
| void TlsAgent::CheckCallbacks() const { |
| // If false start happens, the handshake is reported as being complete at the |
| // point that false start happens. |
| if (expect_psk_ == ssl_psk_resume || !falsestart_enabled_) { |
| EXPECT_TRUE(handshake_callback_called_); |
| } |
| |
| // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. |
| if (role_ == SERVER) { |
| PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); |
| EXPECT_EQ(((expect_psk_ != ssl_psk_resume && have_sni) || |
| expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), |
| sni_hook_called_); |
| } else { |
| EXPECT_EQ(expect_psk_ == ssl_psk_none, auth_certificate_hook_called_); |
| // Note that this isn't unconditionally called, even with false start on. |
| // But the callback is only skipped if a cipher that is ridiculously weak |
| // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. |
| EXPECT_EQ(falsestart_enabled_ && expect_psk_ != ssl_psk_resume, |
| can_falsestart_hook_called_); |
| } |
| } |
| |
| void TlsAgent::ResetPreliminaryInfo() { |
| expected_version_ = 0; |
| expected_cipher_suite_ = 0; |
| } |
| |
| void TlsAgent::UpdatePreliminaryChannelInfo() { |
| SECStatus rv = |
| SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_)); |
| EXPECT_EQ(SECSuccess, rv); |
| EXPECT_EQ(sizeof(pre_info_), pre_info_.length); |
| } |
| |
| void TlsAgent::ValidateCipherSpecs() { |
| PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); |
| // We use one ciphersuite in each direction. |
| PRInt32 expected = 2; |
| if (variant_ == ssl_variant_datagram) { |
| // For DTLS 1.3, the client retains the cipher spec for early data and the |
| // handshake so that it can retransmit EndOfEarlyData and its final flight. |
| // It also retains the handshake read cipher spec so that it can read ACKs |
| // from the server. The server retains the handshake read cipher spec so it |
| // can read the client's retransmitted Finished. |
| if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { |
| if (role_ == CLIENT) { |
| expected = info_.earlyDataAccepted ? 5 : 4; |
| } else { |
| expected = 3; |
| } |
| } else { |
| // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec |
| // until the holddown timer runs down. |
| if (expect_psk_ == ssl_psk_resume) { |
| if (role_ == CLIENT) { |
| expected = 3; |
| } |
| } else { |
| if (role_ == SERVER) { |
| expected = 3; |
| } |
| } |
| } |
| } |
| // This function will be run before the handshake completes if false start is |
| // enabled. In that case, the client will still be reading cleartext, but |
| // will have a spec prepared for reading ciphertext. With DTLS, the client |
| // will also have a spec retained for retransmission of handshake messages. |
| if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) { |
| EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_); |
| expected = (variant_ == ssl_variant_datagram) ? 4 : 3; |
| } |
| EXPECT_EQ(expected, cipherSpecs); |
| if (expected != cipherSpecs) { |
| SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd()); |
| } |
| } |
| |
| void TlsAgent::Connected() { |
| if (state_ == STATE_CONNECTED) { |
| return; |
| } |
| |
| LOG("Handshake success"); |
| CheckPreliminaryInfo(); |
| CheckCallbacks(); |
| |
| SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); |
| EXPECT_EQ(SECSuccess, rv); |
| EXPECT_EQ(sizeof(info_), info_.length); |
| |
| EXPECT_EQ(expect_psk_ == ssl_psk_resume, info_.resumed == PR_TRUE); |
| EXPECT_EQ(expect_psk_, info_.pskType); |
| EXPECT_EQ(expect_ech_, info_.echAccepted); |
| |
| // Preliminary values are exposed through callbacks during the handshake. |
| // If either expected values were set or the callbacks were called, check |
| // that the final values are correct. |
| UpdatePreliminaryChannelInfo(); |
| EXPECT_EQ(expected_version_, info_.protocolVersion); |
| EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); |
| |
| rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); |
| EXPECT_EQ(SECSuccess, rv); |
| EXPECT_EQ(sizeof(csinfo_), csinfo_.length); |
| |
| ValidateCipherSpecs(); |
| |
| SetState(STATE_CONNECTED); |
| } |
| |
| void TlsAgent::EnableExtendedMasterSecret() { |
| SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); |
| } |
| |
| void TlsAgent::CheckExtendedMasterSecret(bool expected) { |
| if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) { |
| expected = PR_TRUE; |
| } |
| ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) |
| << "unexpected extended master secret state for " << name_; |
| } |
| |
| void TlsAgent::CheckEarlyDataAccepted(bool expected) { |
| if (version() < SSL_LIBRARY_VERSION_TLS_1_3) { |
| expected = false; |
| } |
| ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE) |
| << "unexpected early data state for " << name_; |
| } |
| |
| void TlsAgent::CheckSecretsDestroyed() { |
| ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); |
| } |
| |
| void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) { |
| ASSERT_TRUE(EnsureTlsSetup()); |
| |
| SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver); |
| ASSERT_EQ(SECSuccess, rv); |
| } |
| |
| void TlsAgent::Handshake() { |
| LOGV("Handshake"); |
| SECStatus rv = SSL_ForceHandshake(ssl_fd()); |
| if (rv == SECSuccess) { |
| Connected(); |
| Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, |
| &TlsAgent::ReadableCallback); |
| return; |
| } |
| |
| int32_t err = PR_GetError(); |
| if (err == PR_WOULD_BLOCK_ERROR) { |
| LOGV("Would have blocked"); |
| if (variant_ == ssl_variant_datagram) { |
| if (timer_handle_) { |
| timer_handle_->Cancel(); |
| timer_handle_ = nullptr; |
| } |
| |
| PRIntervalTime timeout; |
| rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); |
| if (rv == SECSuccess) { |
| Poller::Instance()->SetTimer( |
| timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); |
| } |
| } |
| Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, |
| &TlsAgent::ReadableCallback); |
| return; |
| } |
| |
| LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": " |
| << PORT_ErrorToString(err)); |
| error_code_ = err; |
| SetState(STATE_ERROR); |
| } |
| |
| void TlsAgent::PrepareForRenegotiate() { |
| EXPECT_EQ(STATE_CONNECTED, state_); |
| |
| SetState(STATE_CONNECTING); |
| } |
| |
| void TlsAgent::StartRenegotiate() { |
| PrepareForRenegotiate(); |
| |
| SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); |
| EXPECT_EQ(SECSuccess, rv); |
| } |
| |
| void TlsAgent::SendDirect(const DataBuffer& buf) { |
| LOG("Send Direct " << buf); |
| auto peer = adapter_->peer().lock(); |
| if (peer) { |
| peer->PacketReceived(buf); |
| } else { |
| LOG("Send Direct peer absent"); |
| } |
| } |
| |
| void TlsAgent::SendRecordDirect(const TlsRecord& record) { |
| DataBuffer buf; |
| |
| auto rv = record.header.Write(&buf, 0, record.buffer); |
| EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv); |
| SendDirect(buf); |
| } |
| |
| static bool ErrorIsFatal(PRErrorCode code) { |
| return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ; |
| } |
| |
| void TlsAgent::SendData(size_t bytes, size_t blocksize) { |
| uint8_t block[16385]; // One larger than the maximum record size. |
| |
| ASSERT_LE(blocksize, sizeof(block)); |
| |
| while (bytes) { |
| size_t tosend = std::min(blocksize, bytes); |
| |
| for (size_t i = 0; i < tosend; ++i) { |
| block[i] = 0xff & send_ctr_; |
| ++send_ctr_; |
| } |
| |
| SendBuffer(DataBuffer(block, tosend)); |
| bytes -= tosend; |
| } |
| } |
| |
| void TlsAgent::SendBuffer(const DataBuffer& buf) { |
| LOGV("Writing " << buf.len() << " bytes"); |
| int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); |
| if (expect_readwrite_error_) { |
| EXPECT_GT(0, rv); |
| EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); |
| error_code_ = PR_GetError(); |
| expect_readwrite_error_ = false; |
| } else { |
| ASSERT_EQ(buf.len(), static_cast<size_t>(rv)); |
| } |
| } |
| |
| bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, |
| uint64_t seq, uint8_t ct, |
| const DataBuffer& buf) { |
| // Ensure that we are doing TLS 1.3. |
| EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); |
| if (variant_ != ssl_variant_datagram) { |
| ADD_FAILURE(); |
| return false; |
| } |
| |
| LOGV("Encrypting " << buf.len() << " bytes"); |
| uint8_t dtls13_ct = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | |
| kCtDtlsCiphertextLengthPresent; |
| TlsRecordHeader header(variant_, expected_version_, dtls13_ct, seq); |
| TlsRecordHeader out_header(header); |
| DataBuffer padded = buf; |
| padded.Write(padded.len(), ct, 1); |
| DataBuffer ciphertext; |
| if (!spec->Protect(header, padded, &ciphertext, &out_header)) { |
| return false; |
| } |
| |
| DataBuffer record; |
| auto rv = out_header.Write(&record, 0, ciphertext); |
| EXPECT_EQ(out_header.header_length() + ciphertext.len(), rv); |
| SendDirect(record); |
| return true; |
| } |
| |
| void TlsAgent::ReadBytes(size_t amount) { |
| uint8_t block[16384]; |
| |
| size_t remaining = amount; |
| while (remaining > 0) { |
| int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); |
| LOGV("ReadBytes " << rv); |
| |
| if (rv > 0) { |
| size_t count = static_cast<size_t>(rv); |
| for (size_t i = 0; i < count; ++i) { |
| ASSERT_EQ(recv_ctr_ & 0xff, block[i]); |
| recv_ctr_++; |
| } |
| remaining -= rv; |
| } else { |
| PRErrorCode err = 0; |
| if (rv < 0) { |
| err = PR_GetError(); |
| LOG("Read error " << PORT_ErrorToName(err) << ": " |
| << PORT_ErrorToString(err)); |
| if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { |
| error_code_ = err; |
| expect_readwrite_error_ = false; |
| } |
| } |
| if (err != 0 && ErrorIsFatal(err)) { |
| // If we hit a fatal error, we're done. |
| remaining = 0; |
| } |
| break; |
| } |
| } |
| |
| // If closed, then don't bother waiting around. |
| if (remaining) { |
| LOGV("Re-arming"); |
| Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, |
| &TlsAgent::ReadableCallback); |
| } |
| } |
| |
| void TlsAgent::ResetSentBytes(size_t bytes) { send_ctr_ = bytes; } |
| |
| void TlsAgent::SetOption(int32_t option, int value) { |
| ASSERT_TRUE(EnsureTlsSetup()); |
| EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value)); |
| } |
| |
| void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { |
| SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); |
| SetOption(SSL_ENABLE_SESSION_TICKETS, |
| mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); |
| } |
| |
| void TlsAgent::DisableECDHEServerKeyReuse() { |
| ASSERT_EQ(TlsAgent::SERVER, role_); |
| SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); |
| } |
| |
| static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; |
| ::testing::internal::ParamGenerator<std::string> |
| TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); |
| |
| void TlsAgentTestBase::SetUp() { |
| SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); |
| } |
| |
| void TlsAgentTestBase::TearDown() { |
| agent_ = nullptr; |
| SSL_ClearSessionCache(); |
| SSL_ShutdownServerSessionIDCache(); |
| } |
| |
| void TlsAgentTestBase::Reset(const std::string& server_name) { |
| agent_.reset( |
| new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, |
| role_, variant_)); |
| if (version_) { |
| agent_->SetVersionRange(version_, version_); |
| } |
| agent_->adapter()->SetPeer(sink_adapter_); |
| agent_->StartConnect(); |
| } |
| |
| void TlsAgentTestBase::EnsureInit() { |
| if (!agent_) { |
| Reset(); |
| } |
| const std::vector<SSLNamedGroup> groups = { |
| ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, |
| ssl_grp_ffdhe_2048}; |
| agent_->ConfigNamedGroups(groups); |
| } |
| |
| void TlsAgentTestBase::ExpectAlert(uint8_t alert) { |
| EnsureInit(); |
| agent_->ExpectSendAlert(alert); |
| } |
| |
| void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, |
| TlsAgent::State expected_state, |
| int32_t error_code) { |
| std::cerr << "Process message: " << buffer << std::endl; |
| EnsureInit(); |
| agent_->adapter()->PacketReceived(buffer); |
| agent_->Handshake(); |
| |
| ASSERT_EQ(expected_state, agent_->state()); |
| |
| if (expected_state == TlsAgent::STATE_ERROR) { |
| ASSERT_EQ(error_code, agent_->error_code()); |
| } |
| } |
| |
| void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, |
| uint16_t version, const uint8_t* buf, |
| size_t len, DataBuffer* out, |
| uint64_t sequence_number) { |
| // Fixup the content type for DTLSCiphertext |
| if (variant == ssl_variant_datagram && |
| version >= SSL_LIBRARY_VERSION_TLS_1_3 && |
| type == ssl_ct_application_data) { |
| type = kCtDtlsCiphertext | kCtDtlsCiphertext16bSeqno | |
| kCtDtlsCiphertextLengthPresent; |
| } |
| |
| size_t index = 0; |
| if (variant == ssl_variant_stream) { |
| index = out->Write(index, type, 1); |
| index = out->Write(index, version, 2); |
| } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && |
| (type & kCtDtlsCiphertextMask) == kCtDtlsCiphertext) { |
| uint32_t epoch = (sequence_number >> 48) & 0x3; |
| index = out->Write(index, type | epoch, 1); |
| uint32_t seqno = sequence_number & ((1ULL << 16) - 1); |
| index = out->Write(index, seqno, 2); |
| } else { |
| index = out->Write(index, type, 1); |
| index = out->Write(index, TlsVersionToDtlsVersion(version), 2); |
| index = out->Write(index, sequence_number >> 32, 4); |
| index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); |
| } |
| index = out->Write(index, len, 2); |
| out->Write(index, buf, len); |
| } |
| |
| void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, |
| const uint8_t* buf, size_t len, |
| DataBuffer* out, uint64_t seq_num) const { |
| MakeRecord(variant_, type, version, buf, len, out, seq_num); |
| } |
| |
| void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, |
| const uint8_t* data, size_t hs_len, |
| DataBuffer* out, |
| uint64_t seq_num) const { |
| return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0, |
| 0); |
| } |
| |
| void TlsAgentTestBase::MakeHandshakeMessageFragment( |
| uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, |
| uint64_t seq_num, uint32_t fragment_offset, |
| uint32_t fragment_length) const { |
| size_t index = 0; |
| if (!fragment_length) fragment_length = hs_len; |
| index = out->Write(index, hs_type, 1); // Handshake record type. |
| index = out->Write(index, hs_len, 3); // Handshake length |
| if (variant_ == ssl_variant_datagram) { |
| index = out->Write(index, seq_num, 2); |
| index = out->Write(index, fragment_offset, 3); |
| index = out->Write(index, fragment_length, 3); |
| } |
| if (data) { |
| index = out->Write(index, data, fragment_length); |
| } else { |
| for (size_t i = 0; i < fragment_length; ++i) { |
| index = out->Write(index, 1, 1); |
| } |
| } |
| } |
| |
| void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, |
| size_t hs_len, |
| DataBuffer* out) { |
| size_t index = 0; |
| index = out->Write(index, ssl_ct_handshake, 1); // Content Type |
| index = out->Write(index, 3, 1); // Version high |
| index = out->Write(index, 1, 1); // Version low |
| index = out->Write(index, 4 + hs_len, 2); // Length |
| |
| index = out->Write(index, hs_type, 1); // Handshake record type. |
| index = out->Write(index, hs_len, 3); // Handshake length |
| for (size_t i = 0; i < hs_len; ++i) { |
| index = out->Write(index, 1, 1); |
| } |
| } |
| |
| DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() { |
| DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello)); |
| if (variant_ == ssl_variant_datagram) { |
| sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2); |
| // The version should be at the end. |
| uint32_t v; |
| EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v)); |
| EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v); |
| sh.Write(sh.len() - 2, 0x7f00 | DTLS_1_3_DRAFT_VERSION, 2); |
| } |
| return sh; |
| } |
| |
| } // namespace nss_test |