446 lines
14 KiB
C++
446 lines
14 KiB
C++
/*
|
|
* Copyright (C) 2020 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <host/libs/websocket/websocket_server.h>
|
|
|
|
#include <string>
|
|
#include <unordered_map>
|
|
|
|
#include <android-base/logging.h>
|
|
#include <libwebsockets.h>
|
|
|
|
#include <common/libs/utils/files.h>
|
|
#include <host/libs/websocket/websocket_handler.h>
|
|
|
|
namespace cuttlefish {
|
|
namespace {
|
|
|
|
std::string GetPath(struct lws* wsi) {
|
|
auto len = lws_hdr_total_length(wsi, WSI_TOKEN_GET_URI);
|
|
std::string path(len + 1, '\0');
|
|
auto ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_GET_URI);
|
|
if (ret <= 0) {
|
|
len = lws_hdr_total_length(wsi, WSI_TOKEN_HTTP_COLON_PATH);
|
|
path.resize(len + 1, '\0');
|
|
ret =
|
|
lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_HTTP_COLON_PATH);
|
|
}
|
|
if (ret < 0) {
|
|
LOG(FATAL) << "Something went wrong getting the path";
|
|
}
|
|
path.resize(len);
|
|
return path;
|
|
}
|
|
|
|
const std::vector<std::pair<std::string, std::string>> kCORSHeaders = {
|
|
{"Access-Control-Allow-Origin:", "*"},
|
|
{"Access-Control-Allow-Methods:", "POST, GET, OPTIONS"},
|
|
{"Access-Control-Allow-Headers:",
|
|
"Content-Type, Access-Control-Allow-Headers, Authorization, "
|
|
"X-Requested-With, Accept"}};
|
|
|
|
bool AddCORSHeaders(struct lws* wsi, unsigned char** buffer_ptr,
|
|
unsigned char* buffer_end) {
|
|
for (const auto& header : kCORSHeaders) {
|
|
const auto& name = header.first;
|
|
const auto& value = header.second;
|
|
if (lws_add_http_header_by_name(
|
|
wsi, reinterpret_cast<const unsigned char*>(name.c_str()),
|
|
reinterpret_cast<const unsigned char*>(value.c_str()), value.size(),
|
|
buffer_ptr, buffer_end)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool WriteCommonHttpHeaders(int status, const char* mime_type,
|
|
size_t content_len, struct lws* wsi) {
|
|
constexpr size_t BUFF_SIZE = 2048;
|
|
uint8_t header_buffer[LWS_PRE + BUFF_SIZE];
|
|
const auto start = &header_buffer[LWS_PRE];
|
|
auto p = &header_buffer[LWS_PRE];
|
|
auto end = start + BUFF_SIZE;
|
|
if (lws_add_http_common_headers(wsi, status, mime_type, content_len, &p,
|
|
end)) {
|
|
LOG(ERROR) << "Failed to write headers for response";
|
|
return false;
|
|
}
|
|
if (!AddCORSHeaders(wsi, &p, end)) {
|
|
LOG(ERROR) << "Failed to write CORS headers for response";
|
|
return false;
|
|
}
|
|
if (lws_finalize_write_http_header(wsi, start, &p, end)) {
|
|
LOG(ERROR) << "Failed to finalize headers for response";
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
} // namespace
|
|
WebSocketServer::WebSocketServer(const char* protocol_name,
|
|
const std::string& assets_dir, int server_port)
|
|
: WebSocketServer(protocol_name, "", assets_dir, server_port) {}
|
|
|
|
WebSocketServer::WebSocketServer(const char* protocol_name,
|
|
const std::string& certs_dir,
|
|
const std::string& assets_dir, int server_port)
|
|
: protocol_name_(protocol_name),
|
|
assets_dir_(assets_dir),
|
|
certs_dir_(certs_dir),
|
|
server_port_(server_port) {}
|
|
|
|
void WebSocketServer::InitializeLwsObjects() {
|
|
std::string cert_file = certs_dir_ + "/server.crt";
|
|
std::string key_file = certs_dir_ + "/server.key";
|
|
std::string ca_file = certs_dir_ + "/CA.crt";
|
|
|
|
retry_ = {
|
|
.secs_since_valid_ping = 3,
|
|
.secs_since_valid_hangup = 10,
|
|
};
|
|
|
|
struct lws_protocols protocols[] = //
|
|
{{
|
|
.name = protocol_name_.c_str(),
|
|
.callback = WebsocketCallback,
|
|
.per_session_data_size = 0,
|
|
.rx_buffer_size = 0,
|
|
.id = 0,
|
|
.user = this,
|
|
.tx_packet_size = 0,
|
|
},
|
|
{
|
|
.name = "__http_polling__",
|
|
.callback = DynHttpCallback,
|
|
.per_session_data_size = 0,
|
|
.rx_buffer_size = 0,
|
|
.id = 0,
|
|
.user = this,
|
|
.tx_packet_size = 0,
|
|
},
|
|
{
|
|
.name = nullptr,
|
|
.callback = nullptr,
|
|
.per_session_data_size = 0,
|
|
.rx_buffer_size = 0,
|
|
.id = 0,
|
|
.user = nullptr,
|
|
.tx_packet_size = 0,
|
|
}};
|
|
|
|
dyn_mounts_.reserve(dyn_handler_factories_.size());
|
|
for (auto& handler_entry : dyn_handler_factories_) {
|
|
auto& path = handler_entry.first;
|
|
dyn_mounts_.push_back({
|
|
.mount_next = nullptr,
|
|
.mountpoint = path.c_str(),
|
|
.mountpoint_len = static_cast<uint8_t>(path.size()),
|
|
.origin = "__http_polling__",
|
|
.def = nullptr,
|
|
.protocol = nullptr,
|
|
.cgienv = nullptr,
|
|
.extra_mimetypes = nullptr,
|
|
.interpret = nullptr,
|
|
.cgi_timeout = 0,
|
|
.cache_max_age = 0,
|
|
.auth_mask = 0,
|
|
.cache_reusable = 0,
|
|
.cache_revalidate = 0,
|
|
.cache_intermediaries = 0,
|
|
.origin_protocol = LWSMPRO_CALLBACK, // dynamic
|
|
.basic_auth_login_file = nullptr,
|
|
});
|
|
}
|
|
struct lws_http_mount* next_mount = nullptr;
|
|
// Set up the linked list after all the mounts have been created to ensure
|
|
// pointers are not invalidated.
|
|
for (auto& mount : dyn_mounts_) {
|
|
mount.mount_next = next_mount;
|
|
next_mount = &mount;
|
|
}
|
|
|
|
static_mount_ = {
|
|
.mount_next = next_mount,
|
|
.mountpoint = "/",
|
|
.mountpoint_len = 1,
|
|
.origin = assets_dir_.c_str(),
|
|
.def = "index.html",
|
|
.protocol = nullptr,
|
|
.cgienv = nullptr,
|
|
.extra_mimetypes = nullptr,
|
|
.interpret = nullptr,
|
|
.cgi_timeout = 0,
|
|
.cache_max_age = 0,
|
|
.auth_mask = 0,
|
|
.cache_reusable = 0,
|
|
.cache_revalidate = 0,
|
|
.cache_intermediaries = 0,
|
|
.origin_protocol = LWSMPRO_FILE, // files in a dir
|
|
.basic_auth_login_file = nullptr,
|
|
};
|
|
|
|
struct lws_context_creation_info info;
|
|
headers_ = {NULL, NULL, "content-security-policy:",
|
|
"default-src 'self' https://ajax.googleapis.com; "
|
|
"style-src 'self' https://fonts.googleapis.com/; "
|
|
"font-src https://fonts.gstatic.com/; "};
|
|
|
|
memset(&info, 0, sizeof info);
|
|
info.port = server_port_;
|
|
info.mounts = &static_mount_;
|
|
info.protocols = protocols;
|
|
info.vhost_name = "localhost";
|
|
info.headers = &headers_;
|
|
info.retry_and_idle_policy = &retry_;
|
|
|
|
if (!certs_dir_.empty()) {
|
|
info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
|
|
info.ssl_cert_filepath = cert_file.c_str();
|
|
info.ssl_private_key_filepath = key_file.c_str();
|
|
if (FileExists(ca_file)) {
|
|
info.ssl_ca_filepath = ca_file.c_str();
|
|
}
|
|
}
|
|
|
|
context_ = lws_create_context(&info);
|
|
if (!context_) {
|
|
LOG(FATAL) << "Failed to create websocket context";
|
|
}
|
|
}
|
|
|
|
void WebSocketServer::RegisterHandlerFactory(
|
|
const std::string& path,
|
|
std::unique_ptr<WebSocketHandlerFactory> handler_factory_p) {
|
|
handler_factories_[path] = std::move(handler_factory_p);
|
|
}
|
|
|
|
void WebSocketServer::RegisterDynHandlerFactory(
|
|
const std::string& path,
|
|
DynHandlerFactory handler_factory) {
|
|
dyn_handler_factories_[path] = std::move(handler_factory);
|
|
}
|
|
|
|
void WebSocketServer::Serve() {
|
|
InitializeLwsObjects();
|
|
int n = 0;
|
|
while (n >= 0) {
|
|
n = lws_service(context_, 0);
|
|
}
|
|
lws_context_destroy(context_);
|
|
}
|
|
|
|
int WebSocketServer::WebsocketCallback(struct lws* wsi,
|
|
enum lws_callback_reasons reason,
|
|
void* user, void* in, size_t len) {
|
|
auto protocol = lws_get_protocol(wsi);
|
|
if (!protocol) {
|
|
// Some callback reasons are always handled by the first protocol, before a
|
|
// wsi struct is even created.
|
|
return lws_callback_http_dummy(wsi, reason, user, in, len);
|
|
}
|
|
return reinterpret_cast<WebSocketServer*>(protocol->user)
|
|
->ServerCallback(wsi, reason, user, in, len);
|
|
}
|
|
|
|
int WebSocketServer::DynHttpCallback(struct lws* wsi,
|
|
enum lws_callback_reasons reason,
|
|
void* user, void* in, size_t len) {
|
|
auto protocol = lws_get_protocol(wsi);
|
|
if (!protocol) {
|
|
LOG(ERROR) << "No protocol associated with connection";
|
|
return 1;
|
|
}
|
|
return reinterpret_cast<WebSocketServer*>(protocol->user)
|
|
->DynServerCallback(wsi, reason, user, in, len);
|
|
}
|
|
|
|
int WebSocketServer::DynServerCallback(struct lws* wsi,
|
|
enum lws_callback_reasons reason,
|
|
void* user, void* in, size_t len) {
|
|
switch (reason) {
|
|
case LWS_CALLBACK_HTTP: {
|
|
char* path_raw;
|
|
int path_len;
|
|
auto method = lws_http_get_uri_and_method(wsi, &path_raw, &path_len);
|
|
if (method < 0) {
|
|
return 1;
|
|
}
|
|
std::string path(path_raw, path_len);
|
|
auto handler = InstantiateDynHandler(path, wsi);
|
|
if (!handler) {
|
|
if (!WriteCommonHttpHeaders(static_cast<int>(HttpStatusCode::NotFound),
|
|
"application/json", 0, wsi)) {
|
|
return 1;
|
|
}
|
|
return lws_http_transaction_completed(wsi);
|
|
}
|
|
dyn_handlers_[wsi] = std::move(handler);
|
|
switch (method) {
|
|
case LWSHUMETH_GET: {
|
|
auto status = dyn_handlers_[wsi]->DoGet();
|
|
if (!WriteCommonHttpHeaders(static_cast<int>(status),
|
|
"application/json",
|
|
dyn_handlers_[wsi]->content_len(), wsi)) {
|
|
return 1;
|
|
}
|
|
// Write the response later, when the server is ready
|
|
lws_callback_on_writable(wsi);
|
|
break;
|
|
}
|
|
case LWSHUMETH_POST:
|
|
// Do nothing until the body has been read
|
|
break;
|
|
case LWSHUMETH_OPTIONS: {
|
|
// Response for CORS preflight
|
|
auto status = HttpStatusCode::NoContent;
|
|
if (!WriteCommonHttpHeaders(static_cast<int>(status), "", 0, wsi)) {
|
|
return 1;
|
|
}
|
|
lws_callback_on_writable(wsi);
|
|
break;
|
|
}
|
|
default:
|
|
LOG(ERROR) << "Unsupported HTTP method: " << method;
|
|
return 1;
|
|
}
|
|
break;
|
|
}
|
|
case LWS_CALLBACK_HTTP_BODY: {
|
|
auto handler = dyn_handlers_[wsi].get();
|
|
if (!handler) {
|
|
LOG(WARNING) << "Received body for unknown wsi";
|
|
return 1;
|
|
}
|
|
handler->AppendDataIn(in, len);
|
|
break;
|
|
}
|
|
case LWS_CALLBACK_HTTP_BODY_COMPLETION: {
|
|
auto handler = dyn_handlers_[wsi].get();
|
|
if (!handler) {
|
|
LOG(WARNING) << "Unexpected body completion event from unknown wsi";
|
|
return 1;
|
|
}
|
|
auto status = handler->DoPost();
|
|
if (!WriteCommonHttpHeaders(static_cast<int>(status), "application/json",
|
|
dyn_handlers_[wsi]->content_len(), wsi)) {
|
|
return 1;
|
|
}
|
|
lws_callback_on_writable(wsi);
|
|
break;
|
|
}
|
|
case LWS_CALLBACK_HTTP_WRITEABLE: {
|
|
auto handler = dyn_handlers_[wsi].get();
|
|
if (!handler) {
|
|
LOG(WARNING) << "Unknown wsi became writable";
|
|
return 1;
|
|
}
|
|
auto ret = handler->OnWritable();
|
|
dyn_handlers_.erase(wsi);
|
|
// Make sure the connection (in HTTP 1) or stream (in HTTP 2) is closed
|
|
// after the response is written
|
|
return ret;
|
|
}
|
|
case LWS_CALLBACK_CLOSED_HTTP:
|
|
break;
|
|
default:
|
|
return lws_callback_http_dummy(wsi, reason, user, in, len);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int WebSocketServer::ServerCallback(struct lws* wsi,
|
|
enum lws_callback_reasons reason,
|
|
void* user, void* in, size_t len) {
|
|
switch (reason) {
|
|
case LWS_CALLBACK_ESTABLISHED: {
|
|
auto path = GetPath(wsi);
|
|
auto handler = InstantiateHandler(path, wsi);
|
|
if (!handler) {
|
|
// This message came on an unexpected uri, close the connection.
|
|
lws_close_reason(wsi, LWS_CLOSE_STATUS_NOSTATUS, (uint8_t*)"404", 3);
|
|
return -1;
|
|
}
|
|
handlers_[wsi] = handler;
|
|
handler->OnConnected();
|
|
break;
|
|
}
|
|
case LWS_CALLBACK_CLOSED: {
|
|
auto handler = handlers_[wsi];
|
|
if (handler) {
|
|
handler->OnClosed();
|
|
handlers_.erase(wsi);
|
|
}
|
|
break;
|
|
}
|
|
case LWS_CALLBACK_SERVER_WRITEABLE: {
|
|
auto handler = handlers_[wsi];
|
|
if (handler) {
|
|
auto should_close = handler->OnWritable();
|
|
if (should_close) {
|
|
lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, nullptr, 0);
|
|
return 1;
|
|
}
|
|
} else {
|
|
LOG(WARNING) << "Unknown wsi became writable";
|
|
return -1;
|
|
}
|
|
break;
|
|
}
|
|
case LWS_CALLBACK_RECEIVE: {
|
|
auto handler = handlers_[wsi];
|
|
if (handler) {
|
|
bool is_final = (lws_remaining_packet_payload(wsi) == 0) &&
|
|
lws_is_final_fragment(wsi);
|
|
handler->OnReceive(reinterpret_cast<const uint8_t*>(in), len,
|
|
lws_frame_is_binary(wsi), is_final);
|
|
} else {
|
|
LOG(WARNING) << "Unknown wsi sent data";
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
return lws_callback_http_dummy(wsi, reason, user, in, len);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
std::shared_ptr<WebSocketHandler> WebSocketServer::InstantiateHandler(
|
|
const std::string& uri_path, struct lws* wsi) {
|
|
auto it = handler_factories_.find(uri_path);
|
|
if (it == handler_factories_.end()) {
|
|
LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
|
|
return nullptr;
|
|
} else {
|
|
LOG(VERBOSE) << "Creating handler for " << uri_path;
|
|
return it->second->Build(wsi);
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<DynHandler> WebSocketServer::InstantiateDynHandler(
|
|
const std::string& uri_path, struct lws* wsi) {
|
|
auto it = dyn_handler_factories_.find(uri_path);
|
|
if (it == dyn_handler_factories_.end()) {
|
|
LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
|
|
return nullptr;
|
|
} else {
|
|
LOG(VERBOSE) << "Creating handler for " << uri_path;
|
|
return it->second(wsi);
|
|
}
|
|
}
|
|
|
|
} // namespace cuttlefish
|