/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef tls_filter_h_
#define tls_filter_h_

#include <functional>
#include <memory>
#include <set>
#include <vector>
#include "sslt.h"
#include "test_io.h"
#include "tls_agent.h"
#include "tls_parser.h"
#include "tls_protect.h"

extern "C" {
#include "libssl_internals.h"
}

namespace nss_test {

class TlsCipherSpec;

class TlsVersioned {
 public:
  TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
  TlsVersioned(SSLProtocolVariant var, uint16_t ver)
      : variant_(var), version_(ver) {}

  bool is_dtls() const { return variant_ == ssl_variant_datagram; }
  SSLProtocolVariant variant() const { return variant_; }
  uint16_t version() const { return version_; }

  void WriteStream(std::ostream& stream) const;

 protected:
  SSLProtocolVariant variant_;
  uint16_t version_;
};

class TlsRecordHeader : public TlsVersioned {
 public:
  TlsRecordHeader()
      : TlsVersioned(), content_type_(0), sequence_number_(0), header_() {}
  TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
                  uint64_t seqno)
      : TlsVersioned(var, ver),
        content_type_(ct),
        sequence_number_(seqno),
        header_() {}

  uint8_t content_type() const { return content_type_; }
  uint64_t sequence_number() const { return sequence_number_; }
  uint16_t epoch() const {
    return static_cast<uint16_t>(sequence_number_ >> 48);
  }
  size_t header_length() const;
  const DataBuffer& header() const { return header_; }

  // Parse the header; return true if successful; body in an outparam if OK.
  bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
             DataBuffer* body);
  // Write the header and body to a buffer at the given offset.
  // Return the offset of the end of the write.
  size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
  size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;

 private:
  static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial,
                                        size_t partial_bits);
  static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw,
                                      size_t seq_no_bits, size_t epoch_bits);

  uint8_t content_type_;
  uint64_t sequence_number_;
  DataBuffer header_;
};

struct TlsRecord {
  const TlsRecordHeader header;
  const DataBuffer buffer;
};

// Make a filter and install it on a TlsAgent.
template <class T, typename... Args>
inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
                                        Args&&... args) {
  auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...);
  agent->SetFilter(filter);
  return filter;
}

// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
 public:
  TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
      : agent_(a),
        count_(0),
        cipher_spec_(),
        dropped_record_(false),
        in_sequence_number_(0),
        out_sequence_number_(0) {}

  std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }

  // External interface. Overrides PacketFilter.
  PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);

  // Report how many packets were altered by the filter.
  size_t filtered_packets() const { return count_; }

  // Enable decryption. This only works properly for TLS 1.3 and above.
  // Enabling it for lower version tests will cause undefined
  // behavior.
  void EnableDecryption();
  bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
                 uint8_t* inner_content_type, DataBuffer* plaintext);
  bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type,
               const DataBuffer& plaintext, DataBuffer* ciphertext,
               size_t padding = 0);

 protected:
  // There are two filter functions which can be overriden. Both are
  // called with the header and the record but the outer one is called
  // with a raw pointer to let you write into the buffer and lets you
  // do anything with this section of the stream. The inner one
  // just lets you change the record contents. By default, the
  // outer one calls the inner one, so if you override the outer
  // one, the inner one is never called unless you call it yourself.
  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& record,
                                            size_t* offset, DataBuffer* output);

  // The record filter receives the record contentType, version and DTLS
  // sequence number (which is zero for TLS), plus the existing record payload.
  // It returns an action (KEEP, CHANGE, DROP).  It writes to the `changed`
  // outparam with the new record contents if it chooses to CHANGE the record.
  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& data,
                                            DataBuffer* changed) {
    return KEEP;
  }

  bool is_dtls13() const;

 private:
  static void CipherSpecChanged(void* arg, PRBool sending,
                                ssl3CipherSpec* newSpec);

  std::weak_ptr<TlsAgent> agent_;
  size_t count_;
  std::unique_ptr<TlsCipherSpec> cipher_spec_;
  // Whether we dropped a record since the cipher spec changed.
  bool dropped_record_;
  // The sequence number we use for reading records as they are written.
  uint64_t in_sequence_number_;
  // The sequence number we use for writing modified records.
  uint64_t out_sequence_number_;
};

inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
  v.WriteStream(stream);
  return stream;
}

inline std::ostream& operator<<(std::ostream& stream,
                                const TlsRecordHeader& hdr) {
  hdr.WriteStream(stream);
  stream << ' ';
  switch (hdr.content_type()) {
    case ssl_ct_change_cipher_spec:
      stream << "CCS";
      break;
    case ssl_ct_alert:
      stream << "Alert";
      break;
    case ssl_ct_handshake:
      stream << "Handshake";
      break;
    case ssl_ct_application_data:
      stream << "Data";
      break;
    case ssl_ct_ack:
      stream << "ACK";
      break;
    default:
      stream << '<' << static_cast<int>(hdr.content_type()) << '>';
      break;
  }
  return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
}

// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
 public:
  TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a)
      : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {}
  TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a,
                     const std::set<uint8_t>& types)
      : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {}

  // This filter can be set to be selective based on handshake message type.  If
  // this function isn't used (or the set is empty), then all handshake messages
  // will be filtered.
  void SetHandshakeTypes(const std::set<uint8_t>& types) {
    handshake_types_ = types;
  }

  class HandshakeHeader : public TlsVersioned {
   public:
    HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {}

    uint8_t handshake_type() const { return handshake_type_; }
    bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
               const DataBuffer& preceding_fragment, DataBuffer* body,
               bool* complete);
    size_t Write(DataBuffer* buffer, size_t offset,
                 const DataBuffer& body) const;
    size_t WriteFragment(DataBuffer* buffer, size_t offset,
                         const DataBuffer& body, size_t fragment_offset,
                         size_t fragment_length) const;

   private:
    // Reads the length from the record header.
    // This also reads the DTLS fragment information and checks it.
    bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
                    uint32_t expected_offset, uint32_t* length,
                    bool* last_fragment);

    uint8_t handshake_type_;
    uint16_t message_seq_;
    // fragment_offset is always zero in these tests.
  };

 protected:
  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& input,
                                            DataBuffer* output);
  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                               const DataBuffer& input,
                                               DataBuffer* output) = 0;

 private:
  bool IsFilteredType(const HandshakeHeader& header,
                      const DataBuffer& handshake);

  std::set<uint8_t> handshake_types_;
  DataBuffer preceding_fragment_;
};

// Make a copy of the first instance of a handshake message.
class TlsHandshakeRecorder : public TlsHandshakeFilter {
 public:
  TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
                       uint8_t handshake_type)
      : TlsHandshakeFilter(a, {handshake_type}), buffer_() {}
  TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
                       const std::set<uint8_t>& handshake_types)
      : TlsHandshakeFilter(a, handshake_types), buffer_() {}

  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                               const DataBuffer& input,
                                               DataBuffer* output);

  void Reset() { buffer_.Truncate(0); }

  const DataBuffer& buffer() const { return buffer_; }

 private:
  DataBuffer buffer_;
};

// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
 public:
  TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a,
                                      uint8_t handshake_type,
                                      const DataBuffer& replacement)
      : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {}

  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                               const DataBuffer& input,
                                               DataBuffer* output);

 private:
  DataBuffer buffer_;
};

// Make a copy of each record of a given type.
class TlsRecordRecorder : public TlsRecordFilter {
 public:
  TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct)
      : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {}
  TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a)
      : TlsRecordFilter(a),
        filter_(false),
        ct_(ssl_ct_handshake),  // dummy (<optional> is C++14)
        records_() {}
  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& input,
                                            DataBuffer* output);

  size_t count() const { return records_.size(); }
  void Clear() { records_.clear(); }

  const TlsRecord& record(size_t i) const { return records_[i]; }

 private:
  bool filter_;
  uint8_t ct_;
  std::vector<TlsRecord> records_;
};

// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
 public:
  TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a,
                          DataBuffer& buffer)
      : TlsRecordFilter(a), buffer_(buffer) {}

  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& input,
                                            DataBuffer* output);

 private:
  DataBuffer buffer_;
};

// Make a copy of the records
class TlsHeaderRecorder : public TlsRecordFilter {
 public:
  TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {}
  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& input,
                                            DataBuffer* output);
  const TlsRecordHeader* header(size_t index);

 private:
  std::vector<TlsRecordHeader> headers_;
};

typedef std::initializer_list<std::shared_ptr<PacketFilter>>
    ChainedPacketFilterInit;

// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
 public:
  ChainedPacketFilter() {}
  ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
      : filters_(filters.begin(), filters.end()) {}
  ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
  virtual ~ChainedPacketFilter() {}

  virtual PacketFilter::Action Filter(const DataBuffer& input,
                                      DataBuffer* output);

  // Takes ownership of the filter.
  void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }

 private:
  std::vector<std::shared_ptr<PacketFilter>> filters_;
};

typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
    TlsExtensionFinder;

class TlsExtensionFilter : public TlsHandshakeFilter {
 public:
  TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a)
      : TlsHandshakeFilter(a,
                           {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
                            kTlsHandshakeHelloRetryRequest,
                            kTlsHandshakeEncryptedExtensions}) {}

  TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a,
                     const std::set<uint8_t>& types)
      : TlsHandshakeFilter(a, types) {}

  static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);

 protected:
  PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                       const DataBuffer& input,
                                       DataBuffer* output) override;

  virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
                                               const DataBuffer& input,
                                               DataBuffer* output) = 0;

 private:
  PacketFilter::Action FilterExtensions(TlsParser* parser,
                                        const DataBuffer& input,
                                        DataBuffer* output);
};

class TlsExtensionCapture : public TlsExtensionFilter {
 public:
  TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
                      bool last = false)
      : TlsExtensionFilter(a),
        extension_(ext),
        captured_(false),
        last_(last),
        data_() {}

  const DataBuffer& extension() const { return data_; }
  bool captured() const { return captured_; }

 protected:
  PacketFilter::Action FilterExtension(uint16_t extension_type,
                                       const DataBuffer& input,
                                       DataBuffer* output) override;

 private:
  const uint16_t extension_;
  bool captured_;
  bool last_;
  DataBuffer data_;
};

