Newer
Older
XinYang_IOS / Pods / OpenVPNAdapter / Sources / OpenVPN3 / openvpn / ws / websocket.hpp
@zhangfeng zhangfeng on 7 Dec 2023 11 KB 1.8.0
//    OpenVPN -- An application to securely tunnel IP networks
//               over a single port, with support for SSL/TLS-based
//               session authentication and key exchange,
//               packet encryption, packet authentication, and
//               packet compression.
//
//    Copyright (C) 2012-2020 OpenVPN Inc.
//
//    This program is free software: you can redistribute it and/or modify
//    it under the terms of the GNU Affero General Public License Version 3
//    as published by the Free Software Foundation.
//
//    This program is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//    GNU Affero General Public License for more details.
//
//    You should have received a copy of the GNU Affero General Public License
//    along with this program in the COPYING file.
//    If not, see <http://www.gnu.org/licenses/>.

#pragma once

#include <string>
#include <cstdint>
#include <ostream>
#include <tuple>
#include <utility>

#include <openvpn/common/exception.hpp>
#include <openvpn/common/rc.hpp>
#include <openvpn/common/base64.hpp>
#include <openvpn/common/socktypes.hpp>
#include <openvpn/common/endian64.hpp>
#include <openvpn/crypto/hashstr.hpp>
#include <openvpn/buffer/buffer.hpp>
#include <openvpn/random/randapi.hpp>

namespace openvpn {
  namespace WebSocket {

    OPENVPN_EXCEPTION(websocket_error);

    class Receiver;

