| /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ |
| /* vim: set ts=2 et sw=2 tw=80: */ |
| /* This Source Code Form is subject to the terms of the Mozilla Public |
| * License, v. 2.0. If a copy of the MPL was not distributed with this file, |
| * You can obtain one at http://mozilla.org/MPL/2.0/. */ |
| |
| #ifndef tls_filter_h_ |
| #define tls_filter_h_ |
| |
| #include <functional> |
| #include <memory> |
| #include <set> |
| #include <vector> |
| #include "pk11pub.h" |
| #include "sslt.h" |
| #include "sslproto.h" |
| #include "test_io.h" |
| #include "tls_agent.h" |
| #include "tls_parser.h" |
| #include "tls_protect.h" |
| |
| extern "C" { |
| #include "libssl_internals.h" |
| } |
| |
| namespace nss_test { |
| |
| class TlsCipherSpec; |
| |
| class TlsSendCipherSpecCapturer { |
| public: |
| TlsSendCipherSpecCapturer(const std::shared_ptr<TlsAgent>& agent) |
| : agent_(agent), send_cipher_specs_() { |
| EXPECT_EQ(SECSuccess, |
| SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this)); |
| } |
| |
| std::shared_ptr<TlsCipherSpec> spec(size_t i) { |
| if (i >= send_cipher_specs_.size()) { |
| return nullptr; |
| } |
| return send_cipher_specs_[i]; |
| } |
| |
| private: |
| static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, |
| SSLSecretDirection dir, PK11SymKey* secret, |
| void* arg) { |
| auto self = static_cast<TlsSendCipherSpecCapturer*>(arg); |
| std::cerr << self->agent_->role_str() << ": capture " << dir |
| << " secret for epoch " << epoch << std::endl; |
| |
| if (dir == ssl_secret_read) { |
| return; |
| } |
| |
| SSLPreliminaryChannelInfo preinfo; |
| EXPECT_EQ(SECSuccess, |
| SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo, |
| sizeof(preinfo))); |
| EXPECT_EQ(sizeof(preinfo), preinfo.length); |
| EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); |
| |
| // Check the version: |
| EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); |
| ASSERT_GE(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); |
| |
| SSLCipherSuiteInfo cipherinfo; |
| EXPECT_EQ(SECSuccess, |
| SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo, |
| sizeof(cipherinfo))); |
| EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); |
| |
| auto spec = std::make_shared<TlsCipherSpec>(true, epoch); |
| EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret)); |
| self->send_cipher_specs_.push_back(spec); |
| } |
| |
| std::shared_ptr<TlsAgent> agent_; |
| std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_; |
| }; |
| |
| class TlsVersioned { |
| public: |
| TlsVersioned() : variant_(ssl_variant_stream), version_(0) {} |
| TlsVersioned(SSLProtocolVariant var, uint16_t ver) |
| : variant_(var), version_(ver) {} |
| |
| bool is_dtls() const { return variant_ == ssl_variant_datagram; } |
| SSLProtocolVariant variant() const { return variant_; } |
| uint16_t version() const { return version_; } |
| |
| void WriteStream(std::ostream& stream) const; |
| |
| protected: |
| SSLProtocolVariant variant_; |
| uint16_t version_; |
| }; |
| |
| class TlsRecordHeader : public TlsVersioned { |
| public: |
| TlsRecordHeader() |
| : TlsVersioned(), |
| content_type_(0), |
| guess_seqno_(0), |
| seqno_is_masked_(false), |
| sequence_number_(0), |
| header_() {} |
| TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct, |
| uint64_t seqno) |
| : TlsVersioned(var, ver), |
| content_type_(ct), |
| guess_seqno_(0), |
| seqno_is_masked_(false), |
| sequence_number_(seqno), |
| header_(), |
| sn_mask_() {} |
| |
| bool is_protected() const { |
| // *TLS < 1.3 |
| if (version() < SSL_LIBRARY_VERSION_TLS_1_3 && |
| content_type() == ssl_ct_application_data) { |
| return true; |
| } |
| |
| // TLS 1.3 |
| if (!is_dtls() && version() >= SSL_LIBRARY_VERSION_TLS_1_3 && |
| content_type() == ssl_ct_application_data) { |
| return true; |
| } |
| |
| // DTLS 1.3 |
| return is_dtls13_ciphertext(); |
| } |
| |
| uint8_t content_type() const { return content_type_; } |
| uint16_t epoch() const { |
| return static_cast<uint16_t>(sequence_number_ >> 48); |
| } |
| uint64_t sequence_number() const { return sequence_number_; } |
| void sequence_number(uint64_t seqno) { sequence_number_ = seqno; } |
| const DataBuffer& sn_mask() const { return sn_mask_; } |
| bool is_dtls13_ciphertext() const { |
| return is_dtls() && (version() >= SSL_LIBRARY_VERSION_TLS_1_3) && |
| (content_type() & kCtDtlsCiphertextMask) == kCtDtlsCiphertext; |
| } |
| |
| size_t header_length() const; |
| const DataBuffer& header() const { return header_; } |
| |
| bool MaskSequenceNumber(); |
| bool MaskSequenceNumber(const DataBuffer& mask_buf); |
| |
| // Parse the header; return true if successful; body in an outparam if OK. |
| bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, |
| DataBuffer* body); |
| // Write the header and body to a buffer at the given offset. |
| // Return the offset of the end of the write. |
| size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; |
| size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const; |
| |
| private: |
| static uint64_t RecoverSequenceNumber(uint64_t guess_seqno, uint32_t partial, |
| size_t partial_bits); |
| uint64_t ParseSequenceNumber(uint64_t expected, uint64_t raw, |
| size_t seq_no_bits, size_t epoch_bits); |
| |
| uint8_t content_type_; |
| uint64_t guess_seqno_; |
| bool seqno_is_masked_; |
| uint64_t sequence_number_; |
| DataBuffer header_; |
| DataBuffer sn_mask_; |
| }; |
| |
| struct TlsRecord { |
| const TlsRecordHeader header; |
| const DataBuffer buffer; |
| }; |
| |
| // Make a filter and install it on a TlsAgent. |
| template <class T, typename... Args> |
| inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent, |
| Args&&... args) { |
| auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...); |
| agent->SetFilter(filter); |
| return filter; |
| } |
| |
| // Abstract filter that operates on entire (D)TLS records. |
| class TlsRecordFilter : public PacketFilter { |
| public: |
| TlsRecordFilter(const std::shared_ptr<TlsAgent>& a); |
| |
| std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); } |
| |
| // External interface. Overrides PacketFilter. |
| PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); |
| |
| // Report how many packets were altered by the filter. |
| size_t filtered_packets() const { return count_; } |
| |
| // Enable decryption. This only works properly for TLS 1.3 and above. |
| // Enabling it for lower version tests will cause undefined |
| // behavior. |
| void EnableDecryption(); |
| bool decrypting() const { return decrypting_; }; |
| bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, |
| uint16_t* protection_epoch, uint8_t* inner_content_type, |
| DataBuffer* plaintext, TlsRecordHeader* out_header); |
| bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, |
| uint8_t inner_content_type, const DataBuffer& plaintext, |
| DataBuffer* ciphertext, TlsRecordHeader* out_header, |
| size_t padding = 0); |
| |
| protected: |
| // There are two filter functions which can be overriden. Both are |
| // called with the header and the record but the outer one is called |
| // with a raw pointer to let you write into the buffer and lets you |
| // do anything with this section of the stream. The inner one |
| // just lets you change the record contents. By default, the |
| // outer one calls the inner one, so if you override the outer |
| // one, the inner one is never called unless you call it yourself. |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& record, |
| size_t* offset, DataBuffer* output); |
| |
| // The record filter receives the record contentType, version and DTLS |
| // sequence number (which is zero for TLS), plus the existing record payload. |
| // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` |
| // outparam with the new record contents if it chooses to CHANGE the record. |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& data, |
| DataBuffer* changed) { |
| return KEEP; |
| } |
| |
| bool is_dtls13() const; |
| bool is_dtls13_ciphertext(uint8_t ct) const; |
| TlsCipherSpec& spec(uint16_t epoch); |
| |
| private: |
| static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, |
| SSLSecretDirection dir, PK11SymKey* secret, |
| void* arg); |
| |
| std::weak_ptr<TlsAgent> agent_; |
| size_t count_ = 0; |
| std::vector<TlsCipherSpec> cipher_specs_; |
| bool decrypting_ = false; |
| }; |
| |
| inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { |
| v.WriteStream(stream); |
| return stream; |
| } |
| |
| inline std::ostream& operator<<(std::ostream& stream, |
| const TlsRecordHeader& hdr) { |
| hdr.WriteStream(stream); |
| stream << ' '; |
| switch (hdr.content_type()) { |
| case ssl_ct_change_cipher_spec: |
| stream << "CCS"; |
| break; |
| case ssl_ct_alert: |
| stream << "Alert"; |
| break; |
| case ssl_ct_handshake: |
| stream << "Handshake"; |
| break; |
| case ssl_ct_application_data: |
| stream << "Data"; |
| break; |
| case ssl_ct_ack: |
| stream << "ACK"; |
| break; |
| default: |
| stream << '<' << static_cast<int>(hdr.content_type()) << '>'; |
| break; |
| } |
| return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; |
| } |
| |
| // Abstract filter that operates on handshake messages rather than records. |
| // This assumes that the handshake messages are written in a block as entire |
| // records and that they don't span records or anything crazy like that. |
| class TlsHandshakeFilter : public TlsRecordFilter { |
| public: |
| TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a) |
| : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {} |
| TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a, |
| const std::set<uint8_t>& types) |
| : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {} |
| |
| // This filter can be set to be selective based on handshake message type. If |
| // this function isn't used (or the set is empty), then all handshake messages |
| // will be filtered. |
| void SetHandshakeTypes(const std::set<uint8_t>& types) { |
| handshake_types_ = types; |
| } |
| |
| class HandshakeHeader : public TlsVersioned { |
| public: |
| HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {} |
| |
| uint8_t handshake_type() const { return handshake_type_; } |
| bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, |
| const DataBuffer& preceding_fragment, DataBuffer* body, |
| bool* complete); |
| size_t Write(DataBuffer* buffer, size_t offset, |
| const DataBuffer& body) const; |
| size_t WriteFragment(DataBuffer* buffer, size_t offset, |
| const DataBuffer& body, size_t fragment_offset, |
| size_t fragment_length) const; |
| |
| private: |
| // Reads the length from the record header. |
| // This also reads the DTLS fragment information and checks it. |
| bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, |
| uint32_t expected_offset, uint32_t* length, |
| bool* last_fragment); |
| |
| uint8_t handshake_type_; |
| uint16_t message_seq_; |
| // fragment_offset is always zero in these tests. |
| }; |
| |
| protected: |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output) = 0; |
| |
| private: |
| bool IsFilteredType(const HandshakeHeader& header, |
| const DataBuffer& handshake); |
| |
| std::set<uint8_t> handshake_types_; |
| DataBuffer preceding_fragment_; |
| }; |
| |
| // Make a copy of the first instance of a handshake message. |
| class TlsHandshakeRecorder : public TlsHandshakeFilter { |
| public: |
| TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a, |
| uint8_t handshake_type) |
| : TlsHandshakeFilter(a, {handshake_type}), buffer_() {} |
| TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a, |
| const std::set<uint8_t>& handshake_types) |
| : TlsHandshakeFilter(a, handshake_types), buffer_() {} |
| |
| virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| void Reset() { buffer_.Truncate(0); } |
| |
| const DataBuffer& buffer() const { return buffer_; } |
| |
| private: |
| DataBuffer buffer_; |
| }; |
| |
| // Replace all instances of a handshake message. |
| class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { |
| public: |
| TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a, |
| uint8_t handshake_type, |
| const DataBuffer& replacement) |
| : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {} |
| |
| virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| private: |
| DataBuffer buffer_; |
| }; |
| |
| // Make a copy of each record of a given type. |
| class TlsRecordRecorder : public TlsRecordFilter { |
| public: |
| TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct) |
| : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {} |
| TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a) |
| : TlsRecordFilter(a), |
| filter_(false), |
| ct_(ssl_ct_handshake), // dummy (<optional> is C++14) |
| records_() {} |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| size_t count() const { return records_.size(); } |
| void Clear() { records_.clear(); } |
| |
| const TlsRecord& record(size_t i) const { return records_[i]; } |
| |
| private: |
| bool filter_; |
| uint8_t ct_; |
| std::vector<TlsRecord> records_; |
| }; |
| |
| // Make a copy of the complete conversation. |
| class TlsConversationRecorder : public TlsRecordFilter { |
| public: |
| TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a, |
| DataBuffer& buffer) |
| : TlsRecordFilter(a), buffer_(buffer) {} |
| |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| private: |
| DataBuffer buffer_; |
| }; |
| |
| // Make a copy of the records |
| class TlsHeaderRecorder : public TlsRecordFilter { |
| public: |
| TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {} |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| const TlsRecordHeader* header(size_t index); |
| |
| private: |
| std::vector<TlsRecordHeader> headers_; |
| }; |
| |
| typedef std::initializer_list<std::shared_ptr<PacketFilter>> |
| ChainedPacketFilterInit; |
| |
| // Runs multiple packet filters in series. |
| class ChainedPacketFilter : public PacketFilter { |
| public: |
| ChainedPacketFilter() {} |
| ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) |
| : filters_(filters.begin(), filters.end()) {} |
| ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} |
| virtual ~ChainedPacketFilter() {} |
| |
| virtual PacketFilter::Action Filter(const DataBuffer& input, |
| DataBuffer* output); |
| |
| // Takes ownership of the filter. |
| void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); } |
| |
| private: |
| std::vector<std::shared_ptr<PacketFilter>> filters_; |
| }; |
| |
| typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> |
| TlsExtensionFinder; |
| |
| class TlsExtensionFilter : public TlsHandshakeFilter { |
| public: |
| TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a) |
| : TlsHandshakeFilter(a, |
| {kTlsHandshakeClientHello, kTlsHandshakeServerHello, |
| kTlsHandshakeHelloRetryRequest, |
| kTlsHandshakeEncryptedExtensions}) {} |
| |
| TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a, |
| const std::set<uint8_t>& types) |
| : TlsHandshakeFilter(a, types) {} |
| |
| static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); |
| |
| protected: |
| PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| virtual PacketFilter::Action FilterExtension(uint16_t extension_type, |
| const DataBuffer& input, |
| DataBuffer* output) = 0; |
| |
| private: |
| PacketFilter::Action FilterExtensions(TlsParser* parser, |
| const DataBuffer& input, |
| DataBuffer* output); |
| }; |
| |
| class TlsExtensionCapture : public TlsExtensionFilter { |
| public: |
| TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext, |
| bool last = false) |
| : TlsExtensionFilter(a), |
| extension_(ext), |
| captured_(false), |
| last_(last), |
| data_() {} |
| |
| const DataBuffer& extension() const { return data_; } |
| bool captured() const { return captured_; } |
| |
| protected: |
| PacketFilter::Action FilterExtension(uint16_t extension_type, |
| const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| private: |
| const uint16_t extension_; |
| bool captured_; |
| bool last_; |
| DataBuffer data_; |
| }; |
| |
| class TlsExtensionReplacer : public TlsExtensionFilter { |
| public: |
| TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension, |
| const DataBuffer& data) |
| : TlsExtensionFilter(a), extension_(extension), data_(data) {} |
| PacketFilter::Action FilterExtension(uint16_t extension_type, |
| const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| private: |
| const uint16_t extension_; |
| const DataBuffer data_; |
| }; |
| |
| class TlsExtensionResizer : public TlsExtensionFilter { |
| public: |
| TlsExtensionResizer(const std::shared_ptr<TlsAgent>& a, uint16_t extension, |
| size_t length) |
| : TlsExtensionFilter(a), extension_(extension), length_(length) {} |
| PacketFilter::Action FilterExtension(uint16_t extension_type, |
| const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| private: |
| uint16_t extension_; |
| size_t length_; |
| }; |
| |
| class TlsExtensionAppender : public TlsHandshakeFilter { |
| public: |
| TlsExtensionAppender(const std::shared_ptr<TlsAgent>& a, |
| uint8_t handshake_type, uint16_t ext, DataBuffer& data) |
| : TlsHandshakeFilter(a, {handshake_type}), extension_(ext), data_(data) {} |
| |
| virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| private: |
| bool UpdateLength(DataBuffer* output, size_t offset, size_t size); |
| |
| const uint16_t extension_; |
| const DataBuffer data_; |
| }; |
| |
| class TlsExtensionDropper : public TlsExtensionFilter { |
| public: |
| TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension) |
| : TlsExtensionFilter(a), extension_(extension) {} |
| PacketFilter::Action FilterExtension(uint16_t extension_type, |
| const DataBuffer&, DataBuffer*) override; |
| |
| private: |
| uint16_t extension_; |
| }; |
| |
| class TlsHandshakeDropper : public TlsHandshakeFilter { |
| public: |
| TlsHandshakeDropper(const std::shared_ptr<TlsAgent>& a) |
| : TlsHandshakeFilter(a) {} |
| |
| protected: |
| PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output) override { |
| return DROP; |
| } |
| }; |
| |
| class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter { |
| public: |
| TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr<TlsAgent>& a, |
| uint8_t old_ct, uint8_t new_ct) |
| : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {} |
| |
| protected: |
| PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& record, size_t* offset, |
| DataBuffer* output) override { |
| if (header.content_type() != ssl_ct_application_data) { |
| return KEEP; |
| } |
| |
| uint16_t protection_epoch = 0; |
| uint8_t inner_content_type; |
| DataBuffer plaintext; |
| TlsRecordHeader out_header; |
| if (!Unprotect(header, record, &protection_epoch, &inner_content_type, |
| &plaintext, &out_header) || |
| !plaintext.len()) { |
| return KEEP; |
| } |
| |
| if (inner_content_type != ssl_ct_handshake) { |
| return KEEP; |
| } |
| |
| size_t off = 0; |
| uint32_t msg_len = 0; |
| uint32_t msg_type = 255; // Not a real message |
| do { |
| if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) { |
| break; |
| } |
| |
| // Increment and check next messages |
| if (!plaintext.Read(++off, 3, &msg_len)) { |
| break; |
| } |
| off += 3 + msg_len; |
| } while (msg_type != old_ct_); |
| |
| if (msg_type == old_ct_) { |
| plaintext.Write(off, new_ct_, 1); |
| } |
| |
| DataBuffer ciphertext; |
| bool ok = Protect(spec(protection_epoch), out_header, inner_content_type, |
| plaintext, &ciphertext, &out_header); |
| if (!ok) { |
| return KEEP; |
| } |
| *offset = out_header.Write(output, *offset, ciphertext); |
| return CHANGE; |
| } |
| |
| private: |
| uint8_t old_ct_; |
| uint8_t new_ct_; |
| }; |
| |
| class TlsExtensionInjector : public TlsHandshakeFilter { |
| public: |
| TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext, |
| const DataBuffer& data) |
| : TlsHandshakeFilter(a), extension_(ext), data_(data) {} |
| |
| protected: |
| PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| private: |
| const uint16_t extension_; |
| const DataBuffer data_; |
| }; |
| |
| class TlsExtensionDamager : public TlsExtensionFilter { |
| public: |
| TlsExtensionDamager(const std::shared_ptr<TlsAgent>& a, uint16_t extension, |
| size_t index) |
| : TlsExtensionFilter(a), extension_(extension), index_(index) {} |
| virtual PacketFilter::Action FilterExtension(uint16_t extension_type, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| private: |
| uint16_t extension_; |
| size_t index_; |
| }; |
| |
| typedef std::function<void(void)> VoidFunction; |
| |
| class AfterRecordN : public TlsRecordFilter { |
| public: |
| AfterRecordN(const std::shared_ptr<TlsAgent>& src, |
| const std::shared_ptr<TlsAgent>& dest, unsigned int record, |
| VoidFunction func) |
| : TlsRecordFilter(src), |
| dest_(dest), |
| record_(record), |
| func_(func), |
| counter_(0) {} |
| |
| virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& body, |
| DataBuffer* out) override; |
| |
| private: |
| std::weak_ptr<TlsAgent> dest_; |
| unsigned int record_; |
| VoidFunction func_; |
| unsigned int counter_; |
| }; |
| |
| // When we see the ClientKeyExchange from |client|, increment the |
| // ClientHelloVersion on |server|. |
| class TlsClientHelloVersionChanger : public TlsHandshakeFilter { |
| public: |
| TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client, |
| const std::shared_ptr<TlsAgent>& server) |
| : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}), |
| server_(server) {} |
| |
| virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| private: |
| std::weak_ptr<TlsAgent> server_; |
| }; |
| |
| // Damage a record. |
| class TlsRecordLastByteDamager : public TlsRecordFilter { |
| public: |
| TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a) |
| : TlsRecordFilter(a) {} |
| |
| protected: |
| PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& data, |
| DataBuffer* changed) override { |
| *changed = data; |
| changed->data()[changed->len() - 1]++; |
| return CHANGE; |
| } |
| }; |
| |
| // This class selectively drops complete writes. This relies on the fact that |
| // writes in libssl are on record boundaries. |
| class SelectiveDropFilter : public PacketFilter { |
| public: |
| SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {} |
| |
| protected: |
| virtual PacketFilter::Action Filter(const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| private: |
| const uint32_t pattern_; |
| uint8_t counter_; |
| }; |
| |
| // This class selectively drops complete records. The difference from |
| // SelectiveDropFilter is that if multiple DTLS records are in the same |
| // datagram, we just drop one. |
| class SelectiveRecordDropFilter : public TlsRecordFilter { |
| public: |
| SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a, |
| uint32_t pattern, bool on = true) |
| : TlsRecordFilter(a), pattern_(pattern), counter_(0) { |
| if (!on) { |
| Disable(); |
| } |
| } |
| SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a, |
| std::initializer_list<size_t> records) |
| : SelectiveRecordDropFilter(a, ToPattern(records), true) {} |
| |
| void Reset(uint32_t pattern) { |
| counter_ = 0; |
| PacketFilter::Enable(); |
| pattern_ = pattern; |
| } |
| |
| void Reset(std::initializer_list<size_t> records) { |
| Reset(ToPattern(records)); |
| } |
| |
| protected: |
| PacketFilter::Action FilterRecord(const TlsRecordHeader& header, |
| const DataBuffer& data, |
| DataBuffer* changed) override; |
| |
| private: |
| static uint32_t ToPattern(std::initializer_list<size_t> records); |
| |
| uint32_t pattern_; |
| uint8_t counter_; |
| }; |
| |
| // Set the version number in the ClientHello. |
| class TlsClientHelloVersionSetter : public TlsHandshakeFilter { |
| public: |
| TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& a, |
| uint16_t version) |
| : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), version_(version) {} |
| |
| virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output); |
| |
| private: |
| uint16_t version_; |
| }; |
| |
| // Damages the last byte of a handshake message. |
| class TlsLastByteDamager : public TlsHandshakeFilter { |
| public: |
| TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type) |
| : TlsHandshakeFilter(a), type_(type) {} |
| PacketFilter::Action FilterHandshake( |
| const TlsHandshakeFilter::HandshakeHeader& header, |
| const DataBuffer& input, DataBuffer* output) override { |
| if (header.handshake_type() != type_) { |
| return KEEP; |
| } |
| |
| *output = input; |
| |
| output->data()[output->len() - 1]++; |
| return CHANGE; |
| } |
| |
| private: |
| uint8_t type_; |
| }; |
| |
| class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { |
| public: |
| SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a, |
| uint16_t suite) |
| : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), |
| cipher_suite_(suite) {} |
| |
| protected: |
| PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |
| const DataBuffer& input, |
| DataBuffer* output) override; |
| |
| private: |
| uint16_t cipher_suite_; |
| }; |
| |
| } // namespace nss_test |
| |
| #endif |