// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <cstdint>
#include <memory>

#include "common/factory_creator.h"
#include "olap/column_predicate.h"
#include "olap/rowset/segment_v2/bloom_filter.h"
#include "olap/rowset/segment_v2/inverted_index_reader.h"
#include "olap/wrapper_field.h"
#include "vec/columns/column_dictionary.h"

namespace doris {

// SharedPredicate only used on topn runtime predicate.
// Runtime predicate globally share one predicate, to ensure that updates can be real-time.
// At the beginning nested predicate may be nullptr, in which case predicate always returns true.
class SharedPredicate final : public ColumnPredicate {
    ENABLE_FACTORY_CREATOR(SharedPredicate);

public:
    SharedPredicate(uint32_t column_id, std::string col_name)
            : ColumnPredicate(column_id, col_name, PrimitiveType::INVALID_TYPE),
              _mtx(std::make_shared<std::shared_mutex>()) {}
    SharedPredicate(const ColumnPredicate& other) = delete;
    SharedPredicate(const SharedPredicate& other, uint32_t column_id)
            : ColumnPredicate(other, column_id),
              _mtx(std::make_shared<std::shared_mutex>()),
              _nested(assert_cast<const SharedPredicate&>(other)._nested
                              ? other._nested->clone(column_id)
                              : nullptr) {}
    ~SharedPredicate() override = default;
    std::string debug_string() const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        fmt::memory_buffer debug_string_buffer;
        fmt::format_to(debug_string_buffer, "SharedPredicate({}, nested={})",
                       ColumnPredicate::debug_string(), _nested ? _nested->debug_string() : "null");
        return fmt::to_string(debug_string_buffer);
    }
    std::shared_ptr<ColumnPredicate> clone(uint32_t column_id) const override {
        // All scanner thread should share the same SharedPredicate object.
        return std::const_pointer_cast<ColumnPredicate>(shared_from_this());
    }

    PredicateType type() const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            // topn filter is le or ge
            return PredicateType::LE;
        }
        return _nested->type();
    }
    PrimitiveType primitive_type() const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return PrimitiveType::INVALID_TYPE;
        }
        return _nested->primitive_type();
    }

    void set_nested(const std::shared_ptr<ColumnPredicate>& nested) {
        std::unique_lock<std::shared_mutex> lock(*_mtx);
        _nested = nested;
    }

    Status evaluate(const vectorized::IndexFieldNameAndTypePair& name_with_type,
                    IndexIterator* iterator, uint32_t num_rows,
                    roaring::Roaring* bitmap) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return Status::OK();
        }
        return _nested->evaluate(name_with_type, iterator, num_rows, bitmap);
    }

    void evaluate_and(const vectorized::IColumn& column, const uint16_t* sel, uint16_t size,
                      bool* flags) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return;
        }
        return _nested->evaluate_and(column, sel, size, flags);
    }

    void evaluate_or(const vectorized::IColumn& column, const uint16_t* sel, uint16_t size,
                     bool* flags) const override {
        DCHECK(false) << "should not reach here";
    }

    bool evaluate_and(const std::pair<WrapperField*, WrapperField*>& statistic) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return ColumnPredicate::evaluate_and(statistic);
        }
        return _nested->evaluate_and(statistic);
    }

    bool evaluate_del(const std::pair<WrapperField*, WrapperField*>& statistic) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return ColumnPredicate::evaluate_del(statistic);
        }
        return _nested->evaluate_del(statistic);
    }

    bool evaluate_and(const BloomFilter* bf) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return ColumnPredicate::evaluate_and(bf);
        }
        return _nested->evaluate_and(bf);
    }

    bool can_do_bloom_filter(bool ngram) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return ColumnPredicate::can_do_bloom_filter(ngram);
        }
        return _nested->can_do_bloom_filter(ngram);
    }

    void evaluate_vec(const vectorized::IColumn& column, uint16_t size,
                      bool* flags) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            for (uint16_t i = 0; i < size; ++i) {
                flags[i] = true;
            }
            return;
        }
        _nested->evaluate_vec(column, size, flags);
    }

    void evaluate_and_vec(const vectorized::IColumn& column, uint16_t size,
                          bool* flags) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return;
        }
        _nested->evaluate_and_vec(column, size, flags);
    }

    std::string get_search_str() const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            DCHECK(false) << "should not reach here";
        }
        return _nested->get_search_str();
    }

    bool evaluate_and(vectorized::ParquetPredicate::ColumnStat* statistic) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            // at the begining _nested will be null, so return true.
            return true;
        }
        return _nested->evaluate_and(statistic);
    }

    bool evaluate_and(vectorized::ParquetPredicate::CachedPageIndexStat* statistic,
                      RowRanges* row_ranges) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);

        if (!_nested) {
            // at the begining _nested will be null, so return true.
            row_ranges->add(statistic->row_group_range);
            return true;
        }
        return _nested->evaluate_and(statistic, row_ranges);
    }

private:
    uint16_t _evaluate_inner(const vectorized::IColumn& column, uint16_t* sel,
                             uint16_t size) const override {
        std::shared_lock<std::shared_mutex> lock(*_mtx);
        if (!_nested) {
            return size;
        }
        return _nested->evaluate(column, sel, size);
    }

    mutable std::shared_ptr<std::shared_mutex> _mtx;
    std::shared_ptr<ColumnPredicate> _nested;
};

} //namespace doris
