// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "cast/common/channel/connection_namespace_handler.h" #include #include #include #include "cast/common/channel/message_util.h" #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" #include "cast/common/channel/virtual_connection.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "util/json/json_serialization.h" #include "util/json/json_value.h" #include "util/osp_logging.h" namespace openscreen { namespace cast { namespace { using ::testing::_; using ::testing::Invoke; using ::testing::NiceMock; using ::cast::channel::CastMessage; using ::cast::channel::CastMessage_ProtocolVersion; class MockVirtualConnectionPolicy : public ConnectionNamespaceHandler::VirtualConnectionPolicy { public: ~MockVirtualConnectionPolicy() override = default; MOCK_METHOD(bool, IsConnectionAllowed, (const VirtualConnection& virtual_conn), (const, override)); }; CastMessage MakeVersionedConnectMessage( const std::string& source_id, const std::string& destination_id, absl::optional version, std::vector version_list) { CastMessage connect_message = MakeConnectMessage(source_id, destination_id); Json::Value message(Json::ValueType::objectValue); message[kMessageKeyType] = kMessageTypeConnect; if (version) { message[kMessageKeyProtocolVersion] = version.value(); } if (!version_list.empty()) { Json::Value list(Json::ValueType::arrayValue); for (CastMessage_ProtocolVersion v : version_list) { list.append(v); } message[kMessageKeyProtocolVersionList] = std::move(list); } ErrorOr result = json::Stringify(message); OSP_DCHECK(result); connect_message.set_payload_utf8(std::move(result.value())); return connect_message; } void VerifyConnectionMessage(const CastMessage& message, const std::string& source_id, const std::string& destination_id) { EXPECT_EQ(message.source_id(), source_id); EXPECT_EQ(message.destination_id(), destination_id); EXPECT_EQ(message.namespace_(), kConnectionNamespace); ASSERT_EQ(message.payload_type(), ::cast::channel::CastMessage_PayloadType_STRING); } Json::Value ParseConnectionMessage(const CastMessage& message) { ErrorOr result = json::Parse(message.payload_utf8()); OSP_CHECK(result) << message.payload_utf8(); return result.value(); } } // namespace class ConnectionNamespaceHandlerTest : public ::testing::Test { public: void SetUp() override { socket_ = fake_cast_socket_pair_.socket.get(); router_.TakeSocket(&mock_error_handler_, std::move(fake_cast_socket_pair_.socket)); ON_CALL(vc_policy_, IsConnectionAllowed(_)) .WillByDefault( Invoke([](const VirtualConnection& virtual_conn) { return true; })); } protected: void ExpectCloseMessage(MockCastSocketClient* mock_client, const std::string& source_id, const std::string& destination_id) { EXPECT_CALL(*mock_client, OnMessage(_, _)) .WillOnce(Invoke([&source_id, &destination_id](CastSocket* socket, CastMessage message) { VerifyConnectionMessage(message, source_id, destination_id); Json::Value value = ParseConnectionMessage(message); absl::optional type = MaybeGetString( value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); ASSERT_TRUE(type) << message.payload_utf8(); EXPECT_EQ(type.value(), kMessageTypeClose) << message.payload_utf8(); })); } void ExpectConnectedMessage( MockCastSocketClient* mock_client, const std::string& source_id, const std::string& destination_id, absl::optional version = absl::nullopt) { EXPECT_CALL(*mock_client, OnMessage(_, _)) .WillOnce(Invoke([&source_id, &destination_id, version]( CastSocket* socket, CastMessage message) { VerifyConnectionMessage(message, source_id, destination_id); Json::Value value = ParseConnectionMessage(message); absl::optional type = MaybeGetString( value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); ASSERT_TRUE(type) << message.payload_utf8(); EXPECT_EQ(type.value(), kMessageTypeConnected) << message.payload_utf8(); if (version) { absl::optional message_version = MaybeGetInt( value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion)); ASSERT_TRUE(message_version) << message.payload_utf8(); EXPECT_EQ(message_version.value(), version.value()); } })); } FakeCastSocketPair fake_cast_socket_pair_; MockSocketErrorHandler mock_error_handler_; CastSocket* socket_; NiceMock vc_policy_; VirtualConnectionRouter router_; ConnectionNamespaceHandler connection_namespace_handler_{&router_, &vc_policy_}; const std::string sender_id_{"sender-5678"}; const std::string receiver_id_{"receiver-3245"}; }; TEST_F(ConnectionNamespaceHandlerTest, Connect) { connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); EXPECT_CALL(fake_cast_socket_pair_.mock_peer_client, OnMessage(_, _)) .Times(0); } TEST_F(ConnectionNamespaceHandlerTest, PolicyDeniesConnection) { EXPECT_CALL(vc_policy_, IsConnectionAllowed(_)) .WillOnce( Invoke([](const VirtualConnection& virtual_conn) { return false; })); ExpectCloseMessage(&fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_); connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); EXPECT_FALSE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersion) { ExpectConnectedMessage( &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2); connection_namespace_handler_.OnMessage( &router_, socket_, MakeVersionedConnectMessage( sender_id_, receiver_id_, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {})); EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } TEST_F(ConnectionNamespaceHandlerTest, ConnectWithVersionList) { ExpectConnectedMessage( &fake_cast_socket_pair_.mock_peer_client, receiver_id_, sender_id_, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3); connection_namespace_handler_.OnMessage( &router_, socket_, MakeVersionedConnectMessage( sender_id_, receiver_id_, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_2, {::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_3, ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0})); EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } TEST_F(ConnectionNamespaceHandlerTest, Close) { connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); connection_namespace_handler_.OnMessage( &router_, socket_, MakeCloseMessage(sender_id_, receiver_id_)); EXPECT_FALSE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } TEST_F(ConnectionNamespaceHandlerTest, CloseUnknown) { connection_namespace_handler_.OnMessage( &router_, socket_, MakeConnectMessage(sender_id_, receiver_id_)); EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); connection_namespace_handler_.OnMessage( &router_, socket_, MakeCloseMessage(sender_id_ + "098", receiver_id_)); EXPECT_TRUE(router_.GetConnectionData( VirtualConnection{receiver_id_, sender_id_, socket_->socket_id()})); } } // namespace cast } // namespace openscreen