android13/external/libtextclassifier/native/utils/grammar/utils/ir_test.cc

244 lines
8.6 KiB
C++

/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/grammar/utils/ir.h"
#include "utils/grammar/rules_generated.h"
#include "utils/grammar/types.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3::grammar {
namespace {
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::Ne;
using ::testing::SizeIs;
TEST(IrTest, HandlesSharingWithTerminalRules) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Ir ir(locale_shard_map);
// <t1> ::= the
const Nonterm t1 = ir.Add(kUnassignedNonterm, "the");
// <t2> ::= quick
const Nonterm t2 = ir.Add(kUnassignedNonterm, "quick");
// <t3> ::= quick -- should share with <t2>
const Nonterm t3 = ir.Add(kUnassignedNonterm, "quick");
// <t4> ::= quick -- specify unshareable <t4>
// <t4> ::= brown
const Nonterm t4_unshareable = ir.AddUnshareableNonterminal();
ir.Add(t4_unshareable, "quick");
ir.Add(t4_unshareable, "brown");
// <t5> ::= brown -- should not be shared with <t4>
const Nonterm t5 = ir.Add(kUnassignedNonterm, "brown");
// <t6> ::= brown -- specify unshareable <t6>
const Nonterm t6_unshareable = ir.AddUnshareableNonterminal();
ir.Add(t6_unshareable, "brown");
// <t7> ::= brown -- should share with <t5>
const Nonterm t7 = ir.Add(kUnassignedNonterm, "brown");
EXPECT_THAT(t1, Ne(kUnassignedNonterm));
EXPECT_THAT(t2, Ne(kUnassignedNonterm));
EXPECT_THAT(t1, Ne(t2));
EXPECT_THAT(t2, Eq(t3));
EXPECT_THAT(t4_unshareable, Ne(kUnassignedNonterm));
EXPECT_THAT(t4_unshareable, Ne(t3));
EXPECT_THAT(t4_unshareable, Ne(t5));
EXPECT_THAT(t6_unshareable, Ne(kUnassignedNonterm));
EXPECT_THAT(t6_unshareable, Ne(t4_unshareable));
EXPECT_THAT(t6_unshareable, Ne(t5));
EXPECT_THAT(t7, Eq(t5));
}
TEST(IrTest, HandlesSharingWithNonterminalRules) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Ir ir(locale_shard_map);
// Setup a few terminal rules.
const std::vector<Nonterm> rhs = {
ir.Add(kUnassignedNonterm, "the"), ir.Add(kUnassignedNonterm, "quick"),
ir.Add(kUnassignedNonterm, "brown"), ir.Add(kUnassignedNonterm, "fox")};
// Check for proper sharing using nonterminal rules.
for (int rhs_length = 1; rhs_length <= rhs.size(); rhs_length++) {
std::vector<Nonterm> rhs_truncated = rhs;
rhs_truncated.resize(rhs_length);
const Nonterm nt_u = ir.AddUnshareableNonterminal();
ir.Add(nt_u, rhs_truncated);
const Nonterm nt_1 = ir.Add(kUnassignedNonterm, rhs_truncated);
const Nonterm nt_2 = ir.Add(kUnassignedNonterm, rhs_truncated);
EXPECT_THAT(nt_1, Eq(nt_2));
EXPECT_THAT(nt_1, Ne(nt_u));
}
}
TEST(IrTest, HandlesSharingWithCallbacksWithSameParameters) {
// Test sharing in the presence of callbacks.
constexpr CallbackId kOutput1 = 1;
constexpr CallbackId kOutput2 = 2;
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Ir ir(locale_shard_map);
const Nonterm x1 = ir.Add(kUnassignedNonterm, "hello");
const Nonterm x2 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput1, 0}}, "hello");
const Nonterm x3 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
// Duplicate entry.
const Nonterm x4 =
ir.Add(Ir::Lhs{kUnassignedNonterm, {kOutput2, 0}}, "hello");
EXPECT_THAT(x2, Eq(x1));
EXPECT_THAT(x3, Eq(x1));
EXPECT_THAT(x4, Eq(x1));
}
TEST(IrTest, SerializesRulesToFlatbufferFormat) {
constexpr CallbackId kOutput = 1;
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Ir ir(locale_shard_map);
const Nonterm verb = ir.AddUnshareableNonterminal();
ir.Add(verb, "buy");
ir.Add(Ir::Lhs{verb, {kOutput}}, "bring");
ir.Add(verb, "upbring");
ir.Add(verb, "remind");
const Nonterm set_reminder = ir.AddUnshareableNonterminal();
ir.Add(set_reminder,
std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "remind"),
ir.Add(kUnassignedNonterm, "me"),
ir.Add(kUnassignedNonterm, "to"), verb});
const Nonterm action = ir.AddUnshareableNonterminal();
ir.Add(action, set_reminder);
RulesSetT rules;
ir.Serialize(/*include_debug_information=*/false, &rules);
EXPECT_THAT(rules.rules, SizeIs(1));
// Only one rule uses a callback, the rest will be encoded directly.
EXPECT_THAT(rules.lhs, SizeIs(1));
EXPECT_THAT(rules.lhs.front().callback_id(), kOutput);
// 6 distinct terminals: "buy", "upbring", "bring", "remind", "me" and "to".
EXPECT_THAT(rules.rules.front()->lowercase_terminal_rules->terminal_offsets,
SizeIs(6));
EXPECT_THAT(rules.rules.front()->terminal_rules->terminal_offsets, IsEmpty());
// As "bring" is a suffix of "upbring" it is expected to be suffix merged in
// the string pool
EXPECT_THAT(rules.terminals,
Eq(std::string("buy\0me\0remind\0to\0upbring\0", 25)));
EXPECT_THAT(rules.rules.front()->binary_rules, SizeIs(3));
// One unary rule: <action> ::= <set_reminder>
EXPECT_THAT(rules.rules.front()->unary_rules, SizeIs(1));
}
TEST(IrTest, HandlesRulesSharding) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({"", "de"});
Ir ir(locale_shard_map);
const Nonterm verb = ir.AddUnshareableNonterminal();
const Nonterm set_reminder = ir.AddUnshareableNonterminal();
// Shard 0: en
ir.Add(verb, "buy");
ir.Add(verb, "bring");
ir.Add(verb, "remind");
ir.Add(set_reminder,
std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "remind"),
ir.Add(kUnassignedNonterm, "me"),
ir.Add(kUnassignedNonterm, "to"), verb});
// Shard 1: de
ir.Add(verb, "kaufen", /*case_sensitive=*/false, /*shard=*/1);
ir.Add(verb, "bringen", /*case_sensitive=*/false, /*shard=*/1);
ir.Add(verb, "erinnern", /*case_sensitive=*/false, /*shard=*/1);
ir.Add(set_reminder,
std::vector<Nonterm>{ir.Add(kUnassignedNonterm, "erinnere",
/*case_sensitive=*/false, /*shard=*/1),
ir.Add(kUnassignedNonterm, "mich",
/*case_sensitive=*/false, /*shard=*/1),
ir.Add(kUnassignedNonterm, "zu",
/*case_sensitive=*/false, /*shard=*/1),
verb},
/*shard=*/1);
// Test that terminal strings are correctly merged into the shared
// string pool.
RulesSetT rules;
ir.Serialize(/*include_debug_information=*/false, &rules);
EXPECT_THAT(rules.rules, SizeIs(2));
// 5 distinct terminals: "buy", "bring", "remind", "me" and "to".
EXPECT_THAT(rules.rules[0]->lowercase_terminal_rules->terminal_offsets,
SizeIs(5));
EXPECT_THAT(rules.rules[0]->terminal_rules->terminal_offsets, IsEmpty());
// 6 distinct terminals: "kaufen", "bringen", "erinnern", "erinnere", "mich"
// and "zu".
EXPECT_THAT(rules.rules[1]->lowercase_terminal_rules->terminal_offsets,
SizeIs(6));
EXPECT_THAT(rules.rules[1]->terminal_rules->terminal_offsets, IsEmpty());
EXPECT_THAT(rules.terminals,
Eq(std::string("bring\0bringen\0buy\0erinnere\0erinnern\0kaufen\0"
"me\0mich\0remind\0to\0zu\0",
64)));
// Intermediate rules should be in shard 0.
EXPECT_THAT(rules.rules[0]->binary_rules, SizeIs(6));
EXPECT_THAT(rules.rules[1]->binary_rules, SizeIs(0));
}
TEST(IrTest, DeduplicatesLhsSets) {
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
Ir ir(locale_shard_map);
const Nonterm test = ir.AddUnshareableNonterminal();
ir.Add(test, "test");
// Add a second rule for the same nonterminal.
ir.Add(test, "again");
RulesSetT rules;
ir.Serialize(/*include_debug_information=*/false, &rules);
EXPECT_THAT(rules.lhs_set, SizeIs(1));
EXPECT_THAT(rules.lhs_set.front()->lhs, ElementsAre(test));
}
} // namespace
} // namespace libtextclassifier3::grammar