#!/usr/bin/env python3 # Copyright 2021 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Tests using the callback client for pw_rpc.""" import unittest from unittest import mock from typing import Any, List, Optional, Tuple from pw_protobuf_compiler import python_protos from pw_status import Status from pw_rpc import callback_client, client, packets from pw_rpc.internal import packet_pb2 TEST_PROTO_1 = """\ syntax = "proto3"; package pw.test1; message SomeMessage { uint32 magic_number = 1; } message AnotherMessage { enum Result { FAILED = 0; FAILED_MISERABLY = 1; I_DONT_WANT_TO_TALK_ABOUT_IT = 2; } Result result = 1; string payload = 2; } service PublicService { rpc SomeUnary(SomeMessage) returns (AnotherMessage) {} rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {} rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {} rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} } """ def _message_bytes(msg) -> bytes: return msg if isinstance(msg, bytes) else msg.SerializeToString() class _CallbackClientImplTestBase(unittest.TestCase): """Supports writing tests that require responses from an RPC server.""" def setUp(self) -> None: self._protos = python_protos.Library.from_strings(TEST_PROTO_1) self._request = self._protos.packages.pw.test1.SomeMessage self._client = client.Client.from_modules( callback_client.Impl(), [client.Channel(1, self._handle_packet)], self._protos.modules()) self._service = self._client.channel(1).rpcs.pw.test1.PublicService self.requests: List[packet_pb2.RpcPacket] = [] self._next_packets: List[Tuple[bytes, Status]] = [] self.send_responses_after_packets: float = 1 self.output_exception: Optional[Exception] = None def last_request(self) -> packet_pb2.RpcPacket: assert self.requests return self.requests[-1] def _enqueue_response(self, channel_id: int, method=None, status: Status = Status.OK, payload=b'', *, ids: Tuple[int, int] = None, process_status=Status.OK) -> None: if method: assert ids is None service_id, method_id = method.service.id, method.id else: assert ids is not None and method is None service_id, method_id = ids self._next_packets.append((packet_pb2.RpcPacket( type=packet_pb2.PacketType.RESPONSE, channel_id=channel_id, service_id=service_id, method_id=method_id, status=status.value, payload=_message_bytes(payload)).SerializeToString(), process_status)) def _enqueue_server_stream(self, channel_id: int, method, response, process_status=Status.OK) -> None: self._next_packets.append((packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_STREAM, channel_id=channel_id, service_id=method.service.id, method_id=method.id, payload=_message_bytes(response)).SerializeToString(), process_status)) def _enqueue_error(self, channel_id: int, service, method, status: Status, process_status=Status.OK) -> None: self._next_packets.append((packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_ERROR, channel_id=channel_id, service_id=service if isinstance(service, int) else service.id, method_id=method if isinstance(method, int) else method.id, status=status.value).SerializeToString(), process_status)) def _handle_packet(self, data: bytes) -> None: if self.output_exception: raise self.output_exception # pylint: disable=raising-bad-type self.requests.append(packets.decode(data)) if self.send_responses_after_packets > 1: self.send_responses_after_packets -= 1 return self._process_enqueued_packets() def _process_enqueued_packets(self) -> None: # Set send_responses_after_packets to infinity to prevent potential # infinite recursion when a packet causes another packet to send. send_after_count = self.send_responses_after_packets self.send_responses_after_packets = float('inf') for packet, status in self._next_packets: self.assertIs(status, self._client.process_packet(packet)) self._next_packets.clear() self.send_responses_after_packets = send_after_count def _sent_payload(self, message_type: type) -> Any: message = message_type() message.ParseFromString(self.last_request().payload) return message class CallbackClientImplTest(_CallbackClientImplTestBase): """Tests the callback_client.Impl client implementation.""" def test_callback_exceptions_suppressed(self) -> None: stub = self._service.SomeUnary self._enqueue_response(1, stub.method) exception_msg = 'YOU BROKE IT O-]-<' with self.assertLogs(callback_client.__package__, 'ERROR') as logs: stub.invoke(self._request(), mock.Mock(side_effect=Exception(exception_msg))) self.assertIn(exception_msg, ''.join(logs.output)) # Make sure we can still invoke the RPC. self._enqueue_response(1, stub.method, Status.UNKNOWN) status, _ = stub() self.assertIs(status, Status.UNKNOWN) def test_ignore_bad_packets_with_pending_rpc(self) -> None: method = self._service.SomeUnary.method service_id = method.service.id # Unknown channel self._enqueue_response(999, method, process_status=Status.NOT_FOUND) # Bad service self._enqueue_response(1, ids=(999, method.id), process_status=Status.OK) # Bad method self._enqueue_response(1, ids=(service_id, 999), process_status=Status.OK) # For RPC not pending (is Status.OK because the packet is processed) self._enqueue_response(1, ids=(service_id, self._service.SomeBidiStreaming.method.id), process_status=Status.OK) self._enqueue_response(1, method, process_status=Status.OK) status, response = self._service.SomeUnary(magic_number=6) self.assertIs(Status.OK, status) self.assertEqual('', response.payload) def test_server_error_for_unknown_call_sends_no_errors(self) -> None: method = self._service.SomeUnary.method service_id = method.service.id # Unknown channel self._enqueue_error(999, service_id, method, Status.NOT_FOUND, process_status=Status.NOT_FOUND) # Bad service self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT) # Bad method self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT) # For RPC not pending self._enqueue_error(1, service_id, self._service.SomeBidiStreaming.method.id, Status.NOT_FOUND) self._process_enqueued_packets() self.assertEqual(self.requests, []) def test_exception_if_payload_fails_to_decode(self) -> None: method = self._service.SomeUnary.method self._enqueue_response(1, method, Status.OK, b'INVALID DATA!!!', process_status=Status.OK) with self.assertRaises(callback_client.RpcError) as context: self._service.SomeUnary(magic_number=6) self.assertIs(context.exception.status, Status.DATA_LOSS) def test_rpc_help_contains_method_name(self) -> None: rpc = self._service.SomeUnary self.assertIn(rpc.method.full_name, rpc.help()) def test_default_timeouts_set_on_impl(self) -> None: impl = callback_client.Impl(None, 1.5) self.assertEqual(impl.default_unary_timeout_s, None) self.assertEqual(impl.default_stream_timeout_s, 1.5) def test_default_timeouts_set_for_all_rpcs(self) -> None: rpc_client = client.Client.from_modules(callback_client.Impl( 99, 100), [client.Channel(1, lambda *a, **b: None)], self._protos.modules()) rpcs = rpc_client.channel(1).rpcs self.assertEqual( rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99) self.assertEqual( rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s, 100) self.assertEqual( rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s, 99) self.assertEqual( rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s, 100) def test_rpc_provides_request_type(self) -> None: self.assertIs(self._service.SomeUnary.request, self._service.SomeUnary.method.request_type) def test_rpc_provides_response_type(self) -> None: self.assertIs(self._service.SomeUnary.request, self._service.SomeUnary.method.request_type) class UnaryTest(_CallbackClientImplTestBase): """Tests for invoking a unary RPC.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeUnary self.method = self.rpc.method def test_blocking_call(self) -> None: for _ in range(3): self._enqueue_response(1, self.method, Status.ABORTED, self.method.response_type(payload='0_o')) status, response = self._service.SomeUnary( self.method.request_type(magic_number=6)) self.assertEqual( 6, self._sent_payload(self.method.request_type).magic_number) self.assertIs(Status.ABORTED, status) self.assertEqual('0_o', response.payload) def test_nonblocking_call(self) -> None: for _ in range(3): self._enqueue_response(1, self.method, Status.ABORTED, self.method.response_type(payload='0_o')) callback = mock.Mock() call = self.rpc.invoke(self._request(magic_number=5), callback, callback) callback.assert_has_calls([ mock.call(call, self.method.response_type(payload='0_o')), mock.call(call, Status.ABORTED) ]) self.assertEqual( 5, self._sent_payload(self.method.request_type).magic_number) def test_open(self) -> None: self.output_exception = IOError('something went wrong sending!') for _ in range(3): self._enqueue_response(1, self.method, Status.ABORTED, self.method.response_type(payload='0_o')) callback = mock.Mock() call = self.rpc.open(self._request(magic_number=5), callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls([ mock.call(call, self.method.response_type(payload='0_o')), mock.call(call, Status.ABORTED) ]) def test_blocking_server_error(self) -> None: for _ in range(3): self._enqueue_error(1, self.method.service, self.method, Status.NOT_FOUND) with self.assertRaises(callback_client.RpcError) as context: self._service.SomeUnary( self.method.request_type(magic_number=6)) self.assertIs(context.exception.status, Status.NOT_FOUND) def test_nonblocking_cancel(self) -> None: callback = mock.Mock() for _ in range(3): call = self._service.SomeUnary.invoke( self._request(magic_number=55), callback) self.assertGreater(len(self.requests), 0) self.requests.clear() self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) # Already cancelled, returns False # Unary RPCs do not send a cancel request to the server. self.assertFalse(self.requests) callback.assert_not_called() def test_nonblocking_with_request_args(self) -> None: self.rpc.invoke(request_args=dict(magic_number=1138)) self.assertEqual( self._sent_payload(self.rpc.request).magic_number, 1138) def test_blocking_timeout_as_argument(self) -> None: with self.assertRaises(callback_client.RpcTimeout): self._service.SomeUnary(pw_rpc_timeout_s=0.0001) def test_blocking_timeout_set_default(self) -> None: self._service.SomeUnary.default_timeout_s = 0.0001 with self.assertRaises(callback_client.RpcTimeout): self._service.SomeUnary() def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, Status.CANCELLED) self.assertFalse(second_call.completed()) def test_nonblocking_exception_in_callback(self) -> None: exception = ValueError('something went wrong!') self._enqueue_response(1, self.method, Status.OK) call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception)) with self.assertRaises(RuntimeError) as context: call.wait() self.assertEqual(context.exception.__cause__, exception) class ServerStreamingTest(_CallbackClientImplTestBase): """Tests for server streaming RPCs.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeServerStreaming self.method = self.rpc.method def test_blocking_call(self) -> None: rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): self._enqueue_server_stream(1, self.method, rep1) self._enqueue_server_stream(1, self.method, rep2) self._enqueue_response(1, self.method, Status.ABORTED) self.assertEqual( [rep1, rep2], self._service.SomeServerStreaming(magic_number=4).responses) self.assertEqual( 4, self._sent_payload(self.method.request_type).magic_number) def test_deprecated_packet_format(self) -> None: rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): # The original packet format used RESPONSE packets for the server # stream and a SERVER_STREAM_END packet as the last packet. These # are converted to SERVER_STREAM packets followed by a RESPONSE. self._enqueue_response(1, self.method, payload=rep1) self._enqueue_response(1, self.method, payload=rep2) self._next_packets.append((packet_pb2.RpcPacket( type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END, channel_id=1, service_id=self.method.service.id, method_id=self.method.id, status=Status.INVALID_ARGUMENT.value).SerializeToString(), Status.OK)) status, replies = self._service.SomeServerStreaming(magic_number=4) self.assertEqual([rep1, rep2], replies) self.assertIs(status, Status.INVALID_ARGUMENT) self.assertEqual( 4, self._sent_payload(self.method.request_type).magic_number) def test_nonblocking_call(self) -> None: rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): self._enqueue_server_stream(1, self.method, rep1) self._enqueue_server_stream(1, self.method, rep2) self._enqueue_response(1, self.method, Status.ABORTED) callback = mock.Mock() call = self.rpc.invoke(self._request(magic_number=3), callback, callback) callback.assert_has_calls([ mock.call(call, self.method.response_type(payload='!!!')), mock.call(call, self.method.response_type(payload='?')), mock.call(call, Status.ABORTED), ]) self.assertEqual( 3, self._sent_payload(self.method.request_type).magic_number) def test_open(self) -> None: self.output_exception = IOError('something went wrong sending!') rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): self._enqueue_server_stream(1, self.method, rep1) self._enqueue_server_stream(1, self.method, rep2) self._enqueue_response(1, self.method, Status.ABORTED) callback = mock.Mock() call = self.rpc.open(self._request(magic_number=3), callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls([ mock.call(call, self.method.response_type(payload='!!!')), mock.call(call, self.method.response_type(payload='?')), mock.call(call, Status.ABORTED), ]) def test_nonblocking_cancel(self) -> None: resp = self.rpc.method.response_type(payload='!!!') self._enqueue_server_stream(1, self.rpc.method, resp) callback = mock.Mock() call = self.rpc.invoke(self._request(magic_number=3), callback) callback.assert_called_once_with( call, self.rpc.method.response_type(payload='!!!')) callback.reset_mock() call.cancel() self.assertEqual(self.last_request().type, packet_pb2.PacketType.CLIENT_ERROR) self.assertEqual(self.last_request().status, Status.CANCELLED.value) # Ensure the RPC can be called after being cancelled. self._enqueue_server_stream(1, self.method, resp) self._enqueue_response(1, self.method, Status.OK) call = self.rpc.invoke(self._request(magic_number=3), callback, callback) callback.assert_has_calls([ mock.call(call, self.method.response_type(payload='!!!')), mock.call(call, Status.OK), ]) def test_nonblocking_with_request_args(self) -> None: self.rpc.invoke(request_args=dict(magic_number=1138)) self.assertEqual( self._sent_payload(self.rpc.request).magic_number, 1138) def test_blocking_timeout(self) -> None: with self.assertRaises(callback_client.RpcTimeout): self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001) def test_nonblocking_iteration_timeout(self) -> None: call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001) with self.assertRaises(callback_client.RpcTimeout): for _ in call: pass def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, Status.CANCELLED) self.assertFalse(second_call.completed()) def test_nonblocking_iterate_over_count(self) -> None: reply = self.method.response_type(payload='!?') for _ in range(4): self._enqueue_server_stream(1, self.method, reply) call = self.rpc.invoke() self.assertEqual(list(call.get_responses(count=1)), [reply]) self.assertEqual(next(iter(call)), reply) self.assertEqual(list(call.get_responses(count=2)), [reply, reply]) def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_server_stream(1, self.method, reply) self._enqueue_response(1, self.method, Status.OK) call = self.rpc.invoke() self.assertEqual(list(call.get_responses()), [reply]) self.assertEqual(list(call.get_responses()), []) self.assertEqual(list(call), []) class ClientStreamingTest(_CallbackClientImplTestBase): """Tests for client streaming RPCs.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeClientStreaming self.method = self.rpc.method def test_blocking_call(self) -> None: requests = [ self.method.request_type(magic_number=123), self.method.request_type(magic_number=456), ] # Send after len(requests) and the client stream end packet. self.send_responses_after_packets = 3 response = self.method.response_type(payload='yo') self._enqueue_response(1, self.method, Status.OK, response) results = self.rpc(requests) self.assertIs(results.status, Status.OK) self.assertEqual(results.response, response) def test_blocking_server_error(self) -> None: requests = [self.method.request_type(magic_number=123)] # Send after len(requests) and the client stream end packet. self._enqueue_error(1, self.method.service, self.method, Status.NOT_FOUND) with self.assertRaises(callback_client.RpcError) as context: self.rpc(requests) self.assertIs(context.exception.status, Status.NOT_FOUND) def test_nonblocking_call(self) -> None: """Tests a successful client streaming RPC ended by the server.""" payload_1 = self.method.response_type(payload='-_-') for _ in range(3): stream = self._service.SomeClientStreaming.invoke() self.assertFalse(stream.completed()) stream.send(magic_number=31) self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type) self.assertEqual( 31, self._sent_payload(self.method.request_type).magic_number) self.assertFalse(stream.completed()) # Enqueue the server response to be sent after the next message. self._enqueue_response(1, self.method, Status.OK, payload_1) stream.send(magic_number=32) self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type) self.assertEqual( 32, self._sent_payload(self.method.request_type).magic_number) self.assertTrue(stream.completed()) self.assertIs(Status.OK, stream.status) self.assertIsNone(stream.error) self.assertEqual(payload_1, stream.response) def test_open(self) -> None: self.output_exception = IOError('something went wrong sending!') payload = self.method.response_type(payload='-_-') for _ in range(3): self._enqueue_response(1, self.method, Status.OK, payload) callback = mock.Mock() call = self.rpc.open(callback, callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls([ mock.call(call, payload), mock.call(call, Status.OK), ]) def test_nonblocking_finish(self) -> None: """Tests a client streaming RPC ended by the client.""" payload_1 = self.method.response_type(payload='-_-') for _ in range(3): stream = self._service.SomeClientStreaming.invoke() self.assertFalse(stream.completed()) stream.send(magic_number=37) self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type) self.assertEqual( 37, self._sent_payload(self.method.request_type).magic_number) self.assertFalse(stream.completed()) # Enqueue the server response to be sent after the next message. self._enqueue_response(1, self.method, Status.OK, payload_1) stream.finish_and_wait() self.assertIs(packet_pb2.PacketType.CLIENT_STREAM_END, self.last_request().type) self.assertTrue(stream.completed()) self.assertIs(Status.OK, stream.status) self.assertIsNone(stream.error) self.assertEqual(payload_1, stream.response) def test_nonblocking_cancel(self) -> None: for _ in range(3): stream = self._service.SomeClientStreaming.invoke() stream.send(magic_number=37) self.assertTrue(stream.cancel()) self.assertIs(packet_pb2.PacketType.CLIENT_ERROR, self.last_request().type) self.assertIs(Status.CANCELLED.value, self.last_request().status) self.assertFalse(stream.cancel()) self.assertTrue(stream.completed()) self.assertIs(stream.error, Status.CANCELLED) def test_nonblocking_server_error(self) -> None: for _ in range(3): stream = self._service.SomeClientStreaming.invoke() self._enqueue_error(1, self.method.service, self.method, Status.INVALID_ARGUMENT) stream.send(magic_number=2**32 - 1) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) def test_nonblocking_server_error_after_stream_end(self) -> None: for _ in range(3): stream = self._service.SomeClientStreaming.invoke() # Error will be sent in response to the CLIENT_STREAM_END packet. self._enqueue_error(1, self.method.service, self.method, Status.INVALID_ARGUMENT) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) def test_nonblocking_send_after_cancelled(self) -> None: call = self._service.SomeClientStreaming.invoke() self.assertTrue(call.cancel()) with self.assertRaises(callback_client.RpcError) as context: call.send(payload='hello') self.assertIs(context.exception.status, Status.CANCELLED) def test_nonblocking_finish_after_completed(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_response(1, self.method, Status.UNAVAILABLE, reply) call = self.rpc.invoke() result = call.finish_and_wait() self.assertEqual(result.response, reply) self.assertEqual(result, call.finish_and_wait()) self.assertEqual(result, call.finish_and_wait()) def test_nonblocking_finish_after_error(self) -> None: self._enqueue_error(1, self.method.service, self.method, Status.UNAVAILABLE) call = self.rpc.invoke() for _ in range(3): with self.assertRaises(callback_client.RpcError) as context: call.finish_and_wait() self.assertIs(context.exception.status, Status.UNAVAILABLE) self.assertIs(call.error, Status.UNAVAILABLE) self.assertIsNone(call.response) def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, Status.CANCELLED) self.assertFalse(second_call.completed()) class BidirectionalStreamingTest(_CallbackClientImplTestBase): """Tests for bidirectional streaming RPCs.""" def setUp(self) -> None: super().setUp() self.rpc = self._service.SomeBidiStreaming self.method = self.rpc.method def test_blocking_call(self) -> None: requests = [ self.method.request_type(magic_number=123), self.method.request_type(magic_number=456), ] # Send after len(requests) and the client stream end packet. self.send_responses_after_packets = 3 self._enqueue_response(1, self.method, Status.NOT_FOUND) results = self.rpc(requests) self.assertIs(results.status, Status.NOT_FOUND) self.assertFalse(results.responses) def test_blocking_server_error(self) -> None: requests = [self.method.request_type(magic_number=123)] # Send after len(requests) and the client stream end packet. self._enqueue_error(1, self.method.service, self.method, Status.NOT_FOUND) with self.assertRaises(callback_client.RpcError) as context: self.rpc(requests) self.assertIs(context.exception.status, Status.NOT_FOUND) def test_nonblocking_call(self) -> None: """Tests a bidirectional streaming RPC ended by the server.""" rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): responses: list = [] stream = self._service.SomeBidiStreaming.invoke( lambda _, res, responses=responses: responses.append(res)) self.assertFalse(stream.completed()) stream.send(magic_number=55) self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type) self.assertEqual( 55, self._sent_payload(self.method.request_type).magic_number) self.assertFalse(stream.completed()) self.assertEqual([], responses) self._enqueue_server_stream(1, self.method, rep1) self._enqueue_server_stream(1, self.method, rep2) stream.send(magic_number=66) self.assertIs(packet_pb2.PacketType.CLIENT_STREAM, self.last_request().type) self.assertEqual( 66, self._sent_payload(self.method.request_type).magic_number) self.assertFalse(stream.completed()) self.assertEqual([rep1, rep2], responses) self._enqueue_response(1, self.method, Status.OK) stream.send(magic_number=77) self.assertTrue(stream.completed()) self.assertEqual([rep1, rep2], responses) self.assertIs(Status.OK, stream.status) self.assertIsNone(stream.error) def test_open(self) -> None: self.output_exception = IOError('something went wrong sending!') rep1 = self.method.response_type(payload='!!!') rep2 = self.method.response_type(payload='?') for _ in range(3): self._enqueue_server_stream(1, self.method, rep1) self._enqueue_server_stream(1, self.method, rep2) self._enqueue_response(1, self.method, Status.OK) callback = mock.Mock() call = self.rpc.open(callback, callback, callback) self.assertEqual(self.requests, []) self._process_enqueued_packets() callback.assert_has_calls([ mock.call(call, self.method.response_type(payload='!!!')), mock.call(call, self.method.response_type(payload='?')), mock.call(call, Status.OK), ]) @mock.patch('pw_rpc.callback_client.call.Call._default_response') def test_nonblocking(self, callback) -> None: """Tests a bidirectional streaming RPC ended by the server.""" reply = self.method.response_type(payload='This is the payload!') self._enqueue_server_stream(1, self.method, reply) self._service.SomeBidiStreaming.invoke() callback.assert_called_once_with(mock.ANY, reply) def test_nonblocking_server_error(self) -> None: rep1 = self.method.response_type(payload='!!!') for _ in range(3): responses: list = [] stream = self._service.SomeBidiStreaming.invoke( lambda _, res, responses=responses: responses.append(res)) self.assertFalse(stream.completed()) self._enqueue_server_stream(1, self.method, rep1) stream.send(magic_number=55) self.assertFalse(stream.completed()) self.assertEqual([rep1], responses) self._enqueue_error(1, self.method.service, self.method, Status.OUT_OF_RANGE) stream.send(magic_number=99999) self.assertTrue(stream.completed()) self.assertEqual([rep1], responses) self.assertIsNone(stream.status) self.assertIs(Status.OUT_OF_RANGE, stream.error) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.OUT_OF_RANGE) def test_nonblocking_server_error_after_stream_end(self) -> None: for _ in range(3): stream = self._service.SomeBidiStreaming.invoke() # Error will be sent in response to the CLIENT_STREAM_END packet. self._enqueue_error(1, self.method.service, self.method, Status.INVALID_ARGUMENT) with self.assertRaises(callback_client.RpcError) as context: stream.finish_and_wait() self.assertIs(context.exception.status, Status.INVALID_ARGUMENT) def test_nonblocking_send_after_cancelled(self) -> None: call = self._service.SomeBidiStreaming.invoke() self.assertTrue(call.cancel()) with self.assertRaises(callback_client.RpcError) as context: call.send(payload='hello') self.assertIs(context.exception.status, Status.CANCELLED) def test_nonblocking_finish_after_completed(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_server_stream(1, self.method, reply) self._enqueue_response(1, self.method, Status.UNAVAILABLE) call = self.rpc.invoke() result = call.finish_and_wait() self.assertEqual(result.responses, [reply]) self.assertEqual(result, call.finish_and_wait()) self.assertEqual(result, call.finish_and_wait()) def test_nonblocking_finish_after_error(self) -> None: reply = self.method.response_type(payload='!?') self._enqueue_server_stream(1, self.method, reply) self._enqueue_error(1, self.method.service, self.method, Status.UNAVAILABLE) call = self.rpc.invoke() for _ in range(3): with self.assertRaises(callback_client.RpcError) as context: call.finish_and_wait() self.assertIs(context.exception.status, Status.UNAVAILABLE) self.assertIs(call.error, Status.UNAVAILABLE) self.assertEqual(call.responses, [reply]) def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None: first_call = self.rpc.invoke() self.assertFalse(first_call.completed()) second_call = self.rpc.invoke() self.assertIs(first_call.error, Status.CANCELLED) self.assertFalse(second_call.completed()) if __name__ == '__main__': unittest.main()