// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/Modulo.cpp
// and modified by Doris

#include <string.h>

#include <cmath>
#include <memory>
#include <utility>

#include "runtime/decimalv2_value.h"
#include "runtime/primitive_type.h"
#include "vec/columns/column_decimal.h"
#include "vec/columns/column_vector.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/number_traits.h"
#include "vec/functions/cast_type_to_either.h"
#include "vec/functions/simple_function_factory.h"

namespace doris::vectorized {

template <typename A, typename B>
inline void throw_if_division_leads_to_FPE(A a, B b) {
    // http://avva.livejournal.com/2548306.html
    // (-9223372036854775808 % -1) will cause coredump directly, so check this case to throw exception, or maybe could return 0 as result
    if constexpr (IsSignedV<A> && IsSignedV<B>) {
        if (b == -1 && a == std::numeric_limits<A>::min()) {
            throw Exception(ErrorCode::INVALID_ARGUMENT,
                            "Division of minimal signed number by minus one is an undefined "
                            "behavior, {} % {}. ",
                            a, b);
        }
    }
}

template <typename Impl>
class FunctionMod : public IFunction {
    static constexpr bool result_is_decimal = Impl::result_is_decimal;
    mutable bool need_replace_null_data_to_default_ = false;

public:
    static constexpr auto name = Impl::name;

    static FunctionPtr create() { return std::make_shared<FunctionMod>(); }

    FunctionMod() = default;

    String get_name() const override { return name; }

    bool need_replace_null_data_to_default() const override {
        return need_replace_null_data_to_default_;
    }

    size_t get_number_of_arguments() const override { return 2; }

    DataTypes get_variadic_argument_types_impl() const override {
        return Impl::get_variadic_argument_types();
    }

    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
        need_replace_null_data_to_default_ = is_decimal(arguments[0]->get_primitive_type());
        return make_nullable(arguments[0]);
    }

    Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
                        uint32_t result, size_t input_rows_count) const override {
        auto& column_left = block.get_by_position(arguments[0]).column;
        auto& column_right = block.get_by_position(arguments[1]).column;
        const auto* type_left = assert_cast<const typename Impl::DataTypeA*>(
                block.get_by_position(arguments[0]).type.get());
        const auto* type_right = assert_cast<const typename Impl::DataTypeB*>(
                block.get_by_position(arguments[1]).type.get());
        const auto& res_data_type = remove_nullable(block.get_by_position(result).type);
        bool is_const_left = is_column_const(*column_left);
        bool is_const_right = is_column_const(*column_right);

        ColumnPtr column_result = nullptr;
        if (is_const_left && is_const_right) {
            column_result = constant_constant(column_left, column_right, type_left, type_right,
                                              res_data_type, context->check_overflow_for_decimal());
        } else if (is_const_left) {
            column_result = constant_vector(column_left, column_right, type_left, type_right,
                                            res_data_type, context->check_overflow_for_decimal());
        } else if (is_const_right) {
            column_result = vector_constant(column_left, column_right, type_left, type_right,
                                            res_data_type, context->check_overflow_for_decimal());
        } else {
            column_result = vector_vector(column_left, column_right, type_left, type_right,
                                          res_data_type, context->check_overflow_for_decimal());
        }
        block.replace_by_position(result, std::move(column_result));
        return Status::OK();
    }