class TlsExtensionReplacer : public TlsExtensionFilter {
 public:
  TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
                       const DataBuffer& data)
      : TlsExtensionFilter(a), extension_(extension), data_(data) {}
  PacketFilter::Action FilterExtension(uint16_t extension_type,
                                       const DataBuffer& input,
                                       DataBuffer* output) override;

 private:
  const uint16_t extension_;
  const DataBuffer data_;
};

class TlsExtensionDropper : public TlsExtensionFilter {
 public:
  TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension)
      : TlsExtensionFilter(a), extension_(extension) {}
  PacketFilter::Action FilterExtension(uint16_t extension_type,
                                       const DataBuffer&, DataBuffer*) override;

 private:
  uint16_t extension_;
};

class TlsExtensionInjector : public TlsHandshakeFilter {
 public:
  TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
                       const DataBuffer& data)
      : TlsHandshakeFilter(a), extension_(ext), data_(data) {}

 protected:
  PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                       const DataBuffer& input,
                                       DataBuffer* output) override;

 private:
  const uint16_t extension_;
  const DataBuffer data_;
};

class TlsExtensionDamager : public TlsExtensionFilter {
 public:
  TlsExtensionDamager(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
                      size_t index)
      : TlsExtensionFilter(a), extension_(extension), index_(index) {}
  virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
                                               const DataBuffer& input,
                                               DataBuffer* output);

 private:
  uint16_t extension_;
  size_t index_;
};

typedef std::function<void(void)> VoidFunction;

class AfterRecordN : public TlsRecordFilter {
 public:
  AfterRecordN(const std::shared_ptr<TlsAgent>& src,
               const std::shared_ptr<TlsAgent>& dest, unsigned int record,
               VoidFunction func)
      : TlsRecordFilter(src),
        dest_(dest),
        record_(record),
        func_(func),
        counter_(0) {}

  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                            const DataBuffer& body,
                                            DataBuffer* out) override;

 private:
  std::weak_ptr<TlsAgent> dest_;
  unsigned int record_;
  VoidFunction func_;
  unsigned int counter_;
};

// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
 public:
  TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
                               const std::shared_ptr<TlsAgent>& server)
      : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
        server_(server) {}

  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                               const DataBuffer& input,
                                               DataBuffer* output);

 private:
  std::weak_ptr<TlsAgent> server_;
};

// Damage a record.
class TlsRecordLastByteDamager : public TlsRecordFilter {
 public:
  TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a)
      : TlsRecordFilter(a) {}

 protected:
  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                    const DataBuffer& data,
                                    DataBuffer* changed) override {
    *changed = data;
    changed->data()[changed->len() - 1]++;
    return CHANGE;
  }
};

// This class selectively drops complete writes.  This relies on the fact that
// writes in libssl are on record boundaries.
class SelectiveDropFilter : public PacketFilter {
 public:
  SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {}

 protected:
  virtual PacketFilter::Action Filter(const DataBuffer& input,
                                      DataBuffer* output) override;

 private:
  const uint32_t pattern_;
  uint8_t counter_;
};

// This class selectively drops complete records. The difference from
// SelectiveDropFilter is that if multiple DTLS records are in the same
// datagram, we just drop one.
class SelectiveRecordDropFilter : public TlsRecordFilter {
 public:
  SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
                            uint32_t pattern, bool enabled = true)
      : TlsRecordFilter(a), pattern_(pattern), counter_(0) {
    if (!enabled) {
      Disable();
    }
  }
  SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
                            std::initializer_list<size_t> records)
      : SelectiveRecordDropFilter(a, ToPattern(records), true) {}

  void Reset(uint32_t pattern) {
    counter_ = 0;
    PacketFilter::Enable();
    pattern_ = pattern;
  }

  void Reset(std::initializer_list<size_t> records) {
    Reset(ToPattern(records));
  }

 protected:
  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                    const DataBuffer& data,
                                    DataBuffer* changed) override;

 private:
  static uint32_t ToPattern(std::initializer_list<size_t> records);

  uint32_t pattern_;
  uint8_t counter_;
};

// Set the version number in the ClientHello.
class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
 public:
  TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& a,
                              uint16_t version)
      : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), version_(version) {}

  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                               const DataBuffer& input,
                                               DataBuffer* output);

 private:
  uint16_t version_;
};

// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
 public:
  TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type)
      : TlsHandshakeFilter(a), type_(type) {}
  PacketFilter::Action FilterHandshake(
      const TlsHandshakeFilter::HandshakeHeader& header,
      const DataBuffer& input, DataBuffer* output) override {
    if (header.handshake_type() != type_) {
      return KEEP;
    }

    *output = input;

    output->data()[output->len() - 1]++;
    return CHANGE;
  }

 private:
  uint8_t type_;
};

class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
 public:
  SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a,
                              uint16_t suite)
      : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}),
        cipher_suite_(suite) {}

 protected:
  PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                       const DataBuffer& input,
                                       DataBuffer* output) override;

 private:
  uint16_t cipher_suite_;
};

}  // namespace nss_test

#endif
