/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/grammar/semantics/evaluators/arithmetic-eval.h" #include namespace libtextclassifier3::grammar { namespace { template StatusOr Reduce( const SemanticExpressionEvaluator* composer, const EvalContext& context, const ArithmeticExpression* expression, UnsafeArena* arena) { T result; switch (expression->op()) { case ArithmeticExpression_::Operator_OP_ADD: { result = 0; break; } case ArithmeticExpression_::Operator_OP_MUL: { result = 1; break; } case ArithmeticExpression_::Operator_OP_MIN: { result = std::numeric_limits::max(); break; } case ArithmeticExpression_::Operator_OP_MAX: { result = std::numeric_limits::min(); break; } default: { return Status(StatusCode::INVALID_ARGUMENT, "Unexpected op: " + std::string(ArithmeticExpression_::EnumNameOperator( expression->op()))); } } if (expression->values() != nullptr) { for (const SemanticExpression* semantic_expression : *expression->values()) { TC3_ASSIGN_OR_RETURN( const SemanticValue* value, composer->Apply(context, semantic_expression, arena)); if (value == nullptr) { continue; } if (!value->Has()) { return Status( StatusCode::INVALID_ARGUMENT, "Argument didn't evaluate as expected type: " + std::string(reflection::EnumNameBaseType(value->base_type()))); } const T scalar_value = value->Value(); switch (expression->op()) { case ArithmeticExpression_::Operator_OP_ADD: { result += scalar_value; break; } case ArithmeticExpression_::Operator_OP_MUL: { result *= scalar_value; break; } case ArithmeticExpression_::Operator_OP_MIN: { result = std::min(result, scalar_value); break; } case ArithmeticExpression_::Operator_OP_MAX: { result = std::max(result, scalar_value); break; } default: { break; } } } } return SemanticValue::Create(result, arena); } } // namespace StatusOr ArithmeticExpressionEvaluator::Apply( const EvalContext& context, const SemanticExpression* expression, UnsafeArena* arena) const { TC3_DCHECK_EQ(expression->expression_type(), SemanticExpression_::Expression_ArithmeticExpression); const ArithmeticExpression* arithmetic_expression = expression->expression_as_ArithmeticExpression(); switch (arithmetic_expression->base_type()) { case reflection::BaseType::Byte: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::UByte: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::Short: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::UShort: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::Int: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::UInt: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::Long: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::ULong: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::Float: return Reduce(composer_, context, arithmetic_expression, arena); case reflection::BaseType::Double: return Reduce(composer_, context, arithmetic_expression, arena); default: return Status(StatusCode::INVALID_ARGUMENT, "Unsupported for ArithmeticExpression: " + std::string(reflection::EnumNameBaseType( static_cast( arithmetic_expression->base_type())))); } } } // namespace libtextclassifier3::grammar