private:
    ColumnPtr constant_constant(ColumnPtr column_left, ColumnPtr column_right,
                                const typename Impl::DataTypeA* type_left,
                                const typename Impl::DataTypeB* type_right,
                                DataTypePtr res_data_type, bool check_overflow_for_decimal) const {
        const auto* column_left_ptr = assert_cast<const ColumnConst*>(column_left.get());
        const auto* column_right_ptr = assert_cast<const ColumnConst*>(column_right.get());
        DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);

        ColumnPtr column_result = nullptr;

        if constexpr (result_is_decimal) {
            if constexpr (Impl::DataTypeA::PType == TYPE_DECIMALV2) {
                if (!cast_type_to_either<DataTypeDecimalV2>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);

                                typename PrimitiveTypeTraits<Impl::DataTypeA::PType>::CppType
                                        left_tmp;
                                auto left_src =
                                        column_left_ptr
                                                ->template get_value<Impl::DataTypeA::PType>();
                                std::memcpy(&left_tmp, &left_src, sizeof(left_src));
                                typename PrimitiveTypeTraits<Impl::DataTypeB::PType>::CppType
                                        right_tmp;
                                auto right_src =
                                        column_right_ptr
                                                ->template get_value<Impl::DataTypeB::PType>();
                                std::memcpy(&right_tmp, &right_src, sizeof(right_src));
                                column_result = Impl::constant_constant(
                                        left_tmp, right_tmp, max_and_multiplier.first,
                                        max_and_multiplier.second, type_result,
                                        check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            } else {
                if (!cast_type_to_either<DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128,
                                         DataTypeDecimal256>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                typename PrimitiveTypeTraits<Impl::DataTypeA::PType>::CppType
                                        left_tmp;
                                auto left_src =
                                        column_left_ptr
                                                ->template get_value<Impl::DataTypeA::PType>();
                                std::memcpy(&left_tmp, &left_src, sizeof(left_src));
                                typename PrimitiveTypeTraits<Impl::DataTypeB::PType>::CppType
                                        right_tmp;
                                auto right_src =
                                        column_right_ptr
                                                ->template get_value<Impl::DataTypeB::PType>();
                                std::memcpy(&right_tmp, &right_src, sizeof(right_src));
                                column_result = Impl::constant_constant(
                                        left_tmp, right_tmp, max_and_multiplier.first,
                                        max_and_multiplier.second, type_result,
                                        check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            }
        } else {
            typename PrimitiveTypeTraits<Impl::DataTypeA::PType>::CppType left_tmp;
            auto left_src = column_left_ptr->template get_value<Impl::DataTypeA::PType>();
            std::memcpy(&left_tmp, &left_src, sizeof(left_src));
            typename PrimitiveTypeTraits<Impl::DataTypeB::PType>::CppType right_tmp;
            auto right_src = column_right_ptr->template get_value<Impl::DataTypeB::PType>();
            std::memcpy(&right_tmp, &right_src, sizeof(right_src));
            column_result = Impl::constant_constant(left_tmp, right_tmp);
        }

        return ColumnConst::create(std::move(column_result), column_left->size());
    }

    ColumnPtr vector_constant(ColumnPtr column_left, ColumnPtr column_right,
                              const typename Impl::DataTypeA* type_left,
                              const typename Impl::DataTypeB* type_right, DataTypePtr res_data_type,
                              bool check_overflow_for_decimal) const {
        const auto* column_right_ptr = assert_cast<const ColumnConst*>(column_right.get());
        DCHECK(column_right_ptr != nullptr);

        ColumnPtr res = nullptr;
        if constexpr (result_is_decimal) {
            if constexpr (Impl::DataTypeA::PType == TYPE_DECIMALV2) {
                if (!cast_type_to_either<DataTypeDecimalV2>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                typename PrimitiveTypeTraits<Impl::DataTypeB::PType>::CppType tmp;
                                auto src = column_right_ptr
                                                   ->template get_value<Impl::DataTypeB::PType>();
                                std::memcpy(&tmp, &src, sizeof(src));
                                res = Impl::vector_constant(column_left->get_ptr(), tmp,
                                                            max_and_multiplier.first,
                                                            max_and_multiplier.second, type_result,
                                                            check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            } else {
                if (!cast_type_to_either<DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128,
                                         DataTypeDecimal256>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                typename PrimitiveTypeTraits<Impl::DataTypeB::PType>::CppType tmp;
                                auto src = column_right_ptr
                                                   ->template get_value<Impl::DataTypeB::PType>();
                                std::memcpy(&tmp, &src, sizeof(src));
                                res = Impl::vector_constant(column_left->get_ptr(), tmp,
                                                            max_and_multiplier.first,
                                                            max_and_multiplier.second, type_result,
                                                            check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            }
        } else {
            typename PrimitiveTypeTraits<Impl::DataTypeB::PType>::CppType tmp;
            auto src = column_right_ptr->template get_value<Impl::DataTypeB::PType>();
            std::memcpy(&tmp, &src, sizeof(src));
            res = Impl::vector_constant(column_left->get_ptr(), tmp);
        }
        return res;
    }

    ColumnPtr constant_vector(ColumnPtr column_left, ColumnPtr column_right,
                              const typename Impl::DataTypeA* type_left,
                              const typename Impl::DataTypeB* type_right, DataTypePtr res_data_type,
                              bool check_overflow_for_decimal) const {
        const auto* column_left_ptr = assert_cast<const ColumnConst*>(column_left.get());
        DCHECK(column_left_ptr != nullptr);

        ColumnPtr res = nullptr;
        if constexpr (result_is_decimal) {
            if constexpr (Impl::DataTypeA::PType == TYPE_DECIMALV2) {
                if (!cast_type_to_either<DataTypeDecimalV2>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                typename PrimitiveTypeTraits<Impl::DataTypeA::PType>::CppType tmp;
                                auto src = column_left_ptr
                                                   ->template get_value<Impl::DataTypeA::PType>();
                                std::memcpy(&tmp, &src, sizeof(src));
                                res = Impl::constant_vector(tmp, column_right->get_ptr(),
                                                            max_and_multiplier.first,
                                                            max_and_multiplier.second, type_result,
                                                            check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            } else {
                if (!cast_type_to_either<DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128,
                                         DataTypeDecimal256>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                typename PrimitiveTypeTraits<Impl::DataTypeA::PType>::CppType tmp;
                                auto src = column_left_ptr
                                                   ->template get_value<Impl::DataTypeA::PType>();
                                std::memcpy(&tmp, &src, sizeof(src));
                                res = Impl::constant_vector(tmp, column_right->get_ptr(),
                                                            max_and_multiplier.first,
                                                            max_and_multiplier.second, type_result,
                                                            check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            }
        } else {
            typename PrimitiveTypeTraits<Impl::DataTypeA::PType>::CppType tmp;
            auto src = column_left_ptr->template get_value<Impl::DataTypeA::PType>();
            std::memcpy(&tmp, &src, sizeof(src));
            res = Impl::constant_vector(tmp, column_right->get_ptr());
        }
        return res;
    }

    ColumnPtr vector_vector(ColumnPtr column_left, ColumnPtr column_right,
                            const typename Impl::DataTypeA* type_left,
                            const typename Impl::DataTypeB* type_right, DataTypePtr res_data_type,
                            bool check_overflow_for_decimal) const {
        ColumnPtr res = nullptr;
        if constexpr (result_is_decimal) {
            if constexpr (Impl::DataTypeA::PType == TYPE_DECIMALV2) {
                if (!cast_type_to_either<DataTypeDecimalV2>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                res = Impl::vector_vector(
                                        column_left->get_ptr(), column_right->get_ptr(),
                                        max_and_multiplier.first, max_and_multiplier.second,
                                        type_result, check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            } else {
                if (!cast_type_to_either<DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128,
                                         DataTypeDecimal256>(
                            remove_nullable(res_data_type).get(), [&](const auto& type_result) {
                                auto max_and_multiplier = Impl::get_max_and_multiplier(
                                        type_left, type_right, type_result);
                                res = Impl::vector_vector(
                                        column_left->get_ptr(), column_right->get_ptr(),
                                        max_and_multiplier.first, max_and_multiplier.second,
                                        type_result, check_overflow_for_decimal);
                                return true;
                            })) {
                    throw Exception(ErrorCode::INTERNAL_ERROR,
                                    "Wrong type. Expected: Decimal, Actually: {}",
                                    type_to_string(res_data_type->get_primitive_type()));
                }
            }
        } else {
            res = Impl::vector_vector(column_left->get_ptr(), column_right->get_ptr());
        }
        return res;
    }
};

static const DecimalV2Value one(1, 0);

template <typename Impl>
struct ModNumericImpl {
    static constexpr auto name = Impl::name;
    static constexpr bool result_is_decimal = false;
    using ArgA = typename Impl::ArgA;
    using ArgB = typename Impl::ArgB;
    using ColumnType = typename Impl::ColumnType;
    using DataTypeA = typename Impl::DataTypeA;
    using DataTypeB = typename Impl::DataTypeB;

    static DataTypes get_variadic_argument_types() { return Impl::get_variadic_argument_types(); }

    static ColumnPtr constant_constant(ArgA a, ArgB b) {
        auto column_result = ColumnType ::create(1);

        auto null_map = ColumnUInt8::create(1, 0);
        column_result->get_element(0) = Impl::apply(a, b, null_map->get_element(0));
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    static ColumnPtr vector_constant(ColumnPtr column_left, ArgB b) {
        const auto column_left_ptr = assert_cast<const ColumnType*>(column_left.get());
        auto column_result = ColumnType::create(column_left->size());
        DCHECK(column_left_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_left->size(), 0);
        Impl::apply(column_left_ptr->get_data(), b, column_result->get_data(),
                    null_map->get_data());
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    static ColumnPtr constant_vector(ArgA a, ColumnPtr column_right) {
        const auto column_right_ptr = assert_cast<const ColumnType*>(column_right.get());
        auto column_result = ColumnType::create(column_right->size());
        DCHECK(column_right_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_right->size(), 0);
        auto& b = column_right_ptr->get_data();
        auto& c = column_result->get_data();
        auto& n = null_map->get_data();
        size_t size = b.size();
        for (size_t i = 0; i < size; ++i) {
            c[i] = Impl::apply(a, b[i], n[i]);
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    static ColumnPtr vector_vector(ColumnPtr column_left, ColumnPtr column_right) {
        const auto* column_left_ptr = assert_cast<const ColumnType*>(column_left.get());
        const auto* column_right_ptr = assert_cast<const ColumnType*>(column_right.get());

        auto column_result = ColumnType::create(column_left->size());
        DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_result->size(), 0);
        auto& a = column_left_ptr->get_data();
        auto& b = column_right_ptr->get_data();
        auto& c = column_result->get_data();
        auto& n = null_map->get_data();
        size_t size = a.size();
        for (size_t i = 0; i < size; ++i) {
            c[i] = Impl::apply(a[i], b[i], n[i]);
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }
};

template <PrimitiveType Type>
struct ModuloNumericImpl {
    static constexpr auto name = "mod";
    using ArgA = typename PrimitiveTypeTraits<Type>::CppType;
    using ArgB = typename PrimitiveTypeTraits<Type>::CppType;
    using ColumnType = typename PrimitiveTypeTraits<Type>::ColumnType;
    using DataTypeA = typename PrimitiveTypeTraits<Type>::DataType;
    using DataTypeB = typename PrimitiveTypeTraits<Type>::DataType;

    static DataTypes get_variadic_argument_types() {
        return {std::make_shared<typename PrimitiveTypeTraits<Type>::DataType>(),
                std::make_shared<typename PrimitiveTypeTraits<Type>::DataType>()};
    }

    static void apply(const typename ColumnType::Container& a, ArgB b,
                      typename ColumnType::Container& c, PaddedPODArray<UInt8>& null_map) {
        size_t size = c.size();
        UInt8 is_null = b == 0;
        memset(null_map.data(), is_null, sizeof(UInt8) * size);

        if (!is_null) {
            for (size_t i = 0; i < size; i++) {
                if constexpr (is_float_or_double(Type)) {
                    c[i] = std::fmod((double)a[i], (double)b);
                } else {
                    throw_if_division_leads_to_FPE(a[i], b);
                    c[i] = a[i] % b;
                }
            }
        }
    }

    static inline typename PrimitiveTypeTraits<Type>::CppType apply(ArgA a, ArgB b,
                                                                    UInt8& is_null) {
        is_null = b == 0;
        b += is_null;

        if constexpr (is_float_or_double(Type)) {
            return std::fmod((double)a, (double)b);
        } else {
            throw_if_division_leads_to_FPE(a, b);
            return a % b;
        }
    }
};

template <PrimitiveType Type>
struct PModuloNumericImpl {
    using ArgA = typename PrimitiveTypeTraits<Type>::CppType;
    using ArgB = typename PrimitiveTypeTraits<Type>::CppType;
    using ColumnType = typename PrimitiveTypeTraits<Type>::ColumnType;
    using DataTypeA = typename PrimitiveTypeTraits<Type>::DataType;
    using DataTypeB = typename PrimitiveTypeTraits<Type>::DataType;

    static constexpr auto name = "pmod";
    static DataTypes get_variadic_argument_types() {
        return {std::make_shared<typename PrimitiveTypeTraits<Type>::DataType>(),
                std::make_shared<typename PrimitiveTypeTraits<Type>::DataType>()};
    }

    static void apply(const typename ColumnType::Container& a, ArgB b,
                      typename PrimitiveTypeTraits<Type>::ColumnType::Container& c,
                      PaddedPODArray<UInt8>& null_map) {
        size_t size = c.size();
        UInt8 is_null = b == 0;
        memset(null_map.data(), is_null, size);

        if (!is_null) {
            for (size_t i = 0; i < size; i++) {
                if constexpr (is_float_or_double(Type)) {
                    c[i] = std::fmod(std::fmod((double)a[i], (double)b) + (double)b, double(b));
                } else {
                    throw_if_division_leads_to_FPE(a[i], b);
                    c[i] = (a[i] % b + b) % b;
                }
            }
        }
    }

    static inline typename PrimitiveTypeTraits<Type>::CppType apply(ArgA a, ArgB b,
                                                                    UInt8& is_null) {
        is_null = b == 0;
        b += is_null;

        if constexpr (is_float_or_double(Type)) {
            return std::fmod(std::fmod((double)a, (double)b) + (double)b, (double)b);
        } else {
            throw_if_division_leads_to_FPE(a, b);
            return (a % b + b) % b;
        }
    }

    template <PrimitiveType Result = TYPE_DECIMALV2>
    static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b, UInt8& is_null) {
        is_null = b == DecimalV2Value(0);
        b += DecimalV2Value(is_null);
        return (a % b + b) % b;
    }
};

template <PrimitiveType TypeA, PrimitiveType TypeB>
struct ModuloDecimalImpl {
    static_assert(is_decimal(TypeA) && is_decimal(TypeB));
    static_assert((TypeA == TYPE_DECIMALV2 && TypeB == TYPE_DECIMALV2) ||
                  (TypeA != TYPE_DECIMALV2 && TypeB != TYPE_DECIMALV2));
    static constexpr auto name = "mod";
    static constexpr auto is_pmod = false;
    using ArgA = typename PrimitiveTypeTraits<TypeA>::CppType;
    using ArgB = typename PrimitiveTypeTraits<TypeB>::CppType;
    using ArgNativeTypeA = typename PrimitiveTypeTraits<TypeA>::CppType::NativeType;
    using ArgNativeTypeB = typename PrimitiveTypeTraits<TypeB>::CppType::NativeType;
    using DataTypeA = typename PrimitiveTypeTraits<TypeA>::DataType;
    using DataTypeB = typename PrimitiveTypeTraits<TypeB>::DataType;
    using ColumnTypeA = typename PrimitiveTypeTraits<TypeA>::ColumnType;
    using ColumnTypeB = typename PrimitiveTypeTraits<TypeB>::ColumnType;

    static DataTypes get_variadic_argument_types() {
        return {std::make_shared<typename PrimitiveTypeTraits<TypeA>::DataType>(),
                std::make_shared<typename PrimitiveTypeTraits<TypeB>::DataType>()};
    }

    static inline DecimalV2Value apply(DecimalV2Value a, DecimalV2Value b, UInt8& is_null) {
        is_null = b == DecimalV2Value(0);
        return a % (b + DecimalV2Value(is_null));
    }
};

template <typename Impl>
struct ModDecimalImpl {
    static constexpr auto name = Impl::name;
    static constexpr bool result_is_decimal = true;
    using ArgA = typename Impl::ArgA;
    using ArgB = typename Impl::ArgB;
    using ArgNativeTypeA = typename Impl::ArgNativeTypeA;
    using ArgNativeTypeB = typename Impl::ArgNativeTypeB;
    using DataTypeA = typename Impl::DataTypeA;
    using DataTypeB = typename Impl::DataTypeB;
    using ColumnTypeA = typename Impl::ColumnTypeA;
    using ColumnTypeB = typename Impl::ColumnTypeB;

    static DataTypes get_variadic_argument_types() { return Impl::get_variadic_argument_types(); }

    template <PrimitiveType ResultType>
        requires(is_decimal(ResultType) && ResultType != TYPE_DECIMALV2)
    static inline typename PrimitiveTypeTraits<ResultType>::CppType::NativeType impl(
            ArgNativeTypeA a, ArgNativeTypeB b, UInt8& is_null) {
        is_null = b == 0;
        b += is_null;

        throw_if_division_leads_to_FPE(a, b);
        if constexpr (Impl::is_pmod) {
            return (a % b + b) % b;
        } else {
            return static_cast<typename PrimitiveTypeTraits<ResultType>::CppType::NativeType>(a) %
                   b;
        }
    }

    template <PrimitiveType ResultType>
        requires(is_decimal(ResultType) && ResultType != TYPE_DECIMALV2)
    static ColumnPtr constant_constant(
            ArgA a, ArgB b,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType& scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        auto column_result = ColumnDecimal<ResultType>::create(1, res_data_type.get_scale());

        auto null_map = ColumnUInt8::create(1, 0);
        if (check_overflow_for_decimal) {
            column_result->get_element(0) =
                    typename PrimitiveTypeTraits<ResultType>::CppType(apply<true, ResultType>(
                            a.value, b.value, null_map->get_element(0), max_result_number));
        } else {
            column_result->get_element(0) =
                    typename PrimitiveTypeTraits<ResultType>::CppType(apply<false, ResultType>(
                            a.value, b.value, null_map->get_element(0), max_result_number));
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(ResultType == TYPE_DECIMALV2)
    static ColumnPtr constant_constant(
            ArgA a, ArgB b,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType& scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        auto column_result = ColumnDecimal<ResultType>::create(1, res_data_type.get_scale());

        auto null_map = ColumnUInt8::create(1, 0);
        if (check_overflow_for_decimal) {
            column_result->get_element(0) =
                    typename PrimitiveTypeTraits<ResultType>::CppType(apply<true, ResultType>(
                            a.value(), b.value(), null_map->get_element(0), max_result_number));
        } else {
            column_result->get_element(0) =
                    typename PrimitiveTypeTraits<ResultType>::CppType(apply<false, ResultType>(
                            a.value(), b.value(), null_map->get_element(0), max_result_number));
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(is_decimal(ResultType) && ResultType != TYPE_DECIMALV2)
    static ColumnPtr vector_constant(
            ColumnPtr column_left, ArgB b,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType& scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        const auto* column_left_ptr = assert_cast<const ColumnTypeA*>(column_left.get());
        auto column_result =
                ColumnDecimal<ResultType>::create(column_left->size(), res_data_type.get_scale());
        DCHECK(column_left_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_left->size(), 0);
        const auto& a = column_left_ptr->get_data().data();
        const auto& c = column_result->get_data().data();
        auto& n = null_map->get_data();
        auto sz = column_left->size();
        if (check_overflow_for_decimal) {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<true, ResultType>(a[i].value, b.value, n[i], max_result_number));
            }
        } else {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<false, ResultType>(a[i].value, b.value, n[i], max_result_number));
            }
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(ResultType == TYPE_DECIMALV2)
    static ColumnPtr vector_constant(
            ColumnPtr column_left, ArgB b,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType& scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        const auto* column_left_ptr = assert_cast<const ColumnTypeA*>(column_left.get());
        auto column_result =
                ColumnDecimal<ResultType>::create(column_left->size(), res_data_type.get_scale());
        DCHECK(column_left_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_left->size(), 0);
        const auto& a = column_left_ptr->get_data().data();
        const auto& c = column_result->get_data().data();
        auto& n = null_map->get_data();
        auto sz = column_left->size();
        if (check_overflow_for_decimal) {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<true, ResultType>(a[i].value(), b.value(), n[i], max_result_number));
            }
        } else {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<false, ResultType>(a[i].value(), b.value(), n[i], max_result_number));
            }
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(is_decimal(ResultType) && ResultType != TYPE_DECIMALV2)
    static ColumnPtr constant_vector(
            ArgA a, ColumnPtr column_right,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType& scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        const auto* column_right_ptr = assert_cast<const ColumnTypeB*>(column_right.get());
        auto column_result =
                ColumnDecimal<ResultType>::create(column_right->size(), res_data_type.get_scale());
        DCHECK(column_right_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_right->size(), 0);
        const auto& b = column_right_ptr->get_data().data();
        const auto& c = column_result->get_data().data();
        auto& n = null_map->get_data();
        auto sz = column_right->size();
        if (check_overflow_for_decimal) {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<true, ResultType>(a.value, b[i].value, n[i], max_result_number));
            }
        } else {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<false, ResultType>(a.value, b[i].value, n[i], max_result_number));
            }
        }

        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(ResultType == TYPE_DECIMALV2)
    static ColumnPtr constant_vector(
            ArgA a, ColumnPtr column_right,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType& scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        const auto* column_right_ptr = assert_cast<const ColumnTypeB*>(column_right.get());
        auto column_result =
                ColumnDecimal<ResultType>::create(column_right->size(), res_data_type.get_scale());
        DCHECK(column_right_ptr != nullptr);

        auto null_map = ColumnUInt8::create(column_right->size(), 0);
        const auto& b = column_right_ptr->get_data().data();
        const auto& c = column_result->get_data().data();
        auto& n = null_map->get_data();
        auto sz = column_right->size();
        if (check_overflow_for_decimal) {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<true, ResultType>(a.value(), b[i].value(), n[i], max_result_number));
            }
        } else {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<false, ResultType>(a.value(), b[i].value(), n[i], max_result_number));
            }
        }

        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(is_decimal(ResultType) && ResultType != TYPE_DECIMALV2)
    static ColumnPtr vector_vector(
            ColumnPtr column_left, ColumnPtr column_right,
            const typename PrimitiveTypeTraits<ResultType>::CppType max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        const auto* column_left_ptr = assert_cast<const ColumnTypeA*>(column_left.get());
        const auto* column_right_ptr = assert_cast<const ColumnTypeB*>(column_right.get());

        auto column_result =
                ColumnDecimal<ResultType>::create(column_left->size(), res_data_type.get_scale());
        DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);

        // function divide, modulo and pmod
        auto null_map = ColumnUInt8::create(column_result->size(), 0);
        const auto& a = column_left_ptr->get_data().data();
        const auto& b = column_right_ptr->get_data().data();
        const auto& c = column_result->get_data().data();
        auto& n = null_map->get_data();
        auto sz = column_right->size();
        if (check_overflow_for_decimal) {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<true, ResultType>(a[i].value, b[i].value, n[i], max_result_number));
            }
        } else {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = typename DataTypeDecimal<ResultType>::FieldType(
                        apply<false, ResultType>(a[i].value, b[i].value, n[i], max_result_number));
            }
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <PrimitiveType ResultType>
        requires(ResultType == TYPE_DECIMALV2)
    static ColumnPtr vector_vector(
            ColumnPtr column_left, ColumnPtr column_right,
            const typename PrimitiveTypeTraits<ResultType>::CppType max_result_number,
            const typename PrimitiveTypeTraits<ResultType>::CppType scale_diff_multiplier,
            const DataTypeDecimal<ResultType>& res_data_type, bool check_overflow_for_decimal) {
        const auto* column_left_ptr = assert_cast<const ColumnTypeA*>(column_left.get());
        const auto* column_right_ptr = assert_cast<const ColumnTypeB*>(column_right.get());

        auto column_result =
                ColumnDecimal<ResultType>::create(column_left->size(), res_data_type.get_scale());
        DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);

        // function divide, modulo and pmod
        auto null_map = ColumnUInt8::create(column_result->size(), 0);
        const auto& a = column_left_ptr->get_data().data();
        const auto& b = column_right_ptr->get_data().data();
        const auto& c = column_result->get_data().data();
        auto& n = null_map->get_data();
        auto sz = column_right->size();
        if (check_overflow_for_decimal) {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = DecimalV2Value(apply<true, TYPE_DECIMALV2>(a[i].value(), b[i].value(), n[i],
                                                                  max_result_number));
            }
        } else {
            for (size_t i = 0; i < sz; ++i) {
                c[i] = DecimalV2Value(apply<false, TYPE_DECIMALV2>(a[i].value(), b[i].value(), n[i],
                                                                   max_result_number));
            }
        }
        return ColumnNullable::create(std::move(column_result), std::move(null_map));
    }

    template <bool check_overflow_for_decimal, PrimitiveType ResultType>
        requires(is_decimal(ResultType))
    static ALWAYS_INLINE typename PrimitiveTypeTraits<ResultType>::CppType::NativeType apply(
            ArgNativeTypeA a, ArgNativeTypeB b, UInt8& is_null,
            const typename PrimitiveTypeTraits<ResultType>::CppType& max_result_number) {
        if constexpr (DataTypeA::PType == TYPE_DECIMALV2) {
            DecimalV2Value l(a);
            DecimalV2Value r(b);
            auto ans = Impl::apply(l, r, is_null);
            using ANS_TYPE = std::decay_t<decltype(ans)>;
            if constexpr (check_overflow_for_decimal) {
                if constexpr (std::is_same_v<ANS_TYPE, DecimalV2Value>) {
                    if (ans.value() > max_result_number.value() ||
                        ans.value() < -max_result_number.value()) {
                        throw Exception(
                                ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
                                "Arithmetic overflow: {} {} {} = {}, result type: {}",
                                DecimalV2Value(a).to_string(), name, DecimalV2Value(b).to_string(),
                                DecimalV2Value(ans).to_string(), type_to_string(ResultType));
                    }
                } else if constexpr (IsDecimalNumber<ANS_TYPE>) {
                    if (ans.value > max_result_number.value ||
                        ans.value < -max_result_number.value) {
                        throw Exception(
                                ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
                                "Arithmetic overflow: {} {} {} = {}, result type: {}",
                                DecimalV2Value(a).to_string(), name, DecimalV2Value(b).to_string(),
                                DecimalV2Value(ans).to_string(), type_to_string(ResultType));
                    }
                } else {
                    if (ans > max_result_number.value || ans < -max_result_number.value) {
                        throw Exception(
                                ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
                                "Arithmetic overflow: {} {} {} = {}, result type: {}",
                                DecimalV2Value(a).to_string(), name, DecimalV2Value(b).to_string(),
                                DecimalV2Value(ans).to_string(), type_to_string(ResultType));
                    }
                }
            }
            typename PrimitiveTypeTraits<ResultType>::CppType::NativeType result {};
            memcpy(&result, &ans, std::min(sizeof(result), sizeof(ans)));
            return result;
        } else {
            return impl<ResultType>(a, b, is_null);
        }
    }

    template <PrimitiveType PT>
    static std::pair<typename PrimitiveTypeTraits<PT>::CppType,
                     typename PrimitiveTypeTraits<PT>::CppType>
    get_max_and_multiplier(const DataTypeA* type_left, const DataTypeB* type_right,
                           const DataTypeDecimal<PT>& type_result) {
        auto max_result_number =
                DataTypeDecimal<PT>::get_max_digits_number(type_result.get_precision());

        auto orig_result_scale = type_left->get_scale() + type_right->get_scale();
        auto result_scale = type_result.get_scale();
        DCHECK(orig_result_scale >= result_scale);
        auto scale_diff_multiplier =
                DataTypeDecimal<PT>::get_scale_multiplier(orig_result_scale - result_scale);
        return {typename PrimitiveTypeTraits<PT>::CppType(max_result_number),
                typename PrimitiveTypeTraits<PT>::CppType(scale_diff_multiplier)};
    }
};

void register_function_modulo(SimpleFunctionFactory& factory) {
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_TINYINT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_SMALLINT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_INT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_BIGINT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_LARGEINT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_FLOAT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<ModuloNumericImpl<TYPE_DOUBLE>>>>();

    factory.register_function<FunctionMod<ModNumericImpl<PModuloNumericImpl<TYPE_BIGINT>>>>();
    factory.register_function<FunctionMod<ModNumericImpl<PModuloNumericImpl<TYPE_DOUBLE>>>>();

    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMALV2, TYPE_DECIMALV2>>>>();

    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL32, TYPE_DECIMAL32>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL32, TYPE_DECIMAL64>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL32, TYPE_DECIMAL128I>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL32, TYPE_DECIMAL256>>>>();

    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL64, TYPE_DECIMAL32>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL64, TYPE_DECIMAL64>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL64, TYPE_DECIMAL128I>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL64, TYPE_DECIMAL256>>>>();

    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL128I, TYPE_DECIMAL32>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL128I, TYPE_DECIMAL64>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL128I, TYPE_DECIMAL128I>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL128I, TYPE_DECIMAL256>>>>();

    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL256, TYPE_DECIMAL32>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL256, TYPE_DECIMAL64>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL256, TYPE_DECIMAL128I>>>>();
    factory.register_function<
            FunctionMod<ModDecimalImpl<ModuloDecimalImpl<TYPE_DECIMAL256, TYPE_DECIMAL256>>>>();
    factory.register_alias("mod", "fmod");
}

} // namespace doris::vectorized
