#include "main.h"
#include "gtest.h"

#include <library/cpp/string_utils/relaxed_escaper/relaxed_escaper.h>
#include <library/cpp/testing/common/env.h>
#include <library/cpp/testing/hook/hook.h>
#include <util/generic/scope.h>
#include <util/string/join.h>
#include <util/system/src_root.h>

#include <fstream>

namespace {
    bool StartsWith(const char* str, const char* pre) {
        return strncmp(pre, str, strlen(pre)) == 0;
    }

    void Unsupported(const char* flag) {
        std::cerr << "This GTest wrapper does not support flag " << flag << std::endl;
        exit(2);
    }

    void Unknown(const char* flag) {
        std::cerr << "Unknown support flag " << flag << std::endl;
        exit(2);
    }

    std::pair<std::string_view, std::string_view> ParseName(std::string_view name) {
        auto pos = name.find("::");
        if (pos == std::string_view::npos) {
            return {name, "*"};
        } else {
            return {name.substr(0, pos), name.substr(pos + 2, name.size())};
        }
    }

    std::pair<std::string_view, std::string_view> ParseParam(std::string_view param) {
        auto pos = param.find("=");
        if (pos == std::string_view::npos) {
            return {param, ""};
        } else {
            return {param.substr(0, pos), param.substr(pos + 1, param.size())};
        }
    }

    constexpr std::string_view StripRoot(std::string_view f) noexcept {
        return ::NPrivate::StripRoot(::NPrivate::TStaticBuf(f.data(), f.size())).As<std::string_view>();
    }

    std::string EscapeJson(std::string_view str) {
        TString result;
        NEscJ::EscapeJ<true, true>(str, result);
        return result;
    }

    class TTraceWriter: public ::testing::EmptyTestEventListener {
    public:
        explicit TTraceWriter(std::ostream* trace)
            : Trace_(trace)
        {
        }

    private:
        void OnTestProgramStart(const testing::UnitTest& test) override {
            auto ts = std::chrono::duration_cast<std::chrono::duration<double>>(
                std::chrono::system_clock::now().time_since_epoch());

            for (int i = 0; i < test.total_test_suite_count(); ++i) {
                auto suite = test.GetTestSuite(i);
                for (int j = 0; j < suite->total_test_count(); ++j) {
                    auto testInfo = suite->GetTestInfo(j);
                    if (testInfo->is_reportable() && !testInfo->should_run()) {
                        PrintTestStatus(*testInfo, "skipped", "test is disabled", {}, ts);
                    }
                }
            }
        }

        void OnTestStart(const ::testing::TestInfo& testInfo) override {
            // We fully format this marker before printing it to stderr/stdout because we want to print it atomically.
            // If we were to write `std::cout << "\n###subtest-finished:" << name`, there would be a chance that
            // someone else could sneak in and print something between `"\n###subtest-finished"` and `name`
            // (this happens when test binary uses both `Cout` and `std::cout`).
            auto marker = Join("", "\n###subtest-started:", testInfo.test_suite_name(), "::", testInfo.name(), "\n");

            // Theoretically, we don't need to flush both `Cerr` and `std::cerr` here because both ultimately
            // result in calling `fflush(stderr)`. However, there may be additional buffering logic
            // going on (custom `std::cerr.tie()`, for example), so just to be sure, we flush both of them.
            std::cout << std::flush;
            Cout << marker << Flush;

            std::cerr << std::flush;
            Cerr << marker << Flush;

            auto ts = std::chrono::duration_cast<std::chrono::duration<double>>(
                std::chrono::system_clock::now().time_since_epoch());

            (*Trace_)
                << "{"
                <<   "\"name\": \"subtest-started\", "
                <<   "\"timestamp\": " << std::setprecision(14) << ts.count() << ", "
                <<   "\"value\": {"
                <<     "\"class\": " << EscapeJson(testInfo.test_suite_name()) << ", "
                <<     "\"subtest\": " << EscapeJson(testInfo.name())
                <<   "}"
                << "}"
                << "\n"
                << std::flush;
        }

        void OnTestPartResult(const testing::TestPartResult& result) override {
            if (!result.passed()) {
                if (result.file_name()) {
                    std::cerr << StripRoot(result.file_name()) << ":" << result.line_number() << ":" << "\n";
                }
                std::cerr << result.message() << "\n";
                std::cerr << std::flush;
            }
        }

