492 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			492 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			C++
		
	
	
	
| /*
 | |
|  * Copyright (C) 2020 The Android Open Source Project
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *      http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| #include "adb/pairing/pairing_connection.h"
 | |
| 
 | |
| #include <stddef.h>
 | |
| #include <stdint.h>
 | |
| 
 | |
| #include <functional>
 | |
| #include <memory>
 | |
| #include <string_view>
 | |
| #include <thread>
 | |
| #include <vector>
 | |
| 
 | |
| #include <adb/pairing/pairing_auth.h>
 | |
| #include <adb/tls/tls_connection.h>
 | |
| #include <android-base/endian.h>
 | |
| #include <android-base/logging.h>
 | |
| #include <android-base/macros.h>
 | |
| #include <android-base/unique_fd.h>
 | |
| 
 | |
| #include "pairing.pb.h"
 | |
| 
 | |
| using namespace adb;
 | |
| using android::base::unique_fd;
 | |
| using TlsError = tls::TlsConnection::TlsError;
 | |
| 
 | |
| const uint8_t kCurrentKeyHeaderVersion = 1;
 | |
| const uint8_t kMinSupportedKeyHeaderVersion = 1;
 | |
| const uint8_t kMaxSupportedKeyHeaderVersion = 1;
 | |
| const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
 | |
| 
 | |
| struct PairingPacketHeader {
 | |
|     uint8_t version;   // PairingPacket version
 | |
|     uint8_t type;      // the type of packet (PairingPacket.Type)
 | |
|     uint32_t payload;  // Size of the payload in bytes
 | |
| } __attribute__((packed));
 | |
| 
 | |
| struct PairingAuthDeleter {
 | |
|     void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
 | |
| };  // PairingAuthDeleter
 | |
| using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
 | |
| 
 | |
| // PairingConnectionCtx encapsulates the protocol to authenticate two peers with
 | |
| // each other. This class will open the tcp sockets and handle the pairing
 | |
| // process. On completion, both sides will have each other's public key
 | |
| // (certificate) if successful, otherwise, the pairing failed. The tcp port
 | |
| // number is hardcoded (see pairing_connection.cpp).
 | |
| //
 | |
| // Each PairingConnectionCtx instance represents a different device trying to
 | |
| // pair. So for the device, we can have multiple PairingConnectionCtxs while the
 | |
| // host may have only one (unless host has a PairingServer).
 | |
| //
 | |
| // See pairing_connection_test.cpp for example usage.
 | |
| //
 | |
| struct PairingConnectionCtx {
 | |
|   public:
 | |
|     using Data = std::vector<uint8_t>;
 | |
|     using ResultCallback = pairing_result_cb;
 | |
|     enum class Role {
 | |
|         Client,
 | |
|         Server,
 | |
|     };
 | |
| 
 | |
|     explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
 | |
|                                   const Data& certificate, const Data& priv_key);
 | |
|     virtual ~PairingConnectionCtx();
 | |
| 
 | |
|     // Starts the pairing connection on a separate thread.
 | |
|     // Upon completion, if the pairing was successful,
 | |
|     // |cb| will be called with the peer information and certificate.
 | |
|     // Otherwise, |cb| will be called with empty data. |fd| should already
 | |
|     // be opened. PairingConnectionCtx will take ownership of the |fd|.
 | |
|     //
 | |
|     // Pairing is successful if both server/client uses the same non-empty
 | |
|     // |pswd|, and they are able to exchange the information. |pswd| and
 | |
|     // |certificate| must be non-empty. Start() can only be called once in the
 | |
|     // lifetime of this object.
 | |
|     //
 | |
|     // Returns true if the thread was successfully started, false otherwise.
 | |
|     bool Start(int fd, ResultCallback cb, void* opaque);
 | |
| 
 | |
|   private:
 | |
|     // Setup the tls connection.
 | |
|     bool SetupTlsConnection();
 | |
| 
 | |
|     /************ PairingPacketHeader methods ****************/
 | |
|     // Tries to write out the header and payload.
 | |
|     bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
 | |
|     // Tries to parse incoming data into the |header|. Returns true if header
 | |
|     // is valid and header version is supported. |header| is filled on success.
 | |
