blob: 2b6e8864568a81c10d8f28acc339a79f19069972 [file] [log] [blame]
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#ifndef tls_filter_h_
#define tls_filter_h_
#include <functional>
#include <memory>
#include <set>
#include <vector>
#include "sslt.h"
#include "test_io.h"
#include "tls_agent.h"
#include "tls_parser.h"
#include "tls_protect.h"
extern "C" {
#include "libssl_internals.h"
}
namespace nss_test {
class TlsCipherSpec;
class TlsVersioned {
public:
TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
TlsVersioned(SSLProtocolVariant var, uint16_t ver)
: variant_(var), version_(ver) {}
bool is_dtls() const { return variant_ == ssl_variant_datagram; }
SSLProtocolVariant variant() const { return variant_; }
uint16_t version() const { return version_; }
void WriteStream(std::ostream& stream) const;
protected:
SSLProtocolVariant variant_;
uint16_t version_;
};
class TlsRecordHeader : public TlsVersioned {
public:
TlsRecordHeader()
: TlsVersioned(), content_type_(0), sequence_number_(0), header_() {}
TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
uint64_t seqno)
: TlsVersioned(var, ver),
content_type_(ct),
sequence_number_(seqno),
header_() {}
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
uint16_t epoch() const {
return static_cast<uint16_t>(sequence_number_ >> 48);
}
size_t header_length() const;
const DataBuffer& header() const { return header_; }
// Parse the header; return true if successful; body in an outparam if OK.
bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;
private:
static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial,
size_t partial_bits);
static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw,
size_t seq_no_bits, size_t epoch_bits);
uint8_t content_type_;
uint64_t sequence_number_;
DataBuffer header_;
};
struct TlsRecord {
const TlsRecordHeader header;
const DataBuffer buffer;
};
// Make a filter and install it on a TlsAgent.
template <class T, typename... Args>
inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
Args&&... args) {
auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...);
agent->SetFilter(filter);
return filter;
}
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
: agent_(a),
count_(0),
cipher_spec_(),
dropped_record_(false),
in_sequence_number_(0),
out_sequence_number_(0) {}
std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
// Report how many packets were altered by the filter.
size_t filtered_packets() const { return count_; }
// Enable decryption. This only works properly for TLS 1.3 and above.
// Enabling it for lower version tests will cause undefined
// behavior.
void EnableDecryption();
bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
uint8_t* inner_content_type, DataBuffer* plaintext);
bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type,
const DataBuffer& plaintext, DataBuffer* ciphertext,
size_t padding = 0);
protected:
// There are two filter functions which can be overriden. Both are
// called with the header and the record but the outer one is called
// with a raw pointer to let you write into the buffer and lets you
// do anything with this section of the stream. The inner one
// just lets you change the record contents. By default, the
// outer one calls the inner one, so if you override the outer
// one, the inner one is never called unless you call it yourself.
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record,
size_t* offset, DataBuffer* output);
// The record filter receives the record contentType, version and DTLS
// sequence number (which is zero for TLS), plus the existing record payload.
// It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
// outparam with the new record contents if it chooses to CHANGE the record.
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) {
return KEEP;
}
bool is_dtls13() const;
private:
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
std::weak_ptr<TlsAgent> agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
// Whether we dropped a record since the cipher spec changed.
bool dropped_record_;
// The sequence number we use for reading records as they are written.
uint64_t in_sequence_number_;
// The sequence number we use for writing modified records.
uint64_t out_sequence_number_;
};
inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
v.WriteStream(stream);
return stream;
}
inline std::ostream& operator<<(std::ostream& stream,
const TlsRecordHeader& hdr) {
hdr.WriteStream(stream);
stream << ' ';
switch (hdr.content_type()) {
case ssl_ct_change_cipher_spec:
stream << "CCS";
break;
case ssl_ct_alert:
stream << "Alert";
break;
case ssl_ct_handshake:
stream << "Handshake";
break;
case ssl_ct_application_data:
stream << "Data";
break;
case ssl_ct_ack:
stream << "ACK";
break;
default:
stream << '<' << static_cast<int>(hdr.content_type()) << '>';
break;
}
return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
}
// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a)
: TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {}
TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a,
const std::set<uint8_t>& types)
: TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {}
// This filter can be set to be selective based on handshake message type. If
// this function isn't used (or the set is empty), then all handshake messages
// will be filtered.
void SetHandshakeTypes(const std::set<uint8_t>& types) {
handshake_types_ = types;
}
class HandshakeHeader : public TlsVersioned {
public:
HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {}
uint8_t handshake_type() const { return handshake_type_; }
bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
const DataBuffer& preceding_fragment, DataBuffer* body,
bool* complete);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
size_t WriteFragment(DataBuffer* buffer, size_t offset,
const DataBuffer& body, size_t fragment_offset,
size_t fragment_length) const;
private:
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
uint32_t expected_offset, uint32_t* length,
bool* last_fragment);
uint8_t handshake_type_;
uint16_t message_seq_;
// fragment_offset is always zero in these tests.
};
protected:
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) = 0;
private:
bool IsFilteredType(const HandshakeHeader& header,
const DataBuffer& handshake);
std::set<uint8_t> handshake_types_;
DataBuffer preceding_fragment_;
};
// Make a copy of the first instance of a handshake message.
class TlsHandshakeRecorder : public TlsHandshakeFilter {
public:
TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
uint8_t handshake_type)
: TlsHandshakeFilter(a, {handshake_type}), buffer_() {}
TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
const std::set<uint8_t>& handshake_types)
: TlsHandshakeFilter(a, handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
void Reset() { buffer_.Truncate(0); }
const DataBuffer& buffer() const { return buffer_; }
private:
DataBuffer buffer_;
};
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a,
uint8_t handshake_type,
const DataBuffer& replacement)
: TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
DataBuffer buffer_;
};
// Make a copy of each record of a given type.
class TlsRecordRecorder : public TlsRecordFilter {
public:
TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct)
: TlsRecordFilter(a), filter_(true), ct_(ct), records_() {}
TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a)
: TlsRecordFilter(a),
filter_(false),
ct_(ssl_ct_handshake), // dummy (<optional> is C++14)
records_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
size_t count() const { return records_.size(); }
void Clear() { records_.clear(); }
const TlsRecord& record(size_t i) const { return records_[i]; }
private:
bool filter_;
uint8_t ct_;
std::vector<TlsRecord> records_;
};
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a,
DataBuffer& buffer)
: TlsRecordFilter(a), buffer_(buffer) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
DataBuffer buffer_;
};
// Make a copy of the records
class TlsHeaderRecorder : public TlsRecordFilter {
public:
TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
const TlsRecordHeader* header(size_t index);
private:
std::vector<TlsRecordHeader> headers_;
};
typedef std::initializer_list<std::shared_ptr<PacketFilter>>
ChainedPacketFilterInit;
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
DataBuffer* output);
// Takes ownership of the filter.
void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }
private:
std::vector<std::shared_ptr<PacketFilter>> filters_;
};
typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
TlsExtensionFinder;
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a)
: TlsHandshakeFilter(a,
{kTlsHandshakeClientHello, kTlsHandshakeServerHello,
kTlsHandshakeHelloRetryRequest,
kTlsHandshakeEncryptedExtensions}) {}
TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a,
const std::set<uint8_t>& types)
: TlsHandshakeFilter(a, types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override;
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) = 0;
private:
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
};
class TlsExtensionCapture : public TlsExtensionFilter {
public:
TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
bool last = false)
: TlsExtensionFilter(a),
extension_(ext),
captured_(false),
last_(last),
data_() {}
const DataBuffer& extension() const { return data_; }
bool captured() const { return captured_; }
protected:
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
private:
const uint16_t extension_;
bool captured_;
bool last_;
DataBuffer data_;
};
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
const DataBuffer& data)
: TlsExtensionFilter(a), extension_(extension), data_(data) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
private:
const uint16_t extension_;
const DataBuffer data_;
};
class TlsExtensionDropper : public TlsExtensionFilter {
public:
TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension)
: TlsExtensionFilter(a), extension_(extension) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer&, DataBuffer*) override;
private:
uint16_t extension_;
};
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
const DataBuffer& data)
: TlsHandshakeFilter(a), extension_(ext), data_(data) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override;
private:
const uint16_t extension_;
const DataBuffer data_;
};
class TlsExtensionDamager : public TlsExtensionFilter {
public:
TlsExtensionDamager(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
size_t index)
: TlsExtensionFilter(a), extension_(extension), index_(index) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output);
private:
uint16_t extension_;
size_t index_;
};
typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
AfterRecordN(const std::shared_ptr<TlsAgent>& src,
const std::shared_ptr<TlsAgent>& dest, unsigned int record,
VoidFunction func)
: TlsRecordFilter(src),
dest_(dest),
record_(record),
func_(func),
counter_(0) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
unsigned int counter_;
};
// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
public:
TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
const std::shared_ptr<TlsAgent>& server)
: TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
std::weak_ptr<TlsAgent> server_;
};
// Damage a record.
class TlsRecordLastByteDamager : public TlsRecordFilter {
public:
TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a)
: TlsRecordFilter(a) {}
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) override {
*changed = data;
changed->data()[changed->len() - 1]++;
return CHANGE;
}
};
// This class selectively drops complete writes. This relies on the fact that
// writes in libssl are on record boundaries.
class SelectiveDropFilter : public PacketFilter {
public:
SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {}
protected:
virtual PacketFilter::Action Filter(const DataBuffer& input,
DataBuffer* output) override;
private:
const uint32_t pattern_;
uint8_t counter_;
};
// This class selectively drops complete records. The difference from
// SelectiveDropFilter is that if multiple DTLS records are in the same
// datagram, we just drop one.
class SelectiveRecordDropFilter : public TlsRecordFilter {
public:
SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
uint32_t pattern, bool enabled = true)
: TlsRecordFilter(a), pattern_(pattern), counter_(0) {
if (!enabled) {
Disable();
}
}
SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
std::initializer_list<size_t> records)
: SelectiveRecordDropFilter(a, ToPattern(records), true) {}
void Reset(uint32_t pattern) {
counter_ = 0;
PacketFilter::Enable();
pattern_ = pattern;
}
void Reset(std::initializer_list<size_t> records) {
Reset(ToPattern(records));
}
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) override;
private:
static uint32_t ToPattern(std::initializer_list<size_t> records);
uint32_t pattern_;
uint8_t counter_;
};
// Set the version number in the ClientHello.
class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
public:
TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& a,
uint16_t version)
: TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
uint16_t version_;
};
// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
public:
TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type)
: TlsHandshakeFilter(a), type_(type) {}
PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
if (header.handshake_type() != type_) {
return KEEP;
}
*output = input;
output->data()[output->len() - 1]++;
return CHANGE;
}
private:
uint8_t type_;
};
class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a,
uint16_t suite)
: TlsHandshakeFilter(a, {kTlsHandshakeServerHello}),
cipher_suite_(suite) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override;
private:
uint16_t cipher_suite_;
};
} // namespace nss_test
#endif