906 lines
37 KiB
C++
906 lines
37 KiB
C++
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "annotator/pod_ner/utils.h"
|
|
|
|
#include <iterator>
|
|
|
|
#include "annotator/model_generated.h"
|
|
#include "annotator/types.h"
|
|
#include "utils/tokenizer-utils.h"
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "absl/strings/str_split.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
using ::testing::IsEmpty;
|
|
using ::testing::Not;
|
|
|
|
using PodNerModel_::CollectionT;
|
|
using PodNerModel_::LabelT;
|
|
using PodNerModel_::Label_::BoiseType;
|
|
using PodNerModel_::Label_::BoiseType_BEGIN;
|
|
using PodNerModel_::Label_::BoiseType_END;
|
|
using PodNerModel_::Label_::BoiseType_INTERMEDIATE;
|
|
using PodNerModel_::Label_::BoiseType_O;
|
|
using PodNerModel_::Label_::BoiseType_SINGLE;
|
|
using PodNerModel_::Label_::MentionType;
|
|
using PodNerModel_::Label_::MentionType_NAM;
|
|
using PodNerModel_::Label_::MentionType_NOM;
|
|
using PodNerModel_::Label_::MentionType_UNDEFINED;
|
|
|
|
constexpr float kPriorityScore = 0.;
|
|
const std::vector<std::string>& kCollectionNames =
|
|
*new std::vector<std::string>{"undefined", "location", "person", "art",
|
|
"organization", "entitiy", "xxx"};
|
|
const auto& kStringToBoiseType = *new absl::flat_hash_map<
|
|
absl::string_view, libtextclassifier3::PodNerModel_::Label_::BoiseType>({
|
|
{"B", libtextclassifier3::PodNerModel_::Label_::BoiseType_BEGIN},
|
|
{"O", libtextclassifier3::PodNerModel_::Label_::BoiseType_O},
|
|
{"I", libtextclassifier3::PodNerModel_::Label_::BoiseType_INTERMEDIATE},
|
|
{"S", libtextclassifier3::PodNerModel_::Label_::BoiseType_SINGLE},
|
|
{"E", libtextclassifier3::PodNerModel_::Label_::BoiseType_END},
|
|
});
|
|
const auto& kStringToMentionType = *new absl::flat_hash_map<
|
|
absl::string_view, libtextclassifier3::PodNerModel_::Label_::MentionType>(
|
|
{{"NAM", libtextclassifier3::PodNerModel_::Label_::MentionType_NAM},
|
|
{"NOM", libtextclassifier3::PodNerModel_::Label_::MentionType_NOM}});
|
|
LabelT CreateLabel(BoiseType boise_type, MentionType mention_type,
|
|
int collection_id) {
|
|
LabelT label;
|
|
label.boise_type = boise_type;
|
|
label.mention_type = mention_type;
|
|
label.collection_id = collection_id;
|
|
return label;
|
|
}
|
|
std::vector<PodNerModel_::LabelT> TagsToLabels(
|
|
const std::vector<std::string>& tags) {
|
|
std::vector<PodNerModel_::LabelT> labels;
|
|
for (const auto& tag : tags) {
|
|
if (tag == "O") {
|
|
labels.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
} else {
|
|
std::vector<absl::string_view> tag_parts = absl::StrSplit(tag, '-');
|
|
labels.emplace_back(CreateLabel(
|
|
kStringToBoiseType.at(tag_parts[0]),
|
|
kStringToMentionType.at(tag_parts[1]),
|
|
std::distance(
|
|
kCollectionNames.begin(),
|
|
std::find(kCollectionNames.begin(), kCollectionNames.end(),
|
|
std::string(tag_parts[2].substr(
|
|
tag_parts[2].rfind('/') + 1))))));
|
|
}
|
|
}
|
|
return labels;
|
|
}
|
|
|
|
std::vector<CollectionT> GetCollections() {
|
|
std::vector<CollectionT> collections;
|
|
for (const std::string& collection_name : kCollectionNames) {
|
|
CollectionT collection;
|
|
collection.name = collection_name;
|
|
collection.single_token_priority_score = kPriorityScore;
|
|
collection.multi_token_priority_score = kPriorityScore;
|
|
collections.emplace_back(collection);
|
|
}
|
|
return collections;
|
|
}
|
|
|
|
class ConvertTagsToAnnotatedSpansTest : public testing::TestWithParam<bool> {};
|
|
INSTANTIATE_TEST_SUITE_P(TagsAndLabelsTest, ConvertTagsToAnnotatedSpansTest,
|
|
testing::Values(true, false));
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansHandlesBIESequence) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_EQ(annotations.size(), 1);
|
|
EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
|
|
EXPECT_EQ(annotations[0].classification[0].collection, "location");
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansHandlesSSequence) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "His father was it.";
|
|
std::vector<std::string> tags = {"O", "S-NAM-/saft/person", "O", "O"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_EQ(annotations.size(), 1);
|
|
EXPECT_EQ(annotations[0].span, CodepointSpan(4, 10));
|
|
EXPECT_EQ(annotations[0].classification[0].collection, "person");
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansHandlesMultiple) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text =
|
|
"Jaromir Jagr, Barak Obama and I met in Google New York City";
|
|
std::vector<std::string> tags = {"B-NAM-/saft/person",
|
|
"E-NAM-/saft/person",
|
|
"B-NOM-/saft/person",
|
|
"E-NOM-/saft/person",
|
|
"O",
|
|
"O",
|
|
"O",
|
|
"O",
|
|
"S-NAM-/saft/organization",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
|
|
ASSERT_EQ(annotations.size(), 4);
|
|
EXPECT_EQ(annotations[0].span, CodepointSpan(0, 13));
|
|
ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[0].classification[0].collection, "person");
|
|
EXPECT_EQ(annotations[1].span, CodepointSpan(14, 25));
|
|
ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[1].classification[0].collection, "person");
|
|
EXPECT_EQ(annotations[2].span, CodepointSpan(39, 45));
|
|
ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[2].classification[0].collection, "organization");
|
|
EXPECT_EQ(annotations[3].span, CodepointSpan(46, 59));
|
|
ASSERT_THAT(annotations[3].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[3].classification[0].collection, "location");
|
|
}
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansHandlesMultipleFirstTokenNotFirst) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::vector<Token> original_tokens = TokenizeOnSpace(
|
|
"Jaromir Jagr, Barak Obama and I met in Google New York City");
|
|
std::vector<std::string> tags = {"B-NOM-/saft/person",
|
|
"E-NOM-/saft/person",
|
|
"O",
|
|
"O",
|
|
"O",
|
|
"O",
|
|
"S-NAM-/saft/organization",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
|
|
tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(original_tokens.begin() + 2, original_tokens.end()),
|
|
TagsToLabels(tags), GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
ASSERT_EQ(annotations.size(), 3);
|
|
EXPECT_EQ(annotations[0].span, CodepointSpan(14, 25));
|
|
ASSERT_THAT(annotations[0].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[0].classification[0].collection, "person");
|
|
EXPECT_EQ(annotations[1].span, CodepointSpan(39, 45));
|
|
ASSERT_THAT(annotations[1].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[1].classification[0].collection, "organization");
|
|
EXPECT_EQ(annotations[2].span, CodepointSpan(46, 59));
|
|
ASSERT_THAT(annotations[2].classification, Not(IsEmpty()));
|
|
EXPECT_EQ(annotations[2].classification[0].collection, "location");
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, ConvertTagsToAnnotatedSpansInvalidCollection) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O", "O", "S-NAM-/saft/invalid_collection"};
|
|
|
|
ASSERT_FALSE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresInconsistentStart) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/xxx",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeStart) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NOM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresInconsistentInside) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/xxx",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeInside) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NOM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansHandlesInconsistentInside) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/xxx",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/true,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/true,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_EQ(annotations.size(), 1);
|
|
EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
|
|
EXPECT_EQ(annotations[0].classification[0].collection, "location");
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresInconsistentEnd) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/xxx"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresInconsistentLabelTypeEnd) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NOM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(
|
|
ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansHandlesInconsistentLabelTypeWhenEntityMatches) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NOM-/saft/location",
|
|
"I-NOM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NAM", "NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/true, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NAM, MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/true, &annotations));
|
|
}
|
|
|
|
EXPECT_EQ(annotations.size(), 1);
|
|
EXPECT_EQ(annotations[0].span, CodepointSpan(10, 23));
|
|
EXPECT_EQ(annotations[0].classification[0].collection, "location");
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansIgnoresFilteredLabel) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NAM-/saft/location",
|
|
"I-NAM-/saft/location",
|
|
"E-NAM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{"NOM"},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{MentionType_NOM},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST_P(ConvertTagsToAnnotatedSpansTest,
|
|
ConvertTagsToAnnotatedSpansWithEmptyLabelFilterIgnoresAll) {
|
|
std::vector<AnnotatedSpan> annotations;
|
|
std::string text = "We met in New York City";
|
|
std::vector<std::string> tags = {"O",
|
|
"O",
|
|
"O",
|
|
"B-NOM-/saft/location",
|
|
"I-NOM-/saft/location",
|
|
"E-NOM-/saft/location"};
|
|
if (GetParam()) {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), tags,
|
|
/*label_filter=*/{},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_label_category_matching=*/false, kPriorityScore,
|
|
&annotations));
|
|
} else {
|
|
ASSERT_TRUE(ConvertTagsToAnnotatedSpans(
|
|
VectorSpan<Token>(TokenizeOnSpace(text)), TagsToLabels(tags),
|
|
GetCollections(),
|
|
/*mention_filter=*/{},
|
|
/*relaxed_inside_label_matching=*/false,
|
|
/*relaxed_mention_type_matching=*/false, &annotations));
|
|
}
|
|
|
|
EXPECT_THAT(annotations, IsEmpty());
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, MergeLabelsIntoLeftSequence) {
|
|
std::vector<PodNerModel_::LabelT> original_labels_left;
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_SINGLE, MentionType_NAM, 1));
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
original_labels_left.emplace_back(
|
|
CreateLabel(BoiseType_SINGLE, MentionType_NAM, 2));
|
|
|
|
std::vector<PodNerModel_::LabelT> labels_right;
|
|
labels_right.emplace_back(
|
|
CreateLabel(BoiseType_BEGIN, MentionType_UNDEFINED, 3));
|
|
labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
labels_right.emplace_back(CreateLabel(BoiseType_O, MentionType_UNDEFINED, 0));
|
|
labels_right.emplace_back(CreateLabel(BoiseType_BEGIN, MentionType_NAM, 4));
|
|
labels_right.emplace_back(
|
|
CreateLabel(BoiseType_INTERMEDIATE, MentionType_UNDEFINED, 4));
|
|
labels_right.emplace_back(
|
|
CreateLabel(BoiseType_END, MentionType_UNDEFINED, 4));
|
|
std::vector<PodNerModel_::LabelT> labels_left = original_labels_left;
|
|
|
|
ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
|
|
/*index_first_right_tag_in_left=*/3,
|
|
&labels_left));
|
|
EXPECT_EQ(labels_left.size(), 9);
|
|
EXPECT_EQ(labels_left[0].collection_id, 0);
|
|
EXPECT_EQ(labels_left[1].collection_id, 0);
|
|
EXPECT_EQ(labels_left[2].collection_id, 0);
|
|
EXPECT_EQ(labels_left[3].collection_id, 1);
|
|
EXPECT_EQ(labels_left[4].collection_id, 0);
|
|
EXPECT_EQ(labels_left[5].collection_id, 0);
|
|
EXPECT_EQ(labels_left[6].collection_id, 4);
|
|
EXPECT_EQ(labels_left[7].collection_id, 4);
|
|
EXPECT_EQ(labels_left[8].collection_id, 4);
|
|
|
|
labels_left = original_labels_left;
|
|
ASSERT_TRUE(MergeLabelsIntoLeftSequence(labels_right,
|
|
/*index_first_right_tag_in_left=*/2,
|
|
&labels_left));
|
|
EXPECT_EQ(labels_left.size(), 8);
|
|
EXPECT_EQ(labels_left[0].collection_id, 0);
|
|
EXPECT_EQ(labels_left[1].collection_id, 0);
|
|
EXPECT_EQ(labels_left[2].collection_id, 0);
|
|
EXPECT_EQ(labels_left[3].collection_id, 1);
|
|
EXPECT_EQ(labels_left[4].collection_id, 0);
|
|
EXPECT_EQ(labels_left[5].collection_id, 4);
|
|
EXPECT_EQ(labels_left[6].collection_id, 4);
|
|
EXPECT_EQ(labels_left[7].collection_id, 4);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanAllWordpices) {
|
|
std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
|
|
{"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
|
|
{"my", 12, 14}, {"name", 15, 19}};
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
|
|
WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{2, 3}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/15);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanInMiddle) {
|
|
std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
|
|
{"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
|
|
{"my", 12, 14}, {"name", 15, 19}};
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
|
|
WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{6, 7}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/5);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 8));
|
|
|
|
wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{6, 7}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/6);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(3, 9));
|
|
|
|
wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{12, 14}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/3);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(9, 12));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToStart) {
|
|
std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
|
|
{"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
|
|
{"my", 12, 14}, {"name", 15, 19}};
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
|
|
WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{2, 3}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/7);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 7));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanCloseToEnd) {
|
|
std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
|
|
{"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
|
|
{"my", 12, 14}, {"name", 15, 19}};
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
|
|
WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{15, 19}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/7);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(5, 12));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindWordpiecesWindowAroundSpanBigSpan) {
|
|
std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
|
|
{"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11},
|
|
{"my", 12, 14}, {"name", 15, 19}};
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
|
|
WordpieceSpan wordpieceSpan = internal::FindWordpiecesWindowAroundSpan(
|
|
{0, 19}, tokens, word_starts,
|
|
/*num_wordpieces=*/12,
|
|
/*max_num_wordpieces_in_window=*/5);
|
|
EXPECT_EQ(wordpieceSpan, WordpieceSpan(0, 12));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindFullTokensSpanInWindow) {
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
int first_token_index, num_tokens;
|
|
WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
|
|
word_starts, /*wordpiece_span=*/{0, 6},
|
|
/*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
|
|
&num_tokens);
|
|
EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
|
|
EXPECT_EQ(first_token_index, 0);
|
|
EXPECT_EQ(num_tokens, 4);
|
|
|
|
updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
|
|
word_starts, /*wordpiece_span=*/{2, 6},
|
|
/*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
|
|
&num_tokens);
|
|
EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(2, 6));
|
|
EXPECT_EQ(first_token_index, 1);
|
|
EXPECT_EQ(num_tokens, 3);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindFullTokensSpanInWindowStartInMiddleOfToken) {
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
int first_token_index, num_tokens;
|
|
WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
|
|
word_starts, /*wordpiece_span=*/{1, 6},
|
|
/*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
|
|
&num_tokens);
|
|
EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
|
|
EXPECT_EQ(first_token_index, 0);
|
|
EXPECT_EQ(num_tokens, 4);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindFullTokensSpanInWindowEndsInMiddleOfToken) {
|
|
std::vector<int32_t> word_starts{0, 2, 3, 5, 6, 7, 10, 11};
|
|
int first_token_index, num_tokens;
|
|
WordpieceSpan updated_wordpiece_span = internal::FindFullTokensSpanInWindow(
|
|
word_starts, /*wordpiece_span=*/{1, 9},
|
|
/*max_num_wordpieces=*/6, /*num_wordpieces=*/12, &first_token_index,
|
|
&num_tokens);
|
|
EXPECT_EQ(updated_wordpiece_span, WordpieceSpan(0, 6));
|
|
EXPECT_EQ(first_token_index, 0);
|
|
EXPECT_EQ(num_tokens, 4);
|
|
}
|
|
TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeOne) {
|
|
std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
|
|
int index_first_full_token = internal::FindFirstFullTokenIndex(
|
|
word_starts, /*first_wordpiece_index=*/2);
|
|
EXPECT_EQ(index_first_full_token, 1);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindFirstFullTokenIndexFirst) {
|
|
std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
|
|
int index_first_full_token = internal::FindFirstFullTokenIndex(
|
|
word_starts, /*first_wordpiece_index=*/0);
|
|
EXPECT_EQ(index_first_full_token, 0);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindFirstFullTokenIndexSizeGreaterThanOne) {
|
|
std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
|
|
int index_first_full_token = internal::FindFirstFullTokenIndex(
|
|
word_starts, /*first_wordpiece_index=*/4);
|
|
EXPECT_EQ(index_first_full_token, 2);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeOne) {
|
|
std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
|
|
int index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/3);
|
|
EXPECT_EQ(index_last_full_token, 1);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindLastFullTokenIndexSizeGreaterThanOne) {
|
|
std::vector<int32_t> word_starts{1, 3, 4, 6, 8, 9};
|
|
int index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/6);
|
|
EXPECT_EQ(index_last_full_token, 2);
|
|
|
|
index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/7);
|
|
EXPECT_EQ(index_last_full_token, 2);
|
|
|
|
index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/10, /*wordpiece_end=*/5);
|
|
EXPECT_EQ(index_last_full_token, 1);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindLastFullTokenIndexLast) {
|
|
std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
|
|
int index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/12, /*wordpiece_end=*/12);
|
|
EXPECT_EQ(index_last_full_token, 7);
|
|
|
|
index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/14, /*wordpiece_end=*/14);
|
|
EXPECT_EQ(index_last_full_token, 7);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, FindLastFullTokenIndexBeforeLast) {
|
|
std::vector<int32_t> word_starts{1, 2, 3, 5, 6, 7, 10, 11};
|
|
int index_last_full_token = internal::FindLastFullTokenIndex(
|
|
word_starts, /*num_wordpieces=*/15, /*wordpiece_end=*/12);
|
|
EXPECT_EQ(index_last_full_token, 6);
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, ExpandWindowAndAlignSequenceSmallerThanMax) {
|
|
WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
|
|
/*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/8,
|
|
/*wordpiece_span_to_expand=*/{2, 5});
|
|
EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 8));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, ExpandWindowAndAlignWindowLengthGreaterThanMax) {
|
|
WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
|
|
/*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/100,
|
|
/*wordpiece_span_to_expand=*/{2, 51});
|
|
EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(2, 51));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToStart) {
|
|
WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
|
|
/*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
|
|
/*wordpiece_span_to_expand=*/{2, 4});
|
|
EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(0, 10));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexCloseToEnd) {
|
|
WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
|
|
/*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
|
|
/*wordpiece_span_to_expand=*/{18, 20});
|
|
EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(10, 20));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, ExpandWindowAndAlignFirstIndexInTheMiddle) {
|
|
int window_first_wordpiece_index = 10;
|
|
int window_last_wordpiece_index = 11;
|
|
WordpieceSpan maxWordpieceSpan = internal::ExpandWindowAndAlign(
|
|
/*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
|
|
/*wordpiece_span_to_expand=*/{10, 12});
|
|
EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(6, 16));
|
|
|
|
window_first_wordpiece_index = 10;
|
|
window_last_wordpiece_index = 12;
|
|
maxWordpieceSpan = internal::ExpandWindowAndAlign(
|
|
/*max_num_wordpieces_in_window=*/10, /*num_wordpieces=*/20,
|
|
/*wordpiece_span_to_expand=*/{10, 13});
|
|
EXPECT_EQ(maxWordpieceSpan, WordpieceSpan(7, 17));
|
|
}
|
|
|
|
TEST(PodNerUtilsTest, WindowGenerator) {
|
|
std::vector<int32_t> wordpiece_indices = {10, 20, 30, 40, 50, 60, 70, 80};
|
|
std::vector<Token> tokens{{"a", 0, 1}, {"b", 2, 3}, {"c", 4, 5},
|
|
{"d", 6, 7}, {"e", 8, 9}, {"f", 10, 11}};
|
|
std::vector<int32_t> token_starts{0, 2, 3, 5, 6, 7};
|
|
WindowGenerator window_generator(wordpiece_indices, token_starts, tokens,
|
|
/*max_num_wordpieces=*/4,
|
|
/*sliding_window_overlap=*/1,
|
|
/*span_of_interest=*/{0, 12});
|
|
VectorSpan<int32_t> cur_wordpiece_indices;
|
|
VectorSpan<int32_t> cur_token_starts;
|
|
VectorSpan<Token> cur_tokens;
|
|
ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
|
|
&cur_tokens));
|
|
ASSERT_FALSE(window_generator.Done());
|
|
ASSERT_EQ(cur_wordpiece_indices.size(), 3);
|
|
for (int i = 0; i < 3; i++) {
|
|
ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i]);
|
|
}
|
|
ASSERT_EQ(cur_token_starts.size(), 2);
|
|
ASSERT_EQ(cur_tokens.size(), 2);
|
|
for (int i = 0; i < cur_tokens.size(); i++) {
|
|
ASSERT_EQ(cur_token_starts[i], token_starts[i]);
|
|
ASSERT_EQ(cur_tokens[i], tokens[i]);
|
|
}
|
|
|
|
ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
|
|
&cur_tokens));
|
|
ASSERT_FALSE(window_generator.Done());
|
|
ASSERT_EQ(cur_wordpiece_indices.size(), 4);
|
|
for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
|
|
ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 2]);
|
|
}
|
|
ASSERT_EQ(cur_token_starts.size(), 3);
|
|
ASSERT_EQ(cur_tokens.size(), 3);
|
|
for (int i = 0; i < cur_tokens.size(); i++) {
|
|
ASSERT_EQ(cur_token_starts[i], token_starts[i + 1]);
|
|
ASSERT_EQ(cur_tokens[i], tokens[i + 1]);
|
|
}
|
|
|
|
ASSERT_TRUE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
|
|
&cur_tokens));
|
|
ASSERT_TRUE(window_generator.Done());
|
|
ASSERT_EQ(cur_wordpiece_indices.size(), 3);
|
|
for (int i = 0; i < cur_wordpiece_indices.size(); i++) {
|
|
ASSERT_EQ(cur_wordpiece_indices[i], wordpiece_indices[i + 5]);
|
|
}
|
|
ASSERT_EQ(cur_token_starts.size(), 3);
|
|
ASSERT_EQ(cur_tokens.size(), 3);
|
|
for (int i = 0; i < cur_tokens.size(); i++) {
|
|
ASSERT_EQ(cur_token_starts[i], token_starts[i + 3]);
|
|
ASSERT_EQ(cur_tokens[i], tokens[i + 3]);
|
|
}
|
|
|
|
ASSERT_FALSE(window_generator.Next(&cur_wordpiece_indices, &cur_token_starts,
|
|
&cur_tokens));
|
|
}
|
|
} // namespace
|
|
} // namespace libtextclassifier3
|