159 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			159 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			C++
		
	
	
	
| /*
 | |
|  * Copyright (C) 2018 The Android Open Source Project
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *      http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| #define LOG_TAG "resolv"
 | |
| 
 | |
| #include "DnsTlsQueryMap.h"
 | |
| 
 | |
| #include <android-base/logging.h>
 | |
| 
 | |
| #include "Experiments.h"
 | |
| 
 | |
| namespace android {
 | |
| namespace net {
 | |
| 
 | |
| DnsTlsQueryMap::DnsTlsQueryMap() {
 | |
|     mMaxTries = Experiments::getInstance()->getFlag("dot_maxtries", kMaxTries);
 | |
|     if (mMaxTries < 1) mMaxTries = 1;
 | |
| }
 | |
| 
 | |
| std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
 | |
|         const netdutils::Slice query) {
 | |
|     std::lock_guard guard(mLock);
 | |
| 
 | |
|     // Store the query so it can be matched to the response or reissued.
 | |
|     if (query.size() < 2) {
 | |
|         LOG(WARNING) << "Query is too short";
 | |
|         return nullptr;
 | |
|     }
 | |
|     int32_t newId = getFreeId();
 | |
|     if (newId < 0) {
 | |
|         LOG(WARNING) << "All query IDs are in use";
 | |
|         return nullptr;
 | |
|     }
 | |
| 
 | |
|     // Make a copy of the query.
 | |
|     std::vector<uint8_t> tmp(query.base(), query.base() + query.size());
 | |
|     Query q = {.newId = static_cast<uint16_t>(newId), .query = std::move(tmp)};
 | |
| 
 | |
|     const auto [it, inserted] = mQueries.try_emplace(newId, q);
 | |
|     if (!inserted) {
 | |
|         LOG(ERROR) << "Failed to store pending query";
 | |
|         return nullptr;
 | |
|     }
 | |
|     return std::make_unique<QueryFuture>(q, it->second.result.get_future());
 | |
| }
 | |
| 
 | |
| void DnsTlsQueryMap::expire(QueryPromise* p) {
 | |
|     Result r = { .code = Response::network_error };
 | |
|     p->result.set_value(r);
 | |
| }
 | |
| 
 | |
| void DnsTlsQueryMap::markTried(uint16_t newId) {
 | |
|     std::lock_guard guard(mLock);
 | |
|     auto it = mQueries.find(newId);
 | |
|     if (it != mQueries.end()) {
 | |
|         it->second.tries++;
 | |
|     }
 | |
| }
 | |
| 
 | |
| void DnsTlsQueryMap::cleanup() {
 | |
|     std::lock_guard guard(mLock);
 | |
|     for (auto it = mQueries.begin(); it != mQueries.end();) {
 | |
|         auto& p = it->second;
 | |
|         if (p.tries >= mMaxTries) {
 | |
|             expire(&p);
 | |
|             it = mQueries.erase(it);
 | |
|         } else {
 | |
|             ++it;
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| int32_t DnsTlsQueryMap::getFreeId() {
 | |
|     if (mQueries.empty()) {
 | |
|         return 0;
 | |
|     }
 | |
|     uint16_t maxId = mQueries.rbegin()->first;
 | |
|     if (maxId < UINT16_MAX) {
 | |
|         return maxId + 1;
 | |
|     }
 | |
|     if (mQueries.size() == UINT16_MAX + 1) {
 | |
|         // Map is full.
 | |
|         return -1;
 | |
|     }
 | |
|     // Linear scan.
 | |
|     uint16_t nextId = 0;
 | |
|     for (auto& pair : mQueries) {
 | |
|         uint16_t id = pair.first;
 | |
|         if (id != nextId) {
 | |
|             // Found a gap.
 | |
|             return nextId;
 | |
|         }
 | |
|         nextId = id + 1;
 | |
|     }
 | |
|     // Unreachable (but the compiler isn't smart enough to prove it).
 | |
|     return -1;
 | |
| }
 | |
| 
 | |
| std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
 | |
|     std::lock_guard guard(mLock);
 | |
|     std::vector<DnsTlsQueryMap::Query> queries;
 | |
|     for (auto& q : mQueries) {
 | |
|         queries.push_back(q.second.query);
 | |
|     }
 | |
|     return queries;
 | |
| }
 | |
| 
 | |
| bool DnsTlsQueryMap::empty() {
 | |
|     std::lock_guard guard(mLock);
 | |
|     return mQueries.empty();
 | |
| }
 | |
| 
 | |
| void DnsTlsQueryMap::clear() {
 | |
|     std::lock_guard guard(mLock);
 | |
|     for (auto& q : mQueries) {
 | |
|         expire(&q.second);
 | |
|     }
 | |
|     mQueries.clear();
 | |
| }
 | |
| 
 | |
| void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
 | |
|     LOG(VERBOSE) << "Got response of size " << response.size();
 | |
|     if (response.size() < 2) {
 | |
|         LOG(WARNING) << "Response is too short";
 | |
|         return;
 | |
|     }
 | |
|     uint16_t id = response[0] << 8 | response[1];
 | |
|     std::lock_guard guard(mLock);
 | |
|     auto it = mQueries.find(id);
 | |
|     if (it == mQueries.end()) {
 | |
|         LOG(WARNING) << "Discarding response: unknown ID " << id;
 | |
|         return;
 | |
|     }
 | |
|     Result r = { .code = Response::success, .response = std::move(response) };
 | |
|     // Rewrite ID to match the query
 | |
|     const uint8_t* data = it->second.query.query.data();
 | |
|     r.response[0] = data[0];
 | |
|     r.response[1] = data[1];
 | |
|     LOG(DEBUG) << "Sending result to dispatcher";
 | |
|     it->second.result.set_value(std::move(r));
 | |
|     mQueries.erase(it);
 | |
| }
 | |
| 
 | |
| }  // end of namespace net
 | |
| }  // end of namespace android
 |