|     // |header| may contain garbage if unsuccessful.
 | |
|     bool ReadHeader(PairingPacketHeader* header);
 | |
|     // Creates a PairingPacketHeader.
 | |
|     void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
 | |
|                       uint32_t payload_size);
 | |
|     // Checks if actual matches expected.
 | |
|     bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
 | |
| 
 | |
|     /*********** State related methods **************/
 | |
|     // Handles the State::ExchangingMsgs state.
 | |
|     bool DoExchangeMsgs();
 | |
|     // Handles the State::ExchangingPeerInfo state.
 | |
|     bool DoExchangePeerInfo();
 | |
| 
 | |
|     // The background task to do the pairing.
 | |
|     void StartWorker();
 | |
| 
 | |
|     // Calls |cb_| and sets the state to Stopped.
 | |
|     void NotifyResult(const PeerInfo* p);
 | |
| 
 | |
|     static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
 | |
| 
 | |
|     enum class State {
 | |
|         Ready,
 | |
|         ExchangingMsgs,
 | |
|         ExchangingPeerInfo,
 | |
|         Stopped,
 | |
|     };
 | |
| 
 | |
|     std::atomic<State> state_{State::Ready};
 | |
|     Role role_;
 | |
|     Data pswd_;
 | |
|     PeerInfo peer_info_;
 | |
|     Data cert_;
 | |
|     Data priv_key_;
 | |
| 
 | |
|     // Peer's info
 | |
|     PeerInfo their_info_;
 | |
| 
 | |
|     ResultCallback cb_;
 | |
|     void* opaque_ = nullptr;
 | |
|     std::unique_ptr<tls::TlsConnection> tls_;
 | |
|     PairingAuthPtr auth_;
 | |
|     unique_fd fd_;
 | |
|     std::thread thread_;
 | |
|     static constexpr size_t kExportedKeySize = 64;
 | |
| };  // PairingConnectionCtx
 | |
| 
 | |
| PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
 | |
|                                            const Data& cert, const Data& priv_key)
 | |
|     : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
 | |
|     CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
 | |
| }
 | |
| 
 | |
| PairingConnectionCtx::~PairingConnectionCtx() {
 | |
|     // Force close the fd and wait for the worker thread to finish.
 | |
|     fd_.reset();
 | |
|     if (thread_.joinable()) {
 | |
|         thread_.join();
 | |
|     }
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::SetupTlsConnection() {
 | |
|     tls_ = tls::TlsConnection::Create(
 | |
|             role_ == Role::Server ? tls::TlsConnection::Role::Server
 | |
|                                   : tls::TlsConnection::Role::Client,
 | |
|             std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
 | |
|             std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
 | |
|             fd_);
 | |
| 
 | |
|     if (tls_ == nullptr) {
 | |
|         LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Allow any peer certificate
 | |
|     tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
 | |
| 
 | |
|     // SSL doesn't seem to behave correctly with fdevents so just do a blocking
 | |
|     // read for the pairing data.
 | |
|     if (tls_->DoHandshake() != TlsError::Success) {
 | |
|         LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // To ensure the connection is not stolen while we do the PAKE, append the
 | |
|     // exported key material from the tls connection to the password.
 | |
|     std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
 | |
|     if (exportedKeyMaterial.empty()) {
 | |
|         LOG(ERROR) << "Failed to export key material";
 | |
|         return false;
 | |
|     }
 | |
|     pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
 | |
|                  std::make_move_iterator(exportedKeyMaterial.end()));
 | |
|     auth_ = CreatePairingAuthPtr(role_, pswd_);
 | |
| 
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
 | |
|                                        std::string_view payload) {
 | |
|     PairingPacketHeader network_header = *header;
 | |
|     network_header.payload = htonl(network_header.payload);
 | |
|     if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
 | |
|                                            sizeof(PairingPacketHeader))) ||
 | |
|         !tls_->WriteFully(payload)) {
 | |
|         LOG(ERROR) << "Failed to write out PairingPacketHeader";
 | |
|         state_ = State::Stopped;
 | |
|         return false;
 | |
|     }
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
 | |
|     auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
 | |