    inline std::string accept_confirmation(DigestFactory& digest_factory,
					   const std::string& websocket_key)
    {
      static const char guid[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
      HashString h(digest_factory, CryptoAlgs::SHA1);
      h.update(websocket_key + guid);
      return h.final_base64();
    }

    class Protocol
    {
    public:
      static constexpr size_t MAX_HEAD = 16;

      enum Opcode
	{
	  Text = 0x1,
	  Binary = 0x2,
	  Close = 0x8,
	  Ping = 0x9,
	  Pong = 0xA,
	};

      static std::string opcode_to_string(const unsigned int opcode)
      {
	switch (opcode)
	  {
	  case Text:
	    return "Text";
	  case Binary:
	    return "Binary";
	  case Close:
	    return "Close";
	  case Ping:
	    return "Ping";
	  case Pong:
	    return "Pong";
	  default:
	    return "WS-OPCODE-" + std::to_string(opcode);
	  }
      }

      union MaskingKey
      {
      public:
	MaskingKey(std::uint32_t mask)
	  : mask32(std::move(mask))
	{
	}

	void xor_buf(Buffer& buf) const
	{
	  const size_t size = buf.size();
	  std::uint8_t* data = buf.data();
	  for (size_t i = 0; i < size; ++i)
	    data[i] ^= mask8[i & 0x3];
	}

	void prepend_mask(Buffer& buf) const
	{
	  buf.prepend(&mask32, sizeof(mask32));
	}

      private:
	std::uint32_t mask32;
	std::uint8_t mask8[4];
      };
    };

    class Status
    {
    public:
      Status()
	: opcode_(0),
	  fin_(false),
	  close_status_code_(0)
      {
      }

      Status(unsigned int opcode,
	     bool fin=true,
	     unsigned int close_status_code=0)
	: opcode_(std::move(opcode)),
	  fin_(std::move(fin)),
	  close_status_code_(std::move(close_status_code))
      {
      }

      Status(const Status& ref,
	     const unsigned int opcode)
	: opcode_(opcode),
	  fin_(ref.fin_),
	  close_status_code_(ref.close_status_code_)
      {
      }

      bool defined() const
      {
	return opcode_ != 0;
      }

      unsigned int opcode() const
      {
	return opcode_;
      }

      bool fin() const
      {
	return fin_;
      }

      unsigned int close_status_code() const
      {
	return close_status_code_;
      }

      bool operator==(const Status& rhs) const
      {
	return std::tie(opcode_, fin_, close_status_code_) == std::tie(rhs.opcode_, fin_, rhs.close_status_code_);
      }

      bool operator!=(const Status& rhs) const
      {
	return std::tie(opcode_, fin_, close_status_code_) != std::tie(rhs.opcode_, fin_, rhs.close_status_code_);
      }

      std::string to_string() const
      {
	std::string ret;

	ret.reserve(64);
	ret += "[op=";
	ret += Protocol::opcode_to_string(opcode_);
	ret += " fin=";
	ret += std::to_string(fin_);
	if (opcode_ == Protocol::Close)
	  {
	    ret += " status=";
	    ret += std::to_string(close_status_code_);
	  }
	ret += ']';
	return ret;
      }

    private:
      friend class Receiver;

      unsigned int opcode_;
      bool fin_;
      unsigned int close_status_code_;
    };

    class Sender
    {
    public:
      Sender(RandomAPI::Ptr cli_rng_arg) // only provide rng on client side
	: cli_rng(std::move(cli_rng_arg))
      {
	if (cli_rng)
	  cli_rng->assert_crypto();
      }

      void frame(Buffer& buf, const Status& s) const
      {
	if (s.opcode() == Protocol::Close)
	  {
	    const std::uint16_t cs = htons(s.close_status_code());
	    buf.prepend(&cs, sizeof(cs));
	  }

	const size_t payload_len = buf.size();
	if (cli_rng)
	  {
	    const Protocol::MaskingKey mk(cli_rng->rand_get<std::uint32_t>());
	    mk.xor_buf(buf);
	    mk.prepend_mask(buf);
	  }
	prepend_payload_length(buf, payload_len);

	std::uint8_t head = s.opcode() & 0xF;
	if (s.fin())
	  head |= 0x80;
	buf.prepend(&head, sizeof(head));

	//OPENVPN_LOG("WS SEND HEAD\n" << dump_hex(buf));
      }

    private:
      void prepend_payload_length(Buffer& buf, const size_t len) const
      {
	std::uint8_t len8;

	if (len <= 125)
	  len8 = len;
	else if (len <= 65535)
	  {
	    len8 = 126;
	    const std::uint16_t len16 = htons(len);
	    buf.prepend(&len16, sizeof(len16));
	  }
	else
	  {
	    len8 = 127;
	    const std::uint64_t len64 = Endian::rev64(len);
	    buf.prepend(&len64, sizeof(len64));
	  }

	if (cli_rng)
	  len8 |= 0x80;
	buf.prepend(&len8, sizeof(len8));
      }

      RandomAPI::Ptr cli_rng;
    };

    class Receiver
    {
    public:
      Receiver(const bool is_client_arg)
	: is_client(is_client_arg)
      {
	reset_pod();
      }

      Buffer buf_unframed()
      {
	verify_message_complete();
	if (size > buf.size())
	  throw websocket_error("Receiver::buf_unframed: internal error");
	return Buffer(buf.data(), size, true);
      }

      // return true if message is complete
      bool complete()
      {
	// already complete?
	if (header_complete)
	  return complete_();

	// we need at least 2 bytes before we can do anything
	if (buf.size() < 2)
	  return false;

	// get first 2 bytes of header
	Buffer b(buf.data(), buf.size(), true);
	const std::uint8_t* head = b.read_alloc(2);
	s.opcode_ = head[0] & 0xF;
	s.fin_ = bool(head[0] & 0x80);
	if (head[0] & 0x70)
	  throw websocket_error("Receiver: reserved bits are set");
	if (bool(head[1] & 0x80) == is_client)
	  throw websocket_error("Receiver: bad masking direction");

	// process payload length
	const std::uint8_t pl = head[1] & 0x7f;
	if (pl <= 125)
	  {
	    size = pl;
	  }
	else if (pl == 126)
	  {
	    std::uint16_t len16;
	    if (b.size() < sizeof(len16))
	      return false;
	    b.read(&len16, sizeof(len16));
	    size = ntohs(len16);
	  }
	else // pl == 127
	  {
	    std::uint64_t len64;
	    if (b.size() < sizeof(len64))
	      return false;
	    b.read(&len64, sizeof(len64));
	    size = Endian::rev64(len64);
	  }

	// read mask (server side only)
	if (!is_client)
	  {
	    if (b.size() < sizeof(mask))
	      return false;
	    b.read(&mask, sizeof(mask));
	  }

	buf.advance(b.offset());
	header_complete = true;
	return complete_();
      }

      void add_buf(BufferAllocated&& inbuf)
      {
	if (!buf.allocated())
	  {
	    buf = std::move(inbuf);
	    buf.or_flags(BufferAllocated::GROW);
	  }
	else
	  buf.append(inbuf);
      }

      void reset()
      {
	verify_message_complete();
	s = Status();
	reset_buf();
	reset_pod();
      }

      Status status() const
      {
	verify_message_complete();
	return s;
      }

    private:
      void reset_buf()
      {
	if (buf.allocated())
	  {
	    if (size < buf.size())
	      {
		buf.advance(size);
		buf.realign(0);
	      }
	    else if (size == buf.size())
	      buf.clear();
	    else
	      throw websocket_error("Receiver::reset_buf: bad size");
	  }
      }

      void reset_pod()
      {
	header_complete = false;
	message_complete = false;
	mask = 0;
	size = 0;
      }

      void verify_message_complete() const
      {
	if (!message_complete)
	  throw websocket_error("Receiver: message incomplete");
      }

      bool complete_()
      {
	if (message_complete)
	  return true;

	if (header_complete && size <= buf.size())
	  {
	    // un-xor the data on the server side only
	    if (!is_client)
	      {
		Buffer b(buf.data(), size, true);
		const Protocol::MaskingKey mk(mask);
		mk.xor_buf(b);
	      }

	    // get close status code
	    if (s.opcode_ == Protocol::Close && size >= 2)
	      {
		std::uint16_t cs;
		buf.read(&cs, sizeof(cs));
		size -= sizeof(cs);
		s.close_status_code_ = ntohs(cs);
	      }

	    message_complete = true;
	    return true;
	  }
	return false;
      }

      const bool is_client;
      bool header_complete;
      bool message_complete;
      std::uint32_t mask;
      std::uint64_t size;
      Status s;
      BufferAllocated buf;
    };

    namespace Client {

      struct Config : public RC<thread_unsafe_refcount>
      {
	typedef RCPtr<Config> Ptr;

	std::string origin;
	std::string protocol;
	RandomAPI::Ptr rng;
	DigestFactory::Ptr digest_factory;

	// compression
	bool compress = false;
	size_t compress_threshold = 256;
      };

      class PerRequest : public RC<thread_unsafe_refcount>
      {
      private:
	Config::Ptr conf;

      public:
	typedef RCPtr<PerRequest> Ptr;

	PerRequest(Config::Ptr conf_arg)
	  : conf(validate_conf(std::move(conf_arg))),
	    sender(conf->rng),
	    receiver(true)
	{
	}

	void client_headers(std::ostream& os)
	{
	  generate_websocket_key();
	  os << "Sec-WebSocket-Key: " << websocket_key << "\r\n";
	  os << "Sec-WebSocket-Version: 13\r\n";
	  if (!conf->protocol.empty())
	    os << "Sec-WebSocket-Protocol: " << conf->protocol << "\r\n";
	  os << "Connection: Upgrade\r\n";
	  os << "Upgrade: websocket\r\n";
	  if (!conf->origin.empty())
	    os << "Origin: " << conf->origin << "\r\n";
	}

	bool confirm_websocket_key(const std::string& ws_accept) const
	{
	  return ws_accept == accept_confirmation(*conf->digest_factory, websocket_key);
	}

	Sender sender;
	Receiver receiver;

      private:
	static Config::Ptr validate_conf(Config::Ptr conf)
	{
	  if (!conf)
	    throw websocket_error("no config");
	  conf->rng->assert_crypto();
	  if (!conf->digest_factory)
	    throw websocket_error("no digest factory in config");
	  return conf;
	}

	void generate_websocket_key()
	{
	  std::uint8_t data[16];
	  conf->rng->rand_bytes(data, sizeof(data));
	  websocket_key = base64->encode(data, sizeof(data));
	}

	std::string websocket_key;
      };

    }

    namespace Server {

      struct Config : public RC<thread_unsafe_refcount>
      {
	typedef RCPtr<Config> Ptr;

	std::string protocol;
	DigestFactory::Ptr digest_factory;
      };

      class PerRequest : public RC<thread_unsafe_refcount>
      {
      private:
	Config::Ptr conf;

      public:
	typedef RCPtr<PerRequest> Ptr;

	PerRequest(Config::Ptr conf_arg)
	  : conf(validate_conf(std::move(conf_arg))),
	    sender(RandomAPI::Ptr()),
	    receiver(false)
	{
	}

	void set_websocket_key(const std::string& websocket_key)
	{
	  websocket_accept = accept_confirmation(*conf->digest_factory, websocket_key);
	}

	void server_headers(std::ostream& os)
	{
	  os << "Upgrade: websocket\r\n";
	  os << "Connection: Upgrade\r\n";
	  if (!websocket_accept.empty())
	    os << "Sec-WebSocket-Accept: " << websocket_accept << "\r\n";
	  if (!conf->protocol.empty())
	    os << "Sec-WebSocket-Protocol: " << conf->protocol << "\r\n";
	}

	Sender sender;
	Receiver receiver;

      private:
	static Config::Ptr validate_conf(Config::Ptr conf)
	{
	  if (!conf)
	    throw websocket_error("no config");
	  if (!conf->digest_factory)
	    throw websocket_error("no digest factory in config");
	  return conf;
	}

	std::string websocket_accept;
      };

    }

  }
}