209 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			209 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			C++
		
	
	
	
| /*
 | |
|  * Copyright (C) 2018 The Android Open Source Project
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *      http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| #ifndef _DNS_DNSTLSSOCKET_H
 | |
| #define _DNS_DNSTLSSOCKET_H
 | |
| 
 | |
| #include <openssl/ssl.h>
 | |
| #include <future>
 | |
| #include <mutex>
 | |
| 
 | |
| #include <android-base/thread_annotations.h>
 | |
| #include <android-base/unique_fd.h>
 | |
| #include <netdutils/Slice.h>
 | |
| #include <netdutils/Status.h>
 | |
| 
 | |
| #include "DnsTlsServer.h"
 | |
| #include "IDnsTlsSocket.h"
 | |
| #include "LockedQueue.h"
 | |
| 
 | |
| namespace android {
 | |
| namespace net {
 | |
| 
 | |
| class IDnsTlsSocketObserver;
 | |
| class DnsTlsSessionCache;
 | |
| 
 | |
| // A class for managing a TLS socket that sends and receives messages in
 | |
| // [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
 | |
| // This class is not aware of query-response pairing or anything else about DNS.
 | |
| // For the observer:
 | |
| // This class is not re-entrant: the observer is not permitted to wait for a call to query()
 | |
| // or the destructor in a callback.  Doing so will result in deadlocks.
 | |
| // This class may call the observer at any time after initialize(), until the destructor
 | |
| // returns (but not after).
 | |
| //
 | |
| // Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
 | |
| //
 | |
| //                                UNINITIALIZED
 | |
| //                                      |
 | |
| //                                      v
 | |
| //                                 INITIALIZED
 | |
| //                                      |
 | |
| //                                      v
 | |
| //                            +----CONNECTING------+
 | |
| //            Handshake fails |                    | Handshake succeeds
 | |
| //   (onClose() when          |                    |
 | |
| //    mAsyncHandshake is set) |                    v
 | |
| //                            |        +---> CONNECTED --+
 | |
| //                            |        |           |     |
 | |
| //                            |        +-----------+     | Idle timeout
 | |
| //                            |   Send/Recv queries      | onClose()
 | |
| //                            |   onResponse()           |
 | |
| //                            |                          |
 | |
| //                            |                          |
 | |
| //                            +--> WAIT_FOR_DELETE <-----+
 | |
| //
 | |
| //
 | |
| // TODO: Add onHandshakeFinished() for handshake results.
 | |
| class DnsTlsSocket : public IDnsTlsSocket {
 | |
|   public:
 | |
|     enum class State {
 | |
|         UNINITIALIZED,
 | |
|         INITIALIZED,
 | |
|         CONNECTING,
 | |
|         CONNECTED,
 | |
|         WAIT_FOR_DELETE,
 | |
|     };
 | |
| 
 | |
|     DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
 | |
|                  IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
 | |
|         : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
 | |
|     ~DnsTlsSocket();
 | |
| 
 | |
|     // Creates the SSL context for this session. Returns false on failure.
 | |
|     // This method should be called after construction and before use of a DnsTlsSocket.
 | |
|     // Only call this method once per DnsTlsSocket.
 | |
|     bool initialize() EXCLUDES(mLock);
 | |
| 
 | |
|     // If async handshake is enabled, this function simply signals a handshake request, and the
 | |
|     // handshake will be performed in the loop thread; otherwise, if async handshake is disabled,
 | |
|     // this function performs the handshake and returns after the handshake finishes.
 | |
|     bool startHandshake() EXCLUDES(mLock);
 | |
| 
 | |
|     // Send a query on the provided SSL socket.  |query| contains
 | |
|     // the body of a query, not including the ID header. This function will typically return before
 | |
|     // the query is actually sent.  If this function fails, DnsTlsSocketObserver will be
 | |
|     // notified that the socket is closed.
 | |
|     // Note that success here indicates successful sending, not receipt of a response.
 | |
|     // Thread-safe.
 | |
|     bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);
 | |
| 
 | |
|   private:
 | |
|     // Lock to be held by the SSL event loop thread.  This is not normally in contention.
 | |
|     std::mutex mLock;
 | |
| 
 | |
|     // Forwards queries and receives responses.  Blocks until the idle timeout.
 | |
|     void loop() EXCLUDES(mLock);
 | |
|     std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);
 | |
| 
 | |
|     // On success, sets mSslFd to a socket connected to mAddr (the
 | |
|     // connection will likely be in progress if mProtocol is IPPROTO_TCP).
 | |
|     // On error, returns the errno.
 | |
|     netdutils::Status tcpConnect() REQUIRES(mLock);
 | |
| 
 | |
|     bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock);
 | |
| 
 | |
|     // Connect an SSL session on the provided socket.  If connection fails, closing the
 | |
|     // socket remains the caller's responsibility.
 | |
|     bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);
 | |
| 
 | |
|     // Connect an SSL session on the provided socket. This is an interruptible version
 | |
|     // which allows to terminate connection handshake any time.
 | |
|     bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock);
 | |
| 
 | |
|     // Disconnect the SSL session and close the socket.
 | |
|     void sslDisconnect() REQUIRES(mLock);
 | |
| 
 | |
|     // Writes a buffer to the socket.
 | |
|     bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);
 | |
| 
 | |
|     // Reads exactly the specified number of bytes from the socket, or fails.
 | |
|     // Returns SSL_ERROR_NONE on success.
 | |
|     // If |wait| is true, then this function always blocks.  Otherwise, it
 | |
|     // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
 | |
|     int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);
 | |
| 
 | |
|     bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);
 | |
| 
 | |
|     // Read one DNS response. It can potentially block until reading the exact bytes of
 | |
|     // the response.
 | |
|     bool readResponse() REQUIRES(mLock);
 | |
| 
 | |
|     // It is only used for DNS-OVER-TLS internal test.
 | |
|     bool setTestCaCertificate() REQUIRES(mLock);
 | |
| 
 | |
|     // Similar to query(), this function uses incrementEventFd to send a message to the
 | |
|     // loop thread.  However, instead of incrementing the counter by one (indicating a
 | |
|     // new query), it wraps the counter to negative, which we use to indicate a shutdown
 | |
|     // request.
 | |
|     void requestLoopShutdown() EXCLUDES(mLock);
 | |
| 
 | |
|     // This function sends a message to the loop thread by incrementing mEventFd.
 | |
|     bool incrementEventFd(int64_t count) EXCLUDES(mLock);
 | |
| 
 | |
|     // Transition the state from expected state |from| to new state |to|.
 | |
|     void transitionState(State from, State to) REQUIRES(mLock);
 | |
| 
 | |
|     // Queue of pending queries.  query() pushes items onto the queue and notifies
 | |
|     // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
 | |
|     LockedQueue<std::vector<uint8_t>> mQueue;
 | |
| 
 | |
|     // eventfd socket used for notifying the SSL thread when queries are ready to send.
 | |
|     // This socket acts similarly to an atomic counter, incremented by query() and cleared
 | |
|     // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
 | |
|     // for input from either a remote server or a query thread.  Since eventfd does not have
 | |
|     // EOF, we indicate a close request by setting the counter to a negative number.
 | |
|     // This file descriptor is opened by initialize(), and closed implicitly after
 | |
|     // destruction.
 | |
|     // Note that: data starts being read from the eventfd when the state is CONNECTED.
 | |
|     base::unique_fd mEventFd;
 | |
| 
 | |
|     // An eventfd used to listen to shutdown requests when the state is CONNECTING.
 | |
|     // TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively
 | |
|     // handle shutdown requests.
 | |
|     base::unique_fd mShutdownEvent;
 | |
| 
 | |
|     // SSL Socket fields.
 | |
|     bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
 | |
|     base::unique_fd mSslFd GUARDED_BY(mLock);
 | |
|     bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
 | |
|     static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);
 | |
| 
 | |
|     const unsigned mMark;  // Socket mark
 | |
|     const DnsTlsServer mServer;
 | |
|     IDnsTlsSocketObserver* _Nonnull const mObserver;
 | |
|     DnsTlsSessionCache* _Nonnull const mCache;
 | |
|     State mState GUARDED_BY(mLock) = State::UNINITIALIZED;
 | |
| 
 | |
|     // If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's
 | |
|     // thread (the call to startHandshake()).
 | |
|     bool mAsyncHandshake GUARDED_BY(mLock) = false;
 | |
| 
 | |
|     // The time to wait for the attempt on connecting to the server.
 | |
|     // Set the default value 127 seconds to be consistent with TCP connect timeout.
 | |
|     // (presume net.ipv4.tcp_syn_retries = 6)
 | |
|     static constexpr int kDotConnectTimeoutMs = 127 * 1000;
 | |
|     int mConnectTimeoutMs;
 | |
| 
 | |
|     // For testing.
 | |
|     friend class DnsTlsSocketTest;
 | |
| };
 | |
| 
 | |
| }  // end of namespace net
 | |
| }  // end of namespace android
 | |
| 
 | |
| #endif  // _DNS_DNSTLSSOCKET_H
 |