# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
""" Test task scheduler """

import tempfile

import multiprocessing
import numpy as np

import tvm
import tvm.testing
from tvm import auto_scheduler

from tvm.testing.auto_scheduler import matmul_auto_scheduler_test


@tvm.testing.requires_llvm
def test_task_scheduler_round_robin():
    tasks = []
    for n in [2, 4, 8]:
        tasks.append(
            auto_scheduler.SearchTask(
                func=matmul_auto_scheduler_test, args=(n, n, n), target="llvm"
            )
        )

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name
        num_trials_per_task = 2

        # Tune all tasks
        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=num_trials_per_task * len(tasks),
            runner=measure_ctx.runner,
            num_measures_per_round=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        task_scheduler = auto_scheduler.TaskScheduler(tasks, strategy="round-robin", callbacks=[])
        task_scheduler.tune(tune_option, search_policy="sketch.random")

        # Check the result of round robin
        counters = {}
        for task in tasks:
            counters[task.workload_key] = 0

        for inp, _ in auto_scheduler.load_records(log_file):
            counters[inp.task.workload_key] += 1

        for task in tasks:
            assert counters[task.workload_key] == num_trials_per_task

        # test continuous tuning (restoring the status)
        task_scheduler = auto_scheduler.TaskScheduler(
            tasks, strategy="round-robin", load_log_file=log_file, callbacks=[]
        )
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=len(tasks),
            num_measures_per_round=1,
        )
        task_scheduler.tune(tune_option, search_policy="sketch.random")
        del measure_ctx


@tvm.testing.requires_llvm
def task_scheduler_round_robin_spawn():
    assert multiprocessing.get_start_method(False) == "spawn"
    test_task_scheduler_round_robin()


@tvm.testing.requires_llvm
def test_task_scheduler_round_robin_spawn():
    ctx = multiprocessing.get_context("spawn")
    p = ctx.Process(target=task_scheduler_round_robin_spawn)
    p.start()
    p.join()


@tvm.testing.requires_llvm
def test_task_scheduler_gradient():
    tasks = []
    for n in [2, 4]:
        tasks.append(
            auto_scheduler.SearchTask(
                func=matmul_auto_scheduler_test, args=(n, n, n), target="llvm"
            )
        )

    def objective_func(costs):
        return costs[0]

    with tempfile.NamedTemporaryFile() as fp:
        log_file = fp.name

        n_trials = 5

        # Tune all tasks
        measure_ctx = auto_scheduler.LocalRPCMeasureContext()
        tune_option = auto_scheduler.TuningOptions(
            num_measure_trials=n_trials,
            runner=measure_ctx.runner,
            num_measures_per_round=1,
            measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
        )
        task_scheduler = auto_scheduler.TaskScheduler(
            tasks, objective_func=objective_func, callbacks=[]
        )

        # Forcely rewrite the initial values.
        # This can make this test more stable on the slow CI machines
        task_scheduler.best_costs = np.array([1e2, 1e-8])

        task_scheduler.tune(tune_option, search_policy="sketch.random")

        # Check the allocation results
        counters = {}
        for task in tasks:
            counters[task.workload_key] = 0

        for inp, _ in auto_scheduler.load_records(log_file):
            counters[inp.task.workload_key] += 1

        assert counters[tasks[0].workload_key] == n_trials - 1
        assert counters[tasks[1].workload_key] == 1
        del measure_ctx


if __name__ == "__main__":
    test_task_scheduler_round_robin()
    test_task_scheduler_round_robin_spawn()
    test_task_scheduler_gradient()
