| /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ |
| /* vim: set ts=2 et sw=2 tw=80: */ |
| /* This Source Code Form is subject to the terms of the Mozilla Public |
| * License, v. 2.0. If a copy of the MPL was not distributed with this file, |
| * You can obtain one at http://mozilla.org/MPL/2.0/. */ |
| |
| #ifndef tls_agent_h_ |
| #define tls_agent_h_ |
| |
| #include "prio.h" |
| #include "ssl.h" |
| #include "sslproto.h" |
| |
| #include <functional> |
| #include <iostream> |
| |
| #include "nss_policy.h" |
| #include "test_io.h" |
| |
| #define GTEST_HAS_RTTI 0 |
| #include "gtest/gtest.h" |
| #include "nss_scoped_ptrs.h" |
| #include "scoped_ptrs_ssl.h" |
| |
| extern bool g_ssl_gtest_verbose; |
| |
| namespace nss_test { |
| |
| #define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl |
| #define LOGV(msg) \ |
| do { \ |
| if (g_ssl_gtest_verbose) LOG(msg); \ |
| } while (false) |
| |
| enum SessionResumptionMode { |
| RESUME_NONE = 0, |
| RESUME_SESSIONID = 1, |
| RESUME_TICKET = 2, |
| RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET |
| }; |
| |
| class PacketFilter; |
| class TlsAgent; |
| class TlsCipherSpec; |
| struct TlsRecord; |
| |
| const extern std::vector<SSLNamedGroup> kAllDHEGroups; |
| const extern std::vector<SSLNamedGroup> kECDHEGroups; |
| const extern std::vector<SSLNamedGroup> kFFDHEGroups; |
| const extern std::vector<SSLNamedGroup> kFasterDHEGroups; |
| |
| // These functions are called from callbacks. They use bare pointers because |
| // TlsAgent sets up the callback and it doesn't know who owns it. |
| typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)> |
| AuthCertificateCallbackFunction; |
| |
| typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction; |
| |
| typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr, |
| PRUint32 srvNameArrSize)> |
| SniCallbackFunction; |
| |
| class TlsAgent : public PollTarget { |
| public: |
| enum Role { CLIENT, SERVER }; |
| enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR }; |
| |
| static const std::string kClient; // the client key is sign only |
| static const std::string kRsa2048; // bigger sign and encrypt for either |
| static const std::string kRsa8192; // biggest sign and encrypt for either |
| static const std::string kServerRsa; // both sign and encrypt |
| static const std::string kServerRsaSign; |
| static const std::string kServerRsaPss; |
| static const std::string kServerRsaDecrypt; |
| static const std::string kServerEcdsa256; |
| static const std::string kServerEcdsa384; |
| static const std::string kServerEcdsa521; |
| static const std::string kServerEcdhEcdsa; |
| static const std::string kServerEcdhRsa; |
| static const std::string kServerDsa; |
| static const std::string kDelegatorEcdsa256; // draft-ietf-tls-subcerts |
| static const std::string kDelegatorRsae2048; // draft-ietf-tls-subcerts |
| static const std::string kDelegatorRsaPss2048; // draft-ietf-tls-subcerts |
| |
| TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); |
| virtual ~TlsAgent(); |
| |
| void SetPeer(std::shared_ptr<TlsAgent>& peer) { |
| adapter_->SetPeer(peer->adapter_); |
| } |
| |
| void SetFilter(std::shared_ptr<PacketFilter> filter) { |
| adapter_->SetPacketFilter(filter); |
| } |
| void ClearFilter() { adapter_->SetPacketFilter(nullptr); } |
| |
| void StartConnect(PRFileDesc* model = nullptr); |
| void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, |
| size_t kea_size = 0) const; |
| void CheckOriginalKEA(SSLNamedGroup kea_group) const; |
| void CheckAuthType(SSLAuthType auth_type, |
| SSLSignatureScheme sig_scheme) const; |
| |
| void DisableAllCiphers(); |
| void EnableCiphersByAuthType(SSLAuthType authType); |
| void EnableCiphersByKeyExchange(SSLKEAType kea); |
| void EnableGroupsByKeyExchange(SSLKEAType kea); |
| void EnableGroupsByAuthType(SSLAuthType authType); |
| void EnableSingleCipher(uint16_t cipher); |
| |
| void Handshake(); |
| // Marks the internal state as CONNECTING in anticipation of renegotiation. |
| void PrepareForRenegotiate(); |
| // Prepares for renegotiation, then actually triggers it. |
| void StartRenegotiate(); |
| void SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx); |
| |
| static bool LoadCertificate(const std::string& name, |
| ScopedCERTCertificate* cert, |
| ScopedSECKEYPrivateKey* priv); |
| static bool LoadKeyPairFromCert(const std::string& name, |
| ScopedSECKEYPublicKey* pub, |
| ScopedSECKEYPrivateKey* priv); |
| |
| // Delegated credentials. |
| // |
| // Generate a delegated credential and sign it using the certificate |
| // associated with |name|. |
| static void DelegateCredential(const std::string& name, |
| const ScopedSECKEYPublicKey& dcPub, |
| SSLSignatureScheme dcCertVerifyAlg, |
| PRUint32 dcValidFor, PRTime now, SECItem* dc); |
| // Indicate support for the delegated credentials extension. |
| void EnableDelegatedCredentials(); |
| // Generate and configure a delegated credential to use in the handshake with |
| // clients that support this extension.. |
| void AddDelegatedCredential(const std::string& dc_name, |
| SSLSignatureScheme dcCertVerifyAlg, |
| PRUint32 dcValidFor, PRTime now); |
| void UpdatePreliminaryChannelInfo(); |
| |
| bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, |
| const SSLExtraServerCertData* serverCertData = nullptr); |
| bool ConfigServerCertWithChain(const std::string& name); |
| bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr); |
| |
| void SetupClientAuth(); |
| void RequestClientAuth(bool requireAuth); |
| |
| void SetOption(int32_t option, int value); |
| void ConfigureSessionCache(SessionResumptionMode mode); |
| void Set0RttEnabled(bool en); |
| void SetFallbackSCSVEnabled(bool en); |
| void SetVersionRange(uint16_t minver, uint16_t maxver); |
| void GetVersionRange(uint16_t* minver, uint16_t* maxver); |
| void CheckPreliminaryInfo(); |
| void ResetPreliminaryInfo(); |
| void SetExpectedVersion(uint16_t version); |
| void SetServerKeyBits(uint16_t bits); |
| void ExpectReadWriteError(); |
| void EnableFalseStart(); |
| void ExpectEch(bool expected = true); |
| bool GetEchExpected() const { return expect_ech_; } |
| void ExpectPsk(SSLPskType psk = ssl_psk_external); |
| void ExpectResumption(); |
| void SkipVersionChecks(); |
| void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); |
| void EnableAlpn(const uint8_t* val, size_t len); |
| void CheckAlpn(SSLNextProtoState expected_state, |
| const std::string& expected = "") const; |
| void EnableSrtp(); |
| void CheckSrtp() const; |
| void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const; |
| void CheckErrorCode(int32_t expected) const; |
| void WaitForErrorCode(int32_t expected, uint32_t delay) const; |
| // Send data on the socket, encrypting it. |
| void SendData(size_t bytes, size_t blocksize = 1024); |
| void SendBuffer(const DataBuffer& buf); |
| bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, |
| uint64_t seq, uint8_t ct, const DataBuffer& buf); |
| // Send data directly to the underlying socket, skipping the TLS layer. |
| void SendDirect(const DataBuffer& buf); |
| void SendRecordDirect(const TlsRecord& record); |
| void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, |
| uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); |
| void RemovePsk(std::string label); |
| void ReadBytes(size_t max = 16384U); |
| void ResetSentBytes(size_t bytes = 0); // Hack to test drops. |
| void EnableExtendedMasterSecret(); |
| void CheckExtendedMasterSecret(bool expected); |
| void CheckEarlyDataAccepted(bool expected); |
| void CheckEchAccepted(bool expected); |
| void SetDowngradeCheckVersion(uint16_t version); |
| void CheckSecretsDestroyed(); |
| void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); |
| void DisableECDHEServerKeyReuse(); |
| bool GetPeerChainLength(size_t* count); |
| void CheckCipherSuite(uint16_t cipher_suite); |
| void SetResumptionTokenCallback(); |
| bool MaybeSetResumptionToken(); |
| void SetResumptionToken(const std::vector<uint8_t>& resumption_token) { |
| resumption_token_ = resumption_token; |
| } |
| const std::vector<uint8_t>& GetResumptionToken() const { |
| return resumption_token_; |
| } |
| void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { |
| SECStatus rv = SSL_GetResumptionTokenInfo( |
| resumption_token_.data(), resumption_token_.size(), token.get(), |
| sizeof(SSLResumptionTokenInfo)); |
| ASSERT_EQ(SECSuccess, rv); |
| } |
| void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } |
| bool resumption_callback_called() const { |
| return resumption_callback_called_; |
| } |
| |
| const std::string& name() const { return name_; } |
| |
| Role role() const { return role_; } |
| std::string role_str() const { return role_ == SERVER ? "server" : "client"; } |
| |
| SSLProtocolVariant variant() const { return variant_; } |
| |
| State state() const { return state_; } |
| |
| const CERTCertificate* peer_cert() const { |
| return SSL_PeerCertificate(ssl_fd_.get()); |
| } |
| |
| const char* state_str() const { return state_str(state()); } |
| |
| static const char* state_str(State state) { return states[state]; } |
| |
| NssManagedFileDesc ssl_fd() const { |
| return NssManagedFileDesc(ssl_fd_.get(), policy_); |
| } |
| std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; } |
| |
| const SSLChannelInfo& info() const { |
| EXPECT_EQ(STATE_CONNECTED, state_); |
| return info_; |
| } |
| |
| const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; } |
| |
| bool is_compressed() const { |
| return info().compressionMethod != ssl_compression_null; |
| } |
| uint16_t server_key_bits() const { return server_key_bits_; } |
| uint16_t min_version() const { return vrange_.min; } |
| uint16_t max_version() const { return vrange_.max; } |
| uint16_t version() const { return info().protocolVersion; } |
| |
| bool cipher_suite(uint16_t* suite) const { |
| if (state_ != STATE_CONNECTED) return false; |
| |
| *suite = info_.cipherSuite; |
| return true; |
| } |
| |
| void expected_cipher_suite(uint16_t suite) { expected_cipher_suite_ = suite; } |
| |
| std::string cipher_suite_name() const { |
| if (state_ != STATE_CONNECTED) return "UNKNOWN"; |
| |
| return csinfo_.cipherSuiteName; |
| } |
| |
| std::vector<uint8_t> session_id() const { |
| return std::vector<uint8_t>(info_.sessionID, |
| info_.sessionID + info_.sessionIDLength); |
| } |
| |
| bool auth_type(SSLAuthType* a) const { |
| if (state_ != STATE_CONNECTED) return false; |
| |
| *a = info_.authType; |
| return true; |
| } |
| |
| bool kea_type(SSLKEAType* k) const { |
| if (state_ != STATE_CONNECTED) return false; |
| |
| *k = info_.keaType; |
| return true; |
| } |
| |
| size_t received_bytes() const { return recv_ctr_; } |
| PRErrorCode error_code() const { return error_code_; } |
| |
| bool can_falsestart_hook_called() const { |
| return can_falsestart_hook_called_; |
| } |
| |
| void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) { |
| handshake_callback_ = handshake_callback; |
| } |
| |
| void SetAuthCertificateCallback( |
| AuthCertificateCallbackFunction auth_certificate_callback) { |
| auth_certificate_callback_ = auth_certificate_callback; |
| } |
| |
| void SetSniCallback(SniCallbackFunction sni_callback) { |
| sni_callback_ = sni_callback; |
| } |
| |
| void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0); |
| void ExpectSendAlert(uint8_t alert, uint8_t level = 0); |
| |
| std::string alpn_value_to_use_ = ""; |
| // set the given policy before this agent runs |
| void SetPolicy(SECOidTag oid, PRUint32 set, PRUint32 clear) { |
| policy_ = NssPolicy(oid, set, clear); |
| } |
| |
| private: |
| const static char* states[]; |
| |
| void SetState(State state); |
| void ValidateCipherSpecs(); |
| |
| // Dummy auth certificate hook. |
| static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, |
| PRBool checksig, PRBool isServer) { |
| TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); |
| agent->CheckPreliminaryInfo(); |
| agent->auth_certificate_hook_called_ = true; |
| if (agent->auth_certificate_callback_) { |
| return agent->auth_certificate_callback_(agent, checksig ? true : false, |
| isServer ? true : false); |
| } |
| return SECSuccess; |
| } |
| |
| // Client auth certificate hook. |
| static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd, |
| PRBool checksig, PRBool isServer) { |
| TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); |
| EXPECT_TRUE(agent->expect_client_auth_); |
| EXPECT_EQ(PR_TRUE, isServer); |
| if (agent->auth_certificate_callback_) { |
| return agent->auth_certificate_callback_(agent, checksig ? true : false, |
| isServer ? true : false); |
| } |
| return SECSuccess; |
| } |
| |
| static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, |
| CERTDistNames* caNames, |
| CERTCertificate** cert, |
| SECKEYPrivateKey** privKey); |
| |
| static void ReadableCallback(PollTarget* self, Event event) { |
| TlsAgent* agent = static_cast<TlsAgent*>(self); |
| if (event == TIMER_EVENT) { |
| agent->timer_handle_ = nullptr; |
| } |
| agent->ReadableCallback_int(); |
| } |
| |
| void ReadableCallback_int() { |
| LOGV("Readable"); |
| switch (state_) { |
| case STATE_CONNECTING: |
| Handshake(); |
| break; |
| case STATE_CONNECTED: |
| ReadBytes(); |
| break; |
| default: |
| break; |
| } |
| } |
| |
| static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr, |
| PRUint32 srvNameArrSize, void* arg) { |
| TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); |
| agent->CheckPreliminaryInfo(); |
| agent->sni_hook_called_ = true; |
| EXPECT_EQ(1UL, srvNameArrSize); |
| if (agent->sni_callback_) { |
| return agent->sni_callback_(agent, srvNameArr, srvNameArrSize); |
| } |
| return 0; // First configuration. |
| } |
| |
| static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg, |
| PRBool* canFalseStart) { |
| TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); |
| agent->CheckPreliminaryInfo(); |
| EXPECT_TRUE(agent->falsestart_enabled_); |
| EXPECT_FALSE(agent->can_falsestart_hook_called_); |
| agent->can_falsestart_hook_called_ = true; |
| *canFalseStart = true; |
| return SECSuccess; |
| } |
| |
| void CheckAlert(bool sent, const SSLAlert* alert); |
| |
| static void AlertReceivedCallback(const PRFileDesc* fd, void* arg, |
| const SSLAlert* alert) { |
| reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert); |
| } |
| |
| static void AlertSentCallback(const PRFileDesc* fd, void* arg, |
| const SSLAlert* alert) { |
| reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert); |
| } |
| |
| static void HandshakeCallback(PRFileDesc* fd, void* arg) { |
| TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); |
| agent->handshake_callback_called_ = true; |
| agent->Connected(); |
| if (agent->handshake_callback_) { |
| agent->handshake_callback_(agent); |
| } |
| } |
| |
| void DisableLameGroups(); |
| void ConfigStrongECGroups(bool en); |
| void ConfigAllDHGroups(bool en); |
| void CheckCallbacks() const; |
| void Connected(); |
| |
| const std::string name_; |
| SSLProtocolVariant variant_; |
| Role role_; |
| uint16_t server_key_bits_; |
| std::shared_ptr<DummyPrSocket> adapter_; |
| ScopedPRFileDesc ssl_fd_; |
| State state_; |
| std::shared_ptr<Poller::Timer> timer_handle_; |
| bool falsestart_enabled_; |
| uint16_t expected_version_; |
| uint16_t expected_cipher_suite_; |
| bool expect_client_auth_; |
| bool expect_ech_; |
| SSLPskType expect_psk_; |
| bool can_falsestart_hook_called_; |
| bool sni_hook_called_; |
| bool auth_certificate_hook_called_; |
| uint8_t expected_received_alert_; |
| uint8_t expected_received_alert_level_; |
| uint8_t expected_sent_alert_; |
| uint8_t expected_sent_alert_level_; |
| bool handshake_callback_called_; |
| bool resumption_callback_called_; |
| SSLChannelInfo info_; |
| SSLPreliminaryChannelInfo pre_info_; |
| SSLCipherSuiteInfo csinfo_; |
| SSLVersionRange vrange_; |
| PRErrorCode error_code_; |
| size_t send_ctr_; |
| size_t recv_ctr_; |
| bool expect_readwrite_error_; |
| HandshakeCallbackFunction handshake_callback_; |
| AuthCertificateCallbackFunction auth_certificate_callback_; |
| SniCallbackFunction sni_callback_; |
| bool skip_version_checks_; |
| std::vector<uint8_t> resumption_token_; |
| NssPolicy policy_; |
| }; |
| |
| inline std::ostream& operator<<(std::ostream& stream, |
| const TlsAgent::State& state) { |
| return stream << TlsAgent::state_str(state); |
| } |
| |
| class TlsAgentTestBase : public ::testing::Test { |
| public: |
| static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll; |
| |
| TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant, |
| uint16_t version = 0) |
| : agent_(nullptr), |
| role_(role), |
| variant_(variant), |
| version_(version), |
| sink_adapter_(new DummyPrSocket("sink", variant)) {} |
| virtual ~TlsAgentTestBase() {} |
| |
| void SetUp(); |
| void TearDown(); |
| |
| void ExpectAlert(uint8_t alert); |
| |
| static void MakeRecord(SSLProtocolVariant variant, uint8_t type, |
| uint16_t version, const uint8_t* buf, size_t len, |
| DataBuffer* out, uint64_t seq_num = 0); |
| void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, |
| size_t len, DataBuffer* out, uint64_t seq_num = 0) const; |
| void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, |
| DataBuffer* out, uint64_t seq_num = 0) const; |
| void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data, |
| size_t hs_len, DataBuffer* out, |
| uint64_t seq_num, uint32_t fragment_offset, |
| uint32_t fragment_length) const; |
| DataBuffer MakeCannedTls13ServerHello(); |
| static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, |
| DataBuffer* out); |
| static inline TlsAgent::Role ToRole(const std::string& str) { |
| return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; |
| } |
| |
| void Init(const std::string& server_name = TlsAgent::kServerRsa); |
| void Reset(const std::string& server_name = TlsAgent::kServerRsa); |
| |
| protected: |
| void EnsureInit(); |
| void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, |
| int32_t error_code = 0); |
| |
| std::shared_ptr<TlsAgent> agent_; |
| TlsAgent::Role role_; |
| SSLProtocolVariant variant_; |
| uint16_t version_; |
| // This adapter is here just to accept packets from this agent. |
| std::shared_ptr<DummyPrSocket> sink_adapter_; |
| }; |
| |
| class TlsAgentTest |
| : public TlsAgentTestBase, |
| public ::testing::WithParamInterface< |
| std::tuple<std::string, SSLProtocolVariant, uint16_t>> { |
| public: |
| TlsAgentTest() |
| : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), |
| std::get<1>(GetParam()), std::get<2>(GetParam())) {} |
| }; |
| |
| class TlsAgentTestClient : public TlsAgentTestBase, |
| public ::testing::WithParamInterface< |
| std::tuple<SSLProtocolVariant, uint16_t>> { |
| public: |
| TlsAgentTestClient() |
| : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()), |
| std::get<1>(GetParam())) {} |
| }; |
| |
| class TlsAgentTestClient13 : public TlsAgentTestClient {}; |
| |
| class TlsAgentStreamTestClient : public TlsAgentTestBase { |
| public: |
| TlsAgentStreamTestClient() |
| : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {} |
| }; |
| |
| class TlsAgentStreamTestServer : public TlsAgentTestBase { |
| public: |
| TlsAgentStreamTestServer() |
| : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {} |
| }; |
| |
| class TlsAgentDgramTestClient : public TlsAgentTestBase { |
| public: |
| TlsAgentDgramTestClient() |
| : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {} |
| }; |
| |
| inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) { |
| return vr1.min == vr2.min && vr1.max == vr2.max; |
| } |
| |
| } // namespace nss_test |
| |
| #endif |