        void OnTestEnd(const ::testing::TestInfo& testInfo) override {
            auto ts = std::chrono::duration_cast<std::chrono::duration<double>>(
                std::chrono::system_clock::now().time_since_epoch());

            std::string_view status = "good";
            if (testInfo.result()->Failed()) {
                status = "fail";
            } else if (testInfo.result()->Skipped()) {
                status = "skipped";
            }

            std::ostringstream messages;
            std::unordered_map<std::string, double> properties;

            {
                if (testInfo.value_param()) {
                    messages << "Value param:\n  " << testInfo.value_param() << "\n";
                }

                if (testInfo.type_param()) {
                    messages << "Type param:\n  " << testInfo.type_param() << "\n";
                }

                std::string_view sep;
                for (int i = 0; i < testInfo.result()->total_part_count(); ++i) {
                    auto part = testInfo.result()->GetTestPartResult(i);
                    if (part.failed()) {
                        messages << sep;
                        if (part.file_name()) {
                            messages << StripRoot(part.file_name()) << ":" << part.line_number() << ":\n";
                        }
                        messages << part.message();
                        messages << "\n";
                        sep = "\n";
                    }
                }

                for (int i = 0; i < testInfo.result()->test_property_count(); ++i) {
                    auto& property = testInfo.result()->GetTestProperty(i);

                    double value;

                    try {
                        value = std::stod(property.value());
                    } catch (std::invalid_argument&) {
                        messages
                            << sep
                            << "Arcadia CI only supports numeric properties, property "
                            << property.key() << "=" << EscapeJson(property.value()) << " is not a number\n";
                        std::cerr
                            << "Arcadia CI only supports numeric properties, property "
                            << property.key() << "=" << EscapeJson(property.value()) << " is not a number\n"
                            << std::flush;
                        status = "fail";
                        sep = "\n";
                        continue;
                    } catch (std::out_of_range&) {
                        messages
                            << sep
                            << "Property " << property.key() << "=" << EscapeJson(property.value())
                            << " is too big for a double precision value\n";
                        std::cerr
                            << "Property " << property.key() << "=" << EscapeJson(property.value())
                            << " is too big for a double precision value\n"
                            << std::flush;
                        status = "fail";
                        sep = "\n";
                        continue;
                    }

                    properties[property.key()] = value;
                }
            }

            auto marker = Join("", "\n###subtest-finished:", testInfo.test_suite_name(), "::", testInfo.name(), "\n");

            std::cout << std::flush;
            Cout << marker << Flush;

            std::cerr << std::flush;
            Cerr << marker << Flush;

            PrintTestStatus(testInfo, status, messages.str(), properties, ts);
        }

        void PrintTestStatus(
                const ::testing::TestInfo& testInfo,
                std::string_view status,
                std::string_view messages,
                const std::unordered_map<std::string, double>& properties,
                std::chrono::duration<double> ts)
        {
            (*Trace_)
                << "{"
                <<   "\"name\": \"subtest-finished\", "
                <<   "\"timestamp\": " << std::setprecision(14) << ts.count() << ", "
                <<   "\"value\": {"
                <<     "\"class\": " << EscapeJson(testInfo.test_suite_name()) << ", "
                <<     "\"subtest\": " << EscapeJson(testInfo.name()) << ", "
                <<     "\"comment\": " << EscapeJson(messages) << ", "
                <<     "\"status\": " << EscapeJson(status) << ", "
                <<     "\"time\": " << (testInfo.result()->elapsed_time() * (1 / 1000.0)) << ", "
                <<     "\"metrics\": {";
            {
                std::string_view sep = "";
                for (auto& [key, value]: properties) {
                    (*Trace_) << sep << EscapeJson(key) << ": " << value;
                    sep = ", ";
                }
            }
            (*Trace_)
                <<     "}"
                <<   "}"
                << "}"
                << "\n"
                << std::flush;
        }

        std::ostream* Trace_;
    };
}

int NGTest::Main(int argc, char** argv) {
    auto flags = ParseFlags(argc, argv);

    ::testing::GTEST_FLAG(filter) = flags.Filter;

    std::ofstream trace;
    if (!flags.TracePath.empty()) {
        trace.open(flags.TracePath, (flags.AppendTrace ? std::ios::app : std::ios::out) | std::ios::binary);

        if (!trace.is_open()) {
            std::cerr << "Failed to open file " << flags.TracePath << " for write" << std::endl;
            exit(2);
        }

        UnsetDefaultReporter();
        SetTraceReporter(&trace);
    }

    NTesting::THook::CallBeforeInit();

    ::testing::InitGoogleMock(&flags.GtestArgc, flags.GtestArgv.data());

    ListTests(flags.ListLevel, flags.ListPath);

    NTesting::THook::CallBeforeRun();

    Y_DEFER { NTesting::THook::CallAfterRun(); };

    return RUN_ALL_TESTS();
}

