#!/usr/bin/env python3 # Copyright 2022 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Tests for the transfer service client.""" import enum import math import unittest from typing import Iterable, List from pw_status import Status from pw_rpc import callback_client, client, ids, packets from pw_rpc.internal import packet_pb2 import pw_transfer from pw_transfer.transfer_pb2 import Chunk _TRANSFER_SERVICE_ID = ids.calculate('pw.transfer.Transfer') # If the default timeout is too short, some tests become flaky on Windows. DEFAULT_TIMEOUT_S = 0.3 class _Method(enum.Enum): READ = ids.calculate('Read') WRITE = ids.calculate('Write') class TransferManagerTest(unittest.TestCase): """Tests for the transfer manager.""" def setUp(self) -> None: self._client = client.Client.from_modules( callback_client.Impl(), [client.Channel(1, self._handle_request)], (pw_transfer.transfer_pb2, )) self._service = self._client.channel(1).rpcs.pw.transfer.Transfer self._sent_chunks: List[Chunk] = [] self._packets_to_send: List[List[bytes]] = [] def _enqueue_server_responses( self, method: _Method, responses: Iterable[Iterable[Chunk]]) -> None: for group in responses: serialized_group = [] for response in group: serialized_group.append( packet_pb2.RpcPacket( type=packet_pb2.PacketType.SERVER_STREAM, channel_id=1, service_id=_TRANSFER_SERVICE_ID, method_id=method.value, status=Status.OK.value, payload=response.SerializeToString()). SerializeToString()) self._packets_to_send.append(serialized_group) def _enqueue_server_error(self, method: _Method, error: Status) -> None: self._packets_to_send.append([ packet_pb2.RpcPacket(type=packet_pb2.PacketType.SERVER_ERROR, channel_id=1, service_id=_TRANSFER_SERVICE_ID, method_id=method.value, status=error.value).SerializeToString() ]) def _handle_request(self, data: bytes) -> None: packet = packets.decode(data) if packet.type is not packet_pb2.PacketType.CLIENT_STREAM: return chunk = Chunk() chunk.MergeFromString(packet.payload) self._sent_chunks.append(chunk) if self._packets_to_send: responses = self._packets_to_send.pop(0) for response in responses: self._client.process_packet(response) def _received_data(self) -> bytearray: data = bytearray() for chunk in self._sent_chunks: data.extend(chunk.data) return data def test_read_transfer_basic(self): manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.READ, ((Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=0), ), ), ) data = manager.read(3) self.assertEqual(data, b'abc') self.assertEqual(len(self._sent_chunks), 2) self.assertTrue(self._sent_chunks[-1].HasField('status')) self.assertEqual(self._sent_chunks[-1].status, 0) def test_read_transfer_multichunk(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.READ, (( Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=3), Chunk(transfer_id=3, offset=3, data=b'def', remaining_bytes=0), ), ), ) data = manager.read(3) self.assertEqual(data, b'abcdef') self.assertEqual(len(self._sent_chunks), 2) self.assertTrue(self._sent_chunks[-1].HasField('status')) self.assertEqual(self._sent_chunks[-1].status, 0) def test_read_transfer_progress_callback(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.READ, (( Chunk(transfer_id=3, offset=0, data=b'abc', remaining_bytes=3), Chunk(transfer_id=3, offset=3, data=b'def', remaining_bytes=0), ), ), ) progress: List[pw_transfer.ProgressStats] = [] data = manager.read(3, progress.append) self.assertEqual(data, b'abcdef') self.assertEqual(len(self._sent_chunks), 2) self.assertTrue(self._sent_chunks[-1].HasField('status')) self.assertEqual(self._sent_chunks[-1].status, 0) self.assertEqual(progress, [ pw_transfer.ProgressStats(3, 3, 6), pw_transfer.ProgressStats(6, 6, 6), ]) def test_read_transfer_retry_bad_offset(self) -> None: """Server responds with an unexpected offset in a read transfer.""" manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.READ, ( ( Chunk(transfer_id=3, offset=0, data=b'123', remaining_bytes=6), # Incorrect offset; expecting 3. Chunk(transfer_id=3, offset=1, data=b'456', remaining_bytes=3), ), ( Chunk(transfer_id=3, offset=3, data=b'456', remaining_bytes=3), Chunk(transfer_id=3, offset=6, data=b'789', remaining_bytes=0), ), )) data = manager.read(3) self.assertEqual(data, b'123456789') # Two transfer parameter requests should have been sent. self.assertEqual(len(self._sent_chunks), 3) self.assertTrue(self._sent_chunks[-1].HasField('status')) self.assertEqual(self._sent_chunks[-1].status, 0) def test_read_transfer_retry_timeout(self) -> None: """Server doesn't respond to read transfer parameters.""" manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.READ, ( (), # Send nothing in response to the initial parameters. (Chunk(transfer_id=3, offset=0, data=b'xyz', remaining_bytes=0), ), )) data = manager.read(3) self.assertEqual(data, b'xyz') # Two transfer parameter requests should have been sent. self.assertEqual(len(self._sent_chunks), 3) self.assertTrue(self._sent_chunks[-1].HasField('status')) self.assertEqual(self._sent_chunks[-1].status, 0) def test_read_transfer_timeout(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) with self.assertRaises(pw_transfer.Error) as context: manager.read(27) exception = context.exception self.assertEqual(exception.transfer_id, 27) self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED) # The client should have sent four transfer parameters requests: one # initial, and three retries. self.assertEqual(len(self._sent_chunks), 4) def test_read_transfer_error(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.READ, ((Chunk(transfer_id=31, status=Status.NOT_FOUND.value), ), ), ) with self.assertRaises(pw_transfer.Error) as context: manager.read(31) exception = context.exception self.assertEqual(exception.transfer_id, 31) self.assertEqual(exception.status, Status.NOT_FOUND) def test_read_transfer_server_error(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_error(_Method.READ, Status.NOT_FOUND) with self.assertRaises(pw_transfer.Error) as context: manager.read(31) exception = context.exception self.assertEqual(exception.transfer_id, 31) self.assertEqual(exception.status, Status.INTERNAL) def test_write_transfer_basic(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ( (Chunk(transfer_id=4, offset=0, pending_bytes=32, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, status=Status.OK.value), ), ), ) manager.write(4, b'hello') self.assertEqual(len(self._sent_chunks), 2) self.assertEqual(self._received_data(), b'hello') def test_write_transfer_max_chunk_size(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ( (Chunk(transfer_id=4, offset=0, pending_bytes=32, max_chunk_size_bytes=8), ), (), (Chunk(transfer_id=4, status=Status.OK.value), ), ), ) manager.write(4, b'hello world') self.assertEqual(len(self._sent_chunks), 3) self.assertEqual(self._received_data(), b'hello world') self.assertEqual(self._sent_chunks[1].data, b'hello wo') self.assertEqual(self._sent_chunks[2].data, b'rld') def test_write_transfer_multiple_parameters(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ( (Chunk(transfer_id=4, offset=0, pending_bytes=8, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, offset=8, pending_bytes=8, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, status=Status.OK.value), ), ), ) manager.write(4, b'data to write') self.assertEqual(len(self._sent_chunks), 3) self.assertEqual(self._received_data(), b'data to write') self.assertEqual(self._sent_chunks[1].data, b'data to ') self.assertEqual(self._sent_chunks[2].data, b'write') def test_write_transfer_progress_callback(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ( (Chunk(transfer_id=4, offset=0, pending_bytes=8, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, offset=8, pending_bytes=8, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, status=Status.OK.value), ), ), ) progress: List[pw_transfer.ProgressStats] = [] manager.write(4, b'data to write', progress.append) self.assertEqual(len(self._sent_chunks), 3) self.assertEqual(self._received_data(), b'data to write') self.assertEqual(self._sent_chunks[1].data, b'data to ') self.assertEqual(self._sent_chunks[2].data, b'write') self.assertEqual(progress, [ pw_transfer.ProgressStats(8, 0, 13), pw_transfer.ProgressStats(13, 8, 13), pw_transfer.ProgressStats(13, 13, 13) ]) def test_write_transfer_rewind(self) -> None: """Write transfer in which the server re-requests an earlier offset.""" manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ( (Chunk(transfer_id=4, offset=0, pending_bytes=8, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, offset=8, pending_bytes=8, max_chunk_size_bytes=8), ), ( Chunk( transfer_id=4, offset=4, # rewind pending_bytes=8, max_chunk_size_bytes=8), ), ( Chunk( transfer_id=4, offset=12, pending_bytes=16, # update max size max_chunk_size_bytes=16), ), (Chunk(transfer_id=4, status=Status.OK.value), ), ), ) manager.write(4, b'pigweed data transfer') self.assertEqual(len(self._sent_chunks), 5) self.assertEqual(self._sent_chunks[1].data, b'pigweed ') self.assertEqual(self._sent_chunks[2].data, b'data tra') self.assertEqual(self._sent_chunks[3].data, b'eed data') self.assertEqual(self._sent_chunks[4].data, b' transfer') def test_write_transfer_bad_offset(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ( (Chunk(transfer_id=4, offset=0, pending_bytes=8, max_chunk_size_bytes=8), ), ( Chunk( transfer_id=4, offset=100, # larger offset than data pending_bytes=8, max_chunk_size_bytes=8), ), (Chunk(transfer_id=4, status=Status.OK.value), ), ), ) with self.assertRaises(pw_transfer.Error) as context: manager.write(4, b'small data') exception = context.exception self.assertEqual(exception.transfer_id, 4) self.assertEqual(exception.status, Status.OUT_OF_RANGE) def test_write_transfer_error(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ((Chunk(transfer_id=21, status=Status.UNAVAILABLE.value), ), ), ) with self.assertRaises(pw_transfer.Error) as context: manager.write(21, b'no write') exception = context.exception self.assertEqual(exception.transfer_id, 21) self.assertEqual(exception.status, Status.UNAVAILABLE) def test_write_transfer_server_error(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_error(_Method.WRITE, Status.NOT_FOUND) with self.assertRaises(pw_transfer.Error) as context: manager.write(21, b'server error') exception = context.exception self.assertEqual(exception.transfer_id, 21) self.assertEqual(exception.status, Status.INTERNAL) def test_write_transfer_timeout_after_initial_chunk(self) -> None: manager = pw_transfer.Manager(self._service, default_response_timeout_s=0.001, max_retries=2) with self.assertRaises(pw_transfer.Error) as context: manager.write(22, b'no server response!') self.assertEqual( self._sent_chunks, [ Chunk(transfer_id=22, type=Chunk.Type.TRANSFER_START), # initial chunk Chunk(transfer_id=22, type=Chunk.Type.TRANSFER_START), # retry 1 Chunk(transfer_id=22, type=Chunk.Type.TRANSFER_START), # retry 2 ]) exception = context.exception self.assertEqual(exception.transfer_id, 22) self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED) def test_write_transfer_timeout_after_intermediate_chunk(self) -> None: """Tests write transfers that timeout after the initial chunk.""" manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S, max_retries=2) self._enqueue_server_responses( _Method.WRITE, [[Chunk(transfer_id=22, pending_bytes=10, max_chunk_size_bytes=5)] ]) with self.assertRaises(pw_transfer.Error) as context: manager.write(22, b'0123456789') last_data_chunk = Chunk(transfer_id=22, data=b'56789', offset=5, remaining_bytes=0, type=Chunk.Type.TRANSFER_DATA) self.assertEqual( self._sent_chunks, [ Chunk(transfer_id=22, type=Chunk.Type.TRANSFER_START), Chunk(transfer_id=22, data=b'01234', type=Chunk.Type.TRANSFER_DATA), last_data_chunk, # last chunk last_data_chunk, # retry 1 last_data_chunk, # retry 2 ]) exception = context.exception self.assertEqual(exception.transfer_id, 22) self.assertEqual(exception.status, Status.DEADLINE_EXCEEDED) def test_write_zero_pending_bytes_is_internal_error(self) -> None: manager = pw_transfer.Manager( self._service, default_response_timeout_s=DEFAULT_TIMEOUT_S) self._enqueue_server_responses( _Method.WRITE, ((Chunk(transfer_id=23, pending_bytes=0), ), ), ) with self.assertRaises(pw_transfer.Error) as context: manager.write(23, b'no write') exception = context.exception self.assertEqual(exception.transfer_id, 23) self.assertEqual(exception.status, Status.INTERNAL) class ProgressStatsTest(unittest.TestCase): def test_received_percent_known_total(self) -> None: self.assertEqual( pw_transfer.ProgressStats(75, 0, 100).percent_received(), 0.0) self.assertEqual( pw_transfer.ProgressStats(75, 50, 100).percent_received(), 50.0) self.assertEqual( pw_transfer.ProgressStats(100, 100, 100).percent_received(), 100.0) def test_received_percent_unknown_total(self) -> None: self.assertTrue( math.isnan( pw_transfer.ProgressStats(75, 50, None).percent_received())) self.assertTrue( math.isnan( pw_transfer.ProgressStats(100, 100, None).percent_received())) def test_str_known_total(self) -> None: stats = str(pw_transfer.ProgressStats(75, 50, 100)) self.assertIn('75', stats) self.assertIn('50', stats) self.assertIn('100', stats) def test_str_unknown_total(self) -> None: stats = str(pw_transfer.ProgressStats(75, 50, None)) self.assertIn('75', stats) self.assertIn('50', stats) self.assertIn('unknown', stats) if __name__ == '__main__': unittest.main()