181 lines
5.5 KiB
C++
181 lines
5.5 KiB
C++
// Copyright 2016 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "mojo/core/broker_host.h"
|
|
|
|
#include <utility>
|
|
|
|
#include "base/logging.h"
|
|
#include "base/memory/platform_shared_memory_region.h"
|
|
#include "base/memory/ref_counted.h"
|
|
#include "base/threading/thread_task_runner_handle.h"
|
|
#include "build/build_config.h"
|
|
#include "mojo/core/broker_messages.h"
|
|
#include "mojo/core/platform_handle_utils.h"
|
|
|
|
#if defined(OS_WIN)
|
|
#include <windows.h>
|
|
#endif
|
|
|
|
namespace mojo {
|
|
namespace core {
|
|
|
|
BrokerHost::BrokerHost(base::ProcessHandle client_process,
|
|
ConnectionParams connection_params,
|
|
const ProcessErrorCallback& process_error_callback)
|
|
: process_error_callback_(process_error_callback)
|
|
#if defined(OS_WIN)
|
|
,
|
|
client_process_(ScopedProcessHandle::CloneFrom(client_process))
|
|
#endif
|
|
{
|
|
CHECK(connection_params.endpoint().is_valid() ||
|
|
connection_params.server_endpoint().is_valid());
|
|
|
|
base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
|
|
|
|
channel_ = Channel::Create(this, std::move(connection_params),
|
|
base::ThreadTaskRunnerHandle::Get());
|
|
channel_->Start();
|
|
}
|
|
|
|
BrokerHost::~BrokerHost() {
|
|
// We're always destroyed on the creation thread, which is the IO thread.
|
|
base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
|
|
|
|
if (channel_)
|
|
channel_->ShutDown();
|
|
}
|
|
|
|
bool BrokerHost::PrepareHandlesForClient(
|
|
std::vector<PlatformHandleInTransit>* handles) {
|
|
#if defined(OS_WIN)
|
|
bool handles_ok = true;
|
|
for (auto& handle : *handles) {
|
|
if (!handle.TransferToProcess(client_process_.Clone()))
|
|
handles_ok = false;
|
|
}
|
|
return handles_ok;
|
|
#else
|
|
return true;
|
|
#endif
|
|
}
|
|
|
|
bool BrokerHost::SendChannel(PlatformHandle handle) {
|
|
CHECK(handle.is_valid());
|
|
CHECK(channel_);
|
|
|
|
#if defined(OS_WIN)
|
|
InitData* data;
|
|
Channel::MessagePtr message =
|
|
CreateBrokerMessage(BrokerMessageType::INIT, 1, 0, &data);
|
|
data->pipe_name_length = 0;
|
|
#else
|
|
Channel::MessagePtr message =
|
|
CreateBrokerMessage(BrokerMessageType::INIT, 1, nullptr);
|
|
#endif
|
|
std::vector<PlatformHandleInTransit> handles(1);
|
|
handles[0] = PlatformHandleInTransit(std::move(handle));
|
|
|
|
// This may legitimately fail on Windows if the client process is in another
|
|
// session, e.g., is an elevated process.
|
|
if (!PrepareHandlesForClient(&handles))
|
|
return false;
|
|
|
|
message->SetHandles(std::move(handles));
|
|
channel_->Write(std::move(message));
|
|
return true;
|
|
}
|
|
|
|
#if defined(OS_WIN)
|
|
|
|
void BrokerHost::SendNamedChannel(const base::StringPiece16& pipe_name) {
|
|
InitData* data;
|
|
base::char16* name_data;
|
|
Channel::MessagePtr message = CreateBrokerMessage(
|
|
BrokerMessageType::INIT, 0, sizeof(*name_data) * pipe_name.length(),
|
|
&data, reinterpret_cast<void**>(&name_data));
|
|
data->pipe_name_length = static_cast<uint32_t>(pipe_name.length());
|
|
std::copy(pipe_name.begin(), pipe_name.end(), name_data);
|
|
channel_->Write(std::move(message));
|
|
}
|
|
|
|
#endif // defined(OS_WIN)
|
|
|
|
void BrokerHost::OnBufferRequest(uint32_t num_bytes) {
|
|
base::subtle::PlatformSharedMemoryRegion region =
|
|
base::subtle::PlatformSharedMemoryRegion::CreateWritable(num_bytes);
|
|
|
|
std::vector<PlatformHandleInTransit> handles(2);
|
|
if (region.IsValid()) {
|
|
PlatformHandle h[2];
|
|
ExtractPlatformHandlesFromSharedMemoryRegionHandle(
|
|
region.PassPlatformHandle(), &h[0], &h[1]);
|
|
handles[0] = PlatformHandleInTransit(std::move(h[0]));
|
|
handles[1] = PlatformHandleInTransit(std::move(h[1]));
|
|
#if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_FUCHSIA) || \
|
|
(defined(OS_MACOSX) && !defined(OS_IOS))
|
|
// Non-POSIX systems, as well as Android, Fuchsia, and non-iOS Mac, only use
|
|
// a single handle to represent a writable region.
|
|
DCHECK(!handles[1].handle().is_valid());
|
|
handles.resize(1);
|
|
#else
|
|
DCHECK(handles[1].handle().is_valid());
|
|
#endif
|
|
}
|
|
|
|
BufferResponseData* response;
|
|
Channel::MessagePtr message = CreateBrokerMessage(
|
|
BrokerMessageType::BUFFER_RESPONSE, handles.size(), 0, &response);
|
|
if (!handles.empty()) {
|
|
base::UnguessableToken guid = region.GetGUID();
|
|
response->guid_high = guid.GetHighForSerialization();
|
|
response->guid_low = guid.GetLowForSerialization();
|
|
PrepareHandlesForClient(&handles);
|
|
message->SetHandles(std::move(handles));
|
|
}
|
|
|
|
channel_->Write(std::move(message));
|
|
}
|
|
|
|
void BrokerHost::OnChannelMessage(const void* payload,
|
|
size_t payload_size,
|
|
std::vector<PlatformHandle> handles) {
|
|
if (payload_size < sizeof(BrokerMessageHeader))
|
|
return;
|
|
|
|
const BrokerMessageHeader* header =
|
|
static_cast<const BrokerMessageHeader*>(payload);
|
|
switch (header->type) {
|
|
case BrokerMessageType::BUFFER_REQUEST:
|
|
if (payload_size ==
|
|
sizeof(BrokerMessageHeader) + sizeof(BufferRequestData)) {
|
|
const BufferRequestData* request =
|
|
reinterpret_cast<const BufferRequestData*>(header + 1);
|
|
OnBufferRequest(request->size);
|
|
}
|
|
break;
|
|
|
|
default:
|
|
DLOG(ERROR) << "Unexpected broker message type: " << header->type;
|
|
break;
|
|
}
|
|
}
|
|
|
|
void BrokerHost::OnChannelError(Channel::Error error) {
|
|
if (process_error_callback_ &&
|
|
error == Channel::Error::kReceivedMalformedData) {
|
|
process_error_callback_.Run("Broker host received malformed message");
|
|
}
|
|
|
|
delete this;
|
|
}
|
|
|
|
void BrokerHost::WillDestroyCurrentMessageLoop() {
|
|
delete this;
|
|
}
|
|
|
|
} // namespace core
|
|
} // namespace mojo
|