132 lines
5.4 KiB
Python
132 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2021 The Pigweed Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
# use this file except in compliance with the License. You may obtain a copy of
|
|
# the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations under
|
|
# the License.
|
|
"""Tests decoding a proto with tokenized fields."""
|
|
|
|
import base64
|
|
import unittest
|
|
|
|
from pw_tokenizer_tests.detokenize_proto_test_pb2 import TheMessage
|
|
|
|
from pw_tokenizer import detokenize, encode, tokens
|
|
from pw_tokenizer.proto import detokenize_fields, decode_optionally_tokenized
|
|
|
|
_DATABASE = tokens.Database([
|
|
tokens.TokenizedStringEntry(0xAABBCCDD, "Luke, we're gonna have %s"),
|
|
tokens.TokenizedStringEntry(0x12345678, "This string has a $oeQAAA=="),
|
|
tokens.TokenizedStringEntry(0x0000e4a1, "recursive token"),
|
|
])
|
|
_DETOKENIZER = detokenize.Detokenizer(_DATABASE)
|
|
|
|
|
|
class TestDetokenizeProtoFields(unittest.TestCase):
|
|
"""Tests detokenizing optionally tokenized proto fields."""
|
|
def test_plain_text(self) -> None:
|
|
proto = TheMessage(message=b'boring conversation anyway!')
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message, b'boring conversation anyway!')
|
|
|
|
def test_binary(self) -> None:
|
|
proto = TheMessage(message=b'\xDD\xCC\xBB\xAA\x07company')
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message, b"Luke, we're gonna have company")
|
|
|
|
def test_binary_missing_arguments(self) -> None:
|
|
proto = TheMessage(message=b'\xDD\xCC\xBB\xAA')
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message, b"Luke, we're gonna have %s")
|
|
|
|
def test_recursive_binary(self) -> None:
|
|
proto = TheMessage(message=b'\x78\x56\x34\x12')
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message, b"This string has a recursive token")
|
|
|
|
def test_base64(self) -> None:
|
|
base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x07company')
|
|
proto = TheMessage(message=base64_msg.encode())
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message, b"Luke, we're gonna have company")
|
|
|
|
def test_recursive_base64(self) -> None:
|
|
base64_msg = encode.prefixed_base64(b'\x78\x56\x34\x12')
|
|
proto = TheMessage(message=base64_msg.encode())
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message, b"This string has a recursive token")
|
|
|
|
def test_plain_text_with_prefixed_base64(self) -> None:
|
|
base64_msg = encode.prefixed_base64(b'\xDD\xCC\xBB\xAA\x09pancakes!')
|
|
proto = TheMessage(message=f'Good morning, {base64_msg}'.encode())
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message,
|
|
b"Good morning, Luke, we're gonna have pancakes!")
|
|
|
|
def test_unknown_token_not_utf8(self) -> None:
|
|
proto = TheMessage(message=b'\xFE\xED\xF0\x0D')
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message.decode(),
|
|
encode.prefixed_base64(b'\xFE\xED\xF0\x0D'))
|
|
|
|
def test_only_control_characters(self) -> None:
|
|
proto = TheMessage(message=b'\1\2\3\4')
|
|
detokenize_fields(_DETOKENIZER, proto)
|
|
self.assertEqual(proto.message.decode(),
|
|
encode.prefixed_base64(b'\1\2\3\4'))
|
|
|
|
|
|
class TestDecodeOptionallyTokenized(unittest.TestCase):
|
|
"""Tests optional detokenization directly."""
|
|
def setUp(self):
|
|
self.detok = detokenize.Detokenizer(
|
|
tokens.Database([
|
|
tokens.TokenizedStringEntry(0, 'cheese'),
|
|
tokens.TokenizedStringEntry(1, 'on pizza'),
|
|
tokens.TokenizedStringEntry(2, 'is quite good'),
|
|
tokens.TokenizedStringEntry(3, 'they say'),
|
|
]))
|
|
|
|
def test_found_binary_token(self):
|
|
self.assertEqual(
|
|
'on pizza',
|
|
decode_optionally_tokenized(self.detok, b'\x01\x00\x00\x00'))
|
|
|
|
def test_missing_binary_token(self):
|
|
self.assertEqual(
|
|
'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A').decode(),
|
|
decode_optionally_tokenized(self.detok, b'\xD5\x8A\xF9\x2A\x8A'))
|
|
|
|
def test_found_b64_token(self):
|
|
b64_bytes = b'$' + base64.b64encode(b'\x03\x00\x00\x00')
|
|
self.assertEqual('they say',
|
|
decode_optionally_tokenized(self.detok, b64_bytes))
|
|
|
|
def test_missing_b64_token(self):
|
|
b64_bytes = b'$' + base64.b64encode(b'\xD5\x8A\xF9\x2A\x8A')
|
|
self.assertEqual(b64_bytes.decode(),
|
|
decode_optionally_tokenized(self.detok, b64_bytes))
|
|
|
|
def test_found_alternate_prefix(self):
|
|
b64_bytes = b'~' + base64.b64encode(b'\x00\x00\x00\x00')
|
|
self.assertEqual(
|
|
'cheese', decode_optionally_tokenized(self.detok, b64_bytes, b'~'))
|
|
|
|
def test_missing_alternate_prefix(self):
|
|
b64_bytes = b'~' + base64.b64encode(b'\x02\x00\x00\x00')
|
|
self.assertEqual(
|
|
b64_bytes.decode(),
|
|
decode_optionally_tokenized(self.detok, b64_bytes, b'^'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|