rocksdb/thrift/lib/cpp/transport/TSSLSocket.h

366 lines
9.8 KiB
C++

// Copyright (c) 2009- Facebook
// Distributed under the Thrift Software License
//
// See accompanying file LICENSE or visit the Thrift site at:
// http://developers.facebook.com/thrift/
#ifndef THRIFT_TRANSPORT_TSSLSOCKET_H_
#define THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
#include <string>
#include <boost/shared_ptr.hpp>
#include <openssl/ssl.h>
#include "thrift/lib/cpp/concurrency/Mutex.h"
#include "thrift/lib/cpp/transport/TSocket.h"
namespace apache { namespace thrift { namespace transport {
class PasswordCollector;
class SSLContext;
class TSocketAddress;
/**
* OpenSSL implementation for SSL socket interface.
*/
class TSSLSocket: public TVirtualTransport<TSSLSocket, TSocket> {
public:
/**
* Constructor.
*/
explicit TSSLSocket(const boost::shared_ptr<SSLContext>& ctx);
/**
* Constructor, create an instance of TSSLSocket given an existing socket.
*
* @param socket An existing socket
*/
TSSLSocket(const boost::shared_ptr<SSLContext>& ctx, int socket);
/**
* Constructor.
*
* @param host Remote host name
* @param port Remote port number
*/
TSSLSocket(const boost::shared_ptr<SSLContext>& ctx,
const std::string& host,
int port);
/**
* Constructor.
*/
TSSLSocket(const boost::shared_ptr<SSLContext>& ctx,
const TSocketAddress& address);
/**
* Destructor.
*/
~TSSLSocket();
/**
* TTransport interface.
*/
bool isOpen();
bool peek();
void open();
void close();
uint32_t read(uint8_t* buf, uint32_t len);
void write(const uint8_t* buf, uint32_t len);
void flush();
/**
* Set whether to use client or server side SSL handshake protocol.
*
* @param flag Use server side handshake protocol if true.
*/
void server(bool flag) { server_ = flag; }
/**
* Determine whether the SSL socket is server or client mode.
*/
bool server() const { return server_; }
protected:
/**
* Verify peer certificate after SSL handshake completes.
*/
virtual void verifyCertificate();
/**
* Initiate SSL handshake if not already initiated.
*/
void checkHandshake();
bool server_;
SSL* ssl_;
boost::shared_ptr<SSLContext> ctx_;
};
/**
* SSL socket factory. SSL sockets should be created via SSL factory.
*/
class TSSLSocketFactory {
public:
/**
* Constructor/Destructor
*/
explicit TSSLSocketFactory(const boost::shared_ptr<SSLContext>& context);
virtual ~TSSLSocketFactory();
/**
* Create an instance of TSSLSocket with a fresh new socket.
*/
virtual boost::shared_ptr<TSSLSocket> createSocket();
/**
* Create an instance of TSSLSocket with the given socket.
*
* @param socket An existing socket.
*/
virtual boost::shared_ptr<TSSLSocket> createSocket(int socket);
/**
* Create an instance of TSSLSocket.
*
* @param host Remote host to be connected to
* @param port Remote port to be connected to
*/
virtual boost::shared_ptr<TSSLSocket> createSocket(const std::string& host,
int port);
/**
* Set/Unset server mode.
*
* @param flag Server mode if true
*/
virtual void server(bool flag) { server_ = flag; }
/**
* Determine whether the socket is in server or client mode.
*
* @return true, if server mode, or, false, if client mode
*/
virtual bool server() const { return server_; }
private:
boost::shared_ptr<SSLContext> ctx_;
bool server_;
};
/**
* SSL exception.
*/
class TSSLException: public TTransportException {
public:
explicit TSSLException(const std::string& message):
TTransportException(TTransportException::INTERNAL_ERROR, message) {}
virtual const char* what() const throw() {
if (message_.empty()) {
return "TSSLException";
} else {
return message_.c_str();
}
}
};
/**
* Wrap OpenSSL SSL_CTX into a class.
*/
class SSLContext {
public:
enum SSLVersion {
SSLv2,
SSLv3,
TLSv1
};
/**
* Constructor.
*
* @param version The lowest or oldest SSL version to support.
*/
explicit SSLContext(SSLVersion version = TLSv1);
virtual ~SSLContext();
/**
* Set ciphers to be used in SSL handshake process.
*
* @param ciphers A list of ciphers
*/
virtual void ciphers(const std::string& enable);
/**
* Enable/Disable authentication. Peer name validation can only be done
* if checkPeerCert is true.
*
* @param checkPeerCert If true, require peer to present valid certificate
* @param checkPeerName If true, validate that the certificate common name
* or alternate name(s) of peer matches the hostname
* used to connect.
* @param peerName If non-empty, validate that the certificate common
* name of peer matches the given string (altername
* name(s) are not used in this case).
*/
virtual void authenticate(bool checkPeerCert, bool checkPeerName,
const std::string& peerName = std::string());
/**
* Load server certificate.
*
* @param path Path to the certificate file
* @param format Certificate file format
*/
virtual void loadCertificate(const char* path, const char* format = "PEM");
/**
* Load private key.
*
* @param path Path to the private key file
* @param format Private key file format
*/
virtual void loadPrivateKey(const char* path, const char* format = "PEM");
/**
* Load trusted certificates from specified file.
*
* @param path Path to trusted certificate file
*/
virtual void loadTrustedCertificates(const char* path);
/**
* Load trusted certificates from specified X509 certificate store.
*
* @param store X509 certificate store.
*/
virtual void loadTrustedCertificates(X509_STORE* store);
/**
* Default randomize method.
*/
virtual void randomize();
/**
* Override default OpenSSL password collector.
*
* @param collector Instance of user defined password collector
*/
virtual void passwordCollector(boost::shared_ptr<PasswordCollector> collector);
/**
* Obtain password collector.
*
* @return User defined password collector
*/
virtual boost::shared_ptr<PasswordCollector> passwordCollector() {
return collector_;
}
/**
* Create an SSL object from this context.
*/
SSL* createSSL() const;
/**
* Possibly validate the peer's certificate name, depending on how this
* SSLContext was configured by authenticate().
*
* @return True if the peer's name is acceptable, false otherwise
*/
bool validatePeerName(TSSLSocket* sock, SSL* ssl) const;
/**
* Set the options on the SSL_CTX object.
*/
void setOptions(long options);
#ifdef OPENSSL_NPN_NEGOTIATED
/**
* Set the list of protocols that a TLS server should advertise for
* Next Protocol Negotiation (NPN).
*
* @param protocols List of protocol names, or NULL to disable NPN.
* Note: if non-null, this method makes a copy, so
* the caller needn't keep the list in scope after
* the call completes.
*/
void setAdvertisedNextProtocols(const std::list<std::string>* protocols);
#endif // OPENSSL_NPN_NEGOTIATED
/**
* Gets the underlying SSL_CTX for advanced usage
*/
SSL_CTX *getSSLCtx() const {
return ctx_;
}
enum SSLLockType {
LOCK_MUTEX,
LOCK_SPINLOCK,
LOCK_NONE
};
/**
* Set preferences for how to treat locks in OpenSSL. This must be
* called before the instantiation of any SSLContext objects, otherwise
* the defaults will be used.
*
* OpenSSL has a lock for each module rather than for each object or
* data that needs locking. Some locks protect only refcounts, and
* might be better as spinlocks rather than mutexes. Other locks
* may be totally unnecessary if the objects being protected are not
* shared between threads in the application.
*
* By default, all locks are initialized as mutexes. OpenSSL's lock usage
* may change from version to version and you should know what you are doing
* before disabling any locks entirely.
*
* Example: if you don't share SSL sessions between threads in your
* application, you may be able to do this
*
* setSSLLockTypes({{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_NONE}})
*/
static void setSSLLockTypes(std::map<int, SSLLockType> lockTypes);
protected:
SSL_CTX* ctx_;
private:
bool checkPeerName_;
std::string peerFixedName_;
boost::shared_ptr<PasswordCollector> collector_;
static concurrency::Mutex mutex_;
static uint64_t count_;
#ifdef OPENSSL_NPN_NEGOTIATED
/**
* Wire-format list of advertised protocols for use in NPN.
*/
unsigned char* advertisedNextProtocols_;
unsigned advertisedNextProtocolsLength_;
static int advertisedNextProtocolCallback(SSL* ssl,
const unsigned char** out, unsigned int* outlen, void* data);
#endif // OPENSSL_NPN_NEGOTIATED
static int passwordCallback(char* password, int size, int, void* data);
static void initializeOpenSSL();
static void cleanupOpenSSL();
/**
* Helper to match a hostname versus a pattern.
*/
static bool matchName(const char* host, const char* pattern, int size);
};
typedef boost::shared_ptr<SSLContext> SSLContextPtr;
/**
* Override the default password collector.
*/
class PasswordCollector {
public:
virtual ~PasswordCollector() {}
/**
* Interface for customizing how to collect private key password.
*
* By default, OpenSSL prints a prompt on screen and request for password
* while loading private key. To implement a custom password collector,
* implement this interface and register it with TSSLSocketFactory.
*
* @param password Pass collected password back to OpenSSL
* @param size Maximum length of password including NULL character
*/
virtual void getPassword(std::string& password, int size) = 0;
};
}}}
#endif