|     if (data.empty()) {
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     uint8_t* p = data.data();
 | |
|     // First byte is always PairingPacketHeader version
 | |
|     header->version = *p;
 | |
|     ++p;
 | |
|     if (header->version < kMinSupportedKeyHeaderVersion ||
 | |
|         header->version > kMaxSupportedKeyHeaderVersion) {
 | |
|         LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
 | |
|                    << " them=" << header->version << ")";
 | |
|         return false;
 | |
|     }
 | |
|     // Next byte is the PairingPacket::Type
 | |
|     if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
 | |
|         LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
 | |
|         return false;
 | |
|     }
 | |
|     header->type = *p;
 | |
|     ++p;
 | |
|     // Last, the payload size
 | |
|     header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
 | |
|     if (header->payload == 0 || header->payload > kMaxPayloadSize) {
 | |
|         LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
 | |
|                    << ")";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
 | |
|                                         adb::proto::PairingPacket::Type type,
 | |
|                                         uint32_t payload_size) {
 | |
|     header->version = kCurrentKeyHeaderVersion;
 | |
|     uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
 | |
|     header->type = type8;
 | |
|     header->payload = payload_size;
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
 | |
|                                            uint8_t actual) {
 | |
|     uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
 | |
|     if (actual != expected) {
 | |
|         LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
 | |
|                    << " actual=" << static_cast<uint32_t>(actual) << ")";
 | |
|         return false;
 | |
|     }
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
 | |
|     cb_(p, fd_.get(), opaque_);
 | |
|     state_ = State::Stopped;
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
 | |
|     if (fd < 0) {
 | |
|         return false;
 | |
|     }
 | |
|     fd_.reset(fd);
 | |
| 
 | |
|     State expected = State::Ready;
 | |
|     if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     cb_ = cb;
 | |
|     opaque_ = opaque;
 | |
| 
 | |
|     thread_ = std::thread([this] { StartWorker(); });
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::DoExchangeMsgs() {
 | |
|     uint32_t payload = pairing_auth_msg_size(auth_.get());
 | |
|     std::vector<uint8_t> msg(payload);
 | |
|     pairing_auth_get_spake2_msg(auth_.get(), msg.data());
 | |
| 
 | |
|     PairingPacketHeader header;
 | |
|     CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
 | |
| 
 | |
|     // Write our SPAKE2 msg
 | |
|     if (!WriteHeader(&header,
 | |
|                      std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
 | |
|         LOG(ERROR) << "Failed to write SPAKE2 msg.";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Read the peer's SPAKE2 msg header
 | |
|     if (!ReadHeader(&header)) {
 | |
|         LOG(ERROR) << "Invalid PairingPacketHeader.";
 | |
|         return false;
 | |
|     }
 | |
|     if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Read the SPAKE2 msg payload and initialize the cipher for
 | |
|     // encrypting the PeerInfo and certificate.
 | |
|     auto their_msg = tls_->ReadFully(header.payload);
 | |
|     if (their_msg.empty() ||
 | |
|         !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
 | |
|         LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
 | |
|                    << "]";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| bool PairingConnectionCtx::DoExchangePeerInfo() {
 | |
|     // Encrypt PeerInfo
 | |
|     std::vector<uint8_t> buf;
 | |
|     uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
 | |
|     buf.assign(p, p + sizeof(peer_info_));
 | |
|     std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
 | |
|     CHECK(!outbuf.empty());
 | |
|     size_t outsize;
 | |
|     if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
 | |
|         LOG(ERROR) << "Failed to encrypt peer info";
 | |
|         return false;
 | |
|     }
 | |
|     outbuf.resize(outsize);
 | |
| 
 | |
|     // Write out the packet header
 | |
|     PairingPacketHeader out_header;
 | |
|     out_header.version = kCurrentKeyHeaderVersion;
 | |
|     out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
 | |
|     out_header.payload = htonl(outbuf.size());
 | |
|     if (!tls_->WriteFully(
 | |
|                 std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
 | |
|         LOG(ERROR) << "Unable to write PairingPacketHeader";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Write out the encrypted payload
 | |
|     if (!tls_->WriteFully(
 | |
|                 std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
 | |
|         LOG(ERROR) << "Unable to write encrypted peer info";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Read in the peer's packet header
 | |
|     PairingPacketHeader header;
 | |
|     if (!ReadHeader(&header)) {
 | |
|         LOG(ERROR) << "Invalid PairingPacketHeader.";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Read in the encrypted peer certificate
 | |
|     buf = tls_->ReadFully(header.payload);
 | |
|     if (buf.empty()) {
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Try to decrypt the certificate
 | |
|     outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
 | |
|     if (outbuf.empty()) {
 | |
|         LOG(ERROR) << "Unsupported payload while decrypting peer info.";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
 | |
|         LOG(ERROR) << "Failed to decrypt";
 | |
|         return false;
 | |
|     }
 | |
|     outbuf.resize(outsize);
 | |
| 
 | |
|     // The decrypted message should contain the PeerInfo.
 | |
|     if (outbuf.size() != sizeof(PeerInfo)) {
 | |
|         LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     p = outbuf.data();
 | |
|     ::memcpy(&their_info_, p, sizeof(PeerInfo));
 | |
|     p += sizeof(PeerInfo);
 | |
| 
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| void PairingConnectionCtx::StartWorker() {
 | |
|     // Setup the secure transport
 | |
|     if (!SetupTlsConnection()) {
 | |
|         NotifyResult(nullptr);
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     for (;;) {
 | |
|         switch (state_) {
 | |
|             case State::ExchangingMsgs:
 | |
|                 if (!DoExchangeMsgs()) {
 | |
|                     NotifyResult(nullptr);
 | |
|                     return;
 | |
|                 }
 | |
|                 state_ = State::ExchangingPeerInfo;
 | |
|                 break;
 | |
|             case State::ExchangingPeerInfo:
 | |
|                 if (!DoExchangePeerInfo()) {
 | |
|                     NotifyResult(nullptr);
 | |
|                     return;
 | |
|                 }
 | |
|                 NotifyResult(&their_info_);
 | |
|                 return;
 | |
|             case State::Ready:
 | |
|             case State::Stopped:
 | |
|                 LOG(FATAL) << __func__ << ": Got invalid state";
 | |
|                 return;
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| // static
 | |
| PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
 | |
|     switch (role) {
 | |
|         case Role::Client:
 | |
|             return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
 | |
|             break;
 | |
|         case Role::Server:
 | |
|             return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
 | |
|             break;
 | |
|     }
 | |
| }
 | |
| 
 | |
| static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
 | |
|                                               size_t pswd_len, const PeerInfo* peer_info,
 | |
|                                               const uint8_t* x509_cert_pem, size_t x509_size,
 | |
|                                               const uint8_t* priv_key_pem, size_t priv_size) {
 | |
|     CHECK(pswd);
 | |
|     CHECK_GT(pswd_len, 0U);
 | |
|     CHECK(x509_cert_pem);
 | |
|     CHECK_GT(x509_size, 0U);
 | |
|     CHECK(priv_key_pem);
 | |
|     CHECK_GT(priv_size, 0U);
 | |
|     CHECK(peer_info);
 | |
|     std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
 | |
|     std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
 | |
|     std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
 | |
|     return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
 | |
| }
 | |
| 
 | |
| PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
 | |
|                                                     const PeerInfo* peer_info,
 | |
|                                                     const uint8_t* x509_cert_pem, size_t x509_size,
 | |
|                                                     const uint8_t* priv_key_pem, size_t priv_size) {
 | |
|     return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
 | |
|                             x509_cert_pem, x509_size, priv_key_pem, priv_size);
 | |
| }
 | |
| 
 | |
| PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
 | |
|                                                     const PeerInfo* peer_info,
 | |
|                                                     const uint8_t* x509_cert_pem, size_t x509_size,
 | |
|                                                     const uint8_t* priv_key_pem, size_t priv_size) {
 | |
|     return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
 | |
|                             x509_cert_pem, x509_size, priv_key_pem, priv_size);
 | |
| }
 | |
| 
 | |
| void pairing_connection_destroy(PairingConnectionCtx* ctx) {
 | |
|     CHECK(ctx);
 | |
|     delete ctx;
 | |
| }
 | |
| 
 | |
| bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
 | |
|                               void* opaque) {
 | |
|     return ctx->Start(fd, cb, opaque);
 | |
| }
 |