#!/usr/bin/env python
#
# Copyright (C) 2022 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for verify_overlaps_test.py."""
import io
import unittest

from signature_trie import InteriorNode
from signature_trie import signature_trie


class TestSignatureToElements(unittest.TestCase):

    @staticmethod
    def signature_to_elements(signature):
        return InteriorNode.signature_to_elements(signature)

    @staticmethod
    def elements_to_signature(elements):
        return InteriorNode.elements_to_selector(elements)

    def test_nested_inner_classes(self):
        elements = [
            ("package", "java"),
            ("package", "lang"),
            ("class", "ProcessBuilder"),
            ("class", "Redirect"),
            ("class", "1"),
            ("member", "<init>()V"),
        ]
        signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_basic_member(self):
        elements = [
            ("package", "java"),
            ("package", "lang"),
            ("class", "Object"),
            ("member", "hashCode()I"),
        ]
        signature = "Ljava/lang/Object;->hashCode()I"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_double_dollar_class(self):
        elements = [
            ("package", "java"),
            ("package", "lang"),
            ("class", "CharSequence"),
            ("class", ""),
            ("class", "ExternalSyntheticLambda0"),
            ("member", "<init>(Ljava/lang/CharSequence;)V"),
        ]
        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
                    "-><init>(Ljava/lang/CharSequence;)V"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_no_member(self):
        elements = [
            ("package", "java"),
            ("package", "lang"),
            ("class", "CharSequence"),
            ("class", ""),
            ("class", "ExternalSyntheticLambda0"),
        ]
        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_wildcard(self):
        elements = [
            ("package", "java"),
            ("package", "lang"),
            ("wildcard", "*"),
        ]
        signature = "java/lang/*"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_recursive_wildcard(self):
        elements = [
            ("package", "java"),
            ("package", "lang"),
            ("wildcard", "**"),
        ]
        signature = "java/lang/**"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_no_packages_wildcard(self):
        elements = [
            ("wildcard", "*"),
        ]
        signature = "*"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_no_packages_recursive_wildcard(self):
        elements = [
            ("wildcard", "**"),
        ]
        signature = "**"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_invalid_no_class_or_wildcard(self):
        signature = "java/lang"
        with self.assertRaises(Exception) as context:
            self.signature_to_elements(signature)
        self.assertIn(
            "last element 'lang' is lower case but should be an "
            "upper case class name or wildcard", str(context.exception))

    def test_non_standard_class_name(self):
        elements = [
            ("package", "javax"),
            ("package", "crypto"),
            ("class", "extObjectInputStream"),
        ]
        signature = "Ljavax/crypto/extObjectInputStream"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_invalid_pattern_wildcard(self):
        pattern = "Ljava/lang/Class*"
        with self.assertRaises(Exception) as context:
            self.signature_to_elements(pattern)
        self.assertIn("invalid wildcard 'Class*'", str(context.exception))

    def test_invalid_pattern_wildcard_and_member(self):
        pattern = "Ljava/lang/*;->hashCode()I"
        with self.assertRaises(Exception) as context:
            self.signature_to_elements(pattern)
        self.assertIn(
            "contains wildcard '*' and member signature 'hashCode()I'",
            str(context.exception))


class TestValues(unittest.TestCase):
    def test_add_then_get(self):
        trie = signature_trie()
        trie.add("La/b/C;->l()", 1)
        trie.add("La/b/C$D;->m()", "A")
        trie.add("La/b/C$D;->n()", {})

        package_a_node = next(iter(trie.child_nodes()))
        self.assertEqual("package", package_a_node.type)
        self.assertEqual("a", package_a_node.selector)

        package_b_node = next(iter(package_a_node.child_nodes()))
        self.assertEqual("package", package_b_node.type)
        self.assertEqual("a/b", package_b_node.selector)

        class_c_node = next(iter(package_b_node.child_nodes()))
        self.assertEqual("class", class_c_node.type)
        self.assertEqual("a/b/C", class_c_node.selector)

        self.assertEqual([1, "A", {}], class_c_node.values(lambda _: True))

class TestGetMatchingRows(unittest.TestCase):
    extractInput = """
Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
Ljava/lang/Character;->serialVersionUID:J
Ljava/lang/Object;->hashCode()I
Ljava/lang/Object;->toString()Ljava/lang/String;
Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V
Ljava/util/zip/ZipFile;-><clinit>()V
"""

    def read_trie(self):
        trie = signature_trie()
        with io.StringIO(self.extractInput.strip()) as f:
            for line in iter(f.readline, ""):
                line = line.rstrip()
                trie.add(line, line)
        return trie

    def check_patterns(self, pattern, expected):
        trie = self.read_trie()
        self.check_node_patterns(trie, pattern, expected)

    def check_node_patterns(self, node, pattern, expected):
        actual = list(node.get_matching_rows(pattern))
        actual.sort()
        self.assertEqual(expected, actual)

    def test_member_pattern(self):
        self.check_patterns("java/util/zip/ZipFile;-><clinit>()V",
                            ["Ljava/util/zip/ZipFile;-><clinit>()V"])

    def test_class_pattern(self):
        self.check_patterns("java/lang/Object", [
            "Ljava/lang/Object;->hashCode()I",
            "Ljava/lang/Object;->toString()Ljava/lang/String;",
        ])

    # pylint: disable=line-too-long
    def test_nested_class_pattern(self):
        self.check_patterns("java/lang/Character", [
            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
            "Ljava/lang/Character;->serialVersionUID:J",
        ])

    def test_wildcard(self):
        self.check_patterns("java/lang/*", [
            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
            "Ljava/lang/Character;->serialVersionUID:J",
            "Ljava/lang/Object;->hashCode()I",
            "Ljava/lang/Object;->toString()Ljava/lang/String;",
            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
        ])

    def test_recursive_wildcard(self):
        self.check_patterns("java/**", [
            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
            "Ljava/lang/Character;->serialVersionUID:J",
            "Ljava/lang/Object;->hashCode()I",
            "Ljava/lang/Object;->toString()Ljava/lang/String;",
            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
            "Ljava/util/zip/ZipFile;-><clinit>()V",
        ])

    def test_node_wildcard(self):
        trie = self.read_trie()
        node = list(trie.child_nodes())[0]
        self.check_node_patterns(node, "**", [
            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
            "Ljava/lang/Character;->serialVersionUID:J",
            "Ljava/lang/Object;->hashCode()I",
            "Ljava/lang/Object;->toString()Ljava/lang/String;",
            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
            "Ljava/util/zip/ZipFile;-><clinit>()V",
        ])

    # pylint: enable=line-too-long


if __name__ == "__main__":
    unittest.main(verbosity=2)