NGTest::TFlags NGTest::ParseFlags(int argc, char** argv) {
    TFlags result;

    std::ostringstream filtersPos;
    std::string_view filterPosSep = "";
    std::ostringstream filtersNeg;
    std::string_view filterNegSep = "";

    if (argc > 0) {
        result.GtestArgv.push_back(argv[0]);
    }

    for (int i = 1; i < argc; ++i) {
        auto name = argv[i];

        if (strcmp(name, "--help") == 0) {
            result.GtestArgv.push_back(name);
            break;
        } else if (StartsWith(name, "--gtest_") || StartsWith(name, "--gmock_")) {
            result.GtestArgv.push_back(name);
        } else if (strcmp(name, "--list") == 0 || strcmp(name, "-l") == 0) {
            result.ListLevel = std::max(result.ListLevel, 1);
        } else if (strcmp(name, "--list-verbose") == 0) {
            result.ListLevel = std::max(result.ListLevel, 2);
        } else if (strcmp(name, "--print-before-suite") == 0) {
            Unsupported("--print-before-suite");
        } else if (strcmp(name, "--print-before-test") == 0) {
            Unsupported("--print-before-test");
        } else if (strcmp(name, "--show-fails") == 0) {
            Unsupported("--show-fails");
        } else if (strcmp(name, "--dont-show-fails") == 0) {
            Unsupported("--dont-show-fails");
        } else if (strcmp(name, "--print-times") == 0) {
            Unsupported("--print-times");
        } else if (strcmp(name, "--from") == 0) {
            Unsupported("--from");
        } else if (strcmp(name, "--to") == 0) {
            Unsupported("--to");
        } else if (strcmp(name, "--fork-tests") == 0) {
            Unsupported("--fork-tests");
        } else if (strcmp(name, "--is-forked-internal") == 0) {
            Unsupported("--is-forked-internal");
        } else if (strcmp(name, "--loop") == 0) {
            Unsupported("--loop");
        } else if (strcmp(name, "--trace-path") == 0 || strcmp(name, "--trace-path-append") == 0) {
            ++i;

            if (i >= argc) {
                std::cerr << "Missing value for argument --trace-path" << std::endl;
                exit(2);
            } else if (!result.TracePath.empty()) {
                std::cerr << "Multiple --trace-path or --trace-path-append given" << std::endl;
                exit(2);
            }

            result.TracePath = argv[i];
            result.AppendTrace = strcmp(name, "--trace-path-append") == 0;
        } else if (strcmp(name, "--list-path") == 0) {
            ++i;

            if (i >= argc) {
                std::cerr << "Missing value for argument --list-path" << std::endl;
                exit(2);
            }

            result.ListPath = argv[i];
        } else if (strcmp(name, "--test-param") == 0) {
            ++i;

            if (i >= argc) {
                std::cerr << "Missing value for argument --test-param" << std::endl;
                exit(2);
            }

            auto [key, value] = ParseParam(argv[i]);

        Singleton<NPrivate::TTestEnv>()->AddTestParam(key, value);
        } else if (StartsWith(name, "--")) {
            Unknown(name);
        } else if (*name == '-') {
            auto [suite, test] = ParseName(name + 1);
            filtersNeg << filterNegSep << suite << "." << test;
            filterNegSep = ":";
        } else if (*name == '+') {
            auto [suite, test] = ParseName(name + 1);
            filtersPos << filterPosSep << suite << "." << test;
            filterPosSep = ":";
        } else {
            auto [suite, test] = ParseName(name);
            filtersPos << filterPosSep << suite << "." << test;
            filterPosSep = ":";
        }
    }

    if (!filtersPos.str().empty() || !filtersNeg.str().empty()) {
        result.Filter = filtersPos.str();
        if (!filtersNeg.str().empty()) {
            result.Filter += "-";
            result.Filter += filtersNeg.str();
        }
    }

    // Main-like functions need a null sentinel at the end of `argv' argument.
    // This sentinel is not counted in `argc' argument.
    result.GtestArgv.push_back(nullptr);
    result.GtestArgc = static_cast<int>(result.GtestArgv.size()) - 1;

    return result;
}

void NGTest::ListTests(int listLevel, const std::string& listPath) {
    // NOTE: do not use `std::endl`, use `\n`; `std::endl` produces `\r\n`s on windows,
    // and ya make does not handle them well.

    if (listLevel > 0) {
        std::ostream* listOut = &std::cout;
        std::ofstream listFile;

        if (!listPath.empty()) {
            listFile.open(listPath, std::ios::out | std::ios::binary);
            if (!listFile.is_open()) {
                std::cerr << "Failed to open file " << listPath << " for write" << std::endl;
                exit(2);
            }
            listOut = &listFile;
        }

        for (int i = 0; i < testing::UnitTest::GetInstance()->total_test_suite_count(); ++i) {
            auto suite = testing::UnitTest::GetInstance()->GetTestSuite(i);
            if (listLevel > 1) {
                for (int j = 0; j < suite->total_test_count(); ++j) {
                    auto test = suite->GetTestInfo(j);
                    (*listOut) << suite->name() << "::" << test->name() << "\n";
                }
            } else {
                (*listOut) << suite->name() << "\n";
            }
        }

        (*listOut) << std::flush;

        exit(0);
    }
}

void NGTest::UnsetDefaultReporter() {
    ::testing::TestEventListeners& listeners = ::testing::UnitTest::GetInstance()->listeners();
    delete listeners.Release(listeners.default_result_printer());
}

void NGTest::SetTraceReporter(std::ostream* traceFile) {
    ::testing::TestEventListeners& listeners = ::testing::UnitTest::GetInstance()->listeners();
    listeners.Append(new TTraceWriter{traceFile});
}