// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include "main.h"

template<typename T, typename U>
bool check_if_equal_or_nans(const T& actual, const U& expected) {
  return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
}

template<typename T, typename U>
bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
  return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
         && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
}

template<typename T, typename U>
bool test_is_equal_or_nans(const T& actual, const U& expected)
{
    if (check_if_equal_or_nans(actual, expected)) {
      return true;
    }

    // false:
    std::cerr
        << "\n    actual   = " << actual
        << "\n    expected = " << expected << "\n\n";
    return false;
}

#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))

template<typename T>
void check_abs() {
  typedef typename NumTraits<T>::Real Real;
  Real zero(0);

  if(NumTraits<T>::IsSigned)
    VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
  VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
  VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));

  for(int k=0; k<100; ++k)
  {
    T x = internal::random<T>();
    if(!internal::is_same<T,bool>::value)
      x = x/Real(2);
    if(NumTraits<T>::IsSigned)
    {
      VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
      VERIFY( numext::abs(-x) >= zero );
    }
    VERIFY( numext::abs(x) >= zero );
    VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
  }
}

template<typename T>
void check_arg() {
  typedef typename NumTraits<T>::Real Real;
  VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
  VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));

  for(int k=0; k<100; ++k)
  {
    T x = internal::random<T>();
    Real y = numext::arg(x);
    VERIFY_IS_APPROX( y, std::arg(x) );
  }
}

template<typename T>
struct check_sqrt_impl {
  static void run() {
    for (int i=0; i<1000; ++i) {
      const T x = numext::abs(internal::random<T>());
      const T sqrtx = numext::sqrt(x);
      VERIFY_IS_APPROX(sqrtx*sqrtx, x);
    }

    // Corner cases.
    const T zero = T(0);
    const T one = T(1);
    const T inf = std::numeric_limits<T>::infinity();
    const T nan = std::numeric_limits<T>::quiet_NaN();
    VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
    VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
    VERIFY((numext::isnan)(numext::sqrt(nan)));
    VERIFY((numext::isnan)(numext::sqrt(-one)));
  }
};

template<typename T>
struct check_sqrt_impl<std::complex<T>  > {
  static void run() {
    typedef typename std::complex<T> ComplexT;

    for (int i=0; i<1000; ++i) {
      const ComplexT x = internal::random<ComplexT>();
      const ComplexT sqrtx = numext::sqrt(x);
      VERIFY_IS_APPROX(sqrtx*sqrtx, x);
    }

    // Corner cases.
    const T zero = T(0);
    const T one = T(1);
    const T inf = std::numeric_limits<T>::infinity();
    const T nan = std::numeric_limits<T>::quiet_NaN();

    // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
    const int kNumCorners = 20;
    const ComplexT corners[kNumCorners][2] = {
      {ComplexT(zero, zero), ComplexT(zero, zero)},
      {ComplexT(-zero, zero), ComplexT(zero, zero)},
      {ComplexT(zero, -zero), ComplexT(zero, zero)},
      {ComplexT(-zero, -zero), ComplexT(zero, zero)},
      {ComplexT(one, inf), ComplexT(inf, inf)},
      {ComplexT(nan, inf), ComplexT(inf, inf)},
      {ComplexT(one, -inf), ComplexT(inf, -inf)},
      {ComplexT(nan, -inf), ComplexT(inf, -inf)},
      {ComplexT(-inf, one), ComplexT(zero, inf)},
      {ComplexT(inf, one), ComplexT(inf, zero)},
      {ComplexT(-inf, -one), ComplexT(zero, -inf)},
      {ComplexT(inf, -one), ComplexT(inf, -zero)},
      {ComplexT(-inf, nan), ComplexT(nan, inf)},
      {ComplexT(inf, nan), ComplexT(inf, nan)},
      {ComplexT(zero, nan), ComplexT(nan, nan)},
      {ComplexT(one, nan), ComplexT(nan, nan)},
      {ComplexT(nan, zero), ComplexT(nan, nan)},
      {ComplexT(nan, one), ComplexT(nan, nan)},
      {ComplexT(nan, -one), ComplexT(nan, nan)},
      {ComplexT(nan, nan), ComplexT(nan, nan)},
    };

    for (int i=0; i<kNumCorners; ++i) {
      const ComplexT& x = corners[i][0];
      const ComplexT sqrtx = corners[i][1];
      VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
    }
  }
};

template<typename T>
void check_sqrt() {
  check_sqrt_impl<T>::run();
}

template<typename T>
struct check_rsqrt_impl {
  static void run() {
    const T zero = T(0);
    const T one = T(1);
    const T inf = std::numeric_limits<T>::infinity();
    const T nan = std::numeric_limits<T>::quiet_NaN();

    for (int i=0; i<1000; ++i) {
      const T x = numext::abs(internal::random<T>());
      const T rsqrtx = numext::rsqrt(x);
      const T invx = one / x;
      VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
    }

    // Corner cases.
    VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
    VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
    VERIFY((numext::isnan)(numext::rsqrt(nan)));
    VERIFY((numext::isnan)(numext::rsqrt(-one)));
  }
};

template<typename T>
struct check_rsqrt_impl<std::complex<T> > {
  static void run() {
    typedef typename std::complex<T> ComplexT;
    const T zero = T(0);
    const T one = T(1);
    const T inf = std::numeric_limits<T>::infinity();
    const T nan = std::numeric_limits<T>::quiet_NaN();

    for (int i=0; i<1000; ++i) {
      const ComplexT x = internal::random<ComplexT>();
      const ComplexT invx = ComplexT(one, zero) / x;
      const ComplexT rsqrtx = numext::rsqrt(x);
      VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
    }

    // GCC and MSVC differ in their treatment of 1/(0 + 0i)
    //   GCC/clang = (inf, nan)
    //   MSVC = (nan, nan)
    // and 1 / (x + inf i)
    //   GCC/clang = (0, 0)
    //   MSVC = (nan, nan)
    #if (EIGEN_COMP_GNUC)
    {
      const int kNumCorners = 20;
      const ComplexT corners[kNumCorners][2] = {
        // Only consistent across GCC, clang
        {ComplexT(zero, zero), ComplexT(zero, zero)},
        {ComplexT(-zero, zero), ComplexT(zero, zero)},
        {ComplexT(zero, -zero), ComplexT(zero, zero)},
        {ComplexT(-zero, -zero), ComplexT(zero, zero)},
        {ComplexT(one, inf), ComplexT(inf, inf)},
        {ComplexT(nan, inf), ComplexT(inf, inf)},
        {ComplexT(one, -inf), ComplexT(inf, -inf)},
        {ComplexT(nan, -inf), ComplexT(inf, -inf)},
        // Consistent across GCC, clang, MSVC
        {ComplexT(-inf, one), ComplexT(zero, inf)},
        {ComplexT(inf, one), ComplexT(inf, zero)},
        {ComplexT(-inf, -one), ComplexT(zero, -inf)},
        {ComplexT(inf, -one), ComplexT(inf, -zero)},
        {ComplexT(-inf, nan), ComplexT(nan, inf)},
        {ComplexT(inf, nan), ComplexT(inf, nan)},
        {ComplexT(zero, nan), ComplexT(nan, nan)},
        {ComplexT(one, nan), ComplexT(nan, nan)},
        {ComplexT(nan, zero), ComplexT(nan, nan)},
        {ComplexT(nan, one), ComplexT(nan, nan)},
        {ComplexT(nan, -one), ComplexT(nan, nan)},
        {ComplexT(nan, nan), ComplexT(nan, nan)},
      };

      for (int i=0; i<kNumCorners; ++i) {
        const ComplexT& x = corners[i][0];
        const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
        VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
      }
    }
    #endif
  }
};

template<typename T>
void check_rsqrt() {
  check_rsqrt_impl<T>::run();
}

EIGEN_DECLARE_TEST(numext) {
  for(int k=0; k<g_repeat; ++k)
  {
    CALL_SUBTEST( check_abs<bool>() );
    CALL_SUBTEST( check_abs<signed char>() );
    CALL_SUBTEST( check_abs<unsigned char>() );
    CALL_SUBTEST( check_abs<short>() );
    CALL_SUBTEST( check_abs<unsigned short>() );
    CALL_SUBTEST( check_abs<int>() );
    CALL_SUBTEST( check_abs<unsigned int>() );
    CALL_SUBTEST( check_abs<long>() );
    CALL_SUBTEST( check_abs<unsigned long>() );
    CALL_SUBTEST( check_abs<half>() );
    CALL_SUBTEST( check_abs<bfloat16>() );
    CALL_SUBTEST( check_abs<float>() );
    CALL_SUBTEST( check_abs<double>() );
    CALL_SUBTEST( check_abs<long double>() );
    CALL_SUBTEST( check_abs<std::complex<float> >() );
    CALL_SUBTEST( check_abs<std::complex<double> >() );

    CALL_SUBTEST( check_arg<std::complex<float> >() );
    CALL_SUBTEST( check_arg<std::complex<double> >() );

    CALL_SUBTEST( check_sqrt<float>() );
    CALL_SUBTEST( check_sqrt<double>() );
    CALL_SUBTEST( check_sqrt<std::complex<float> >() );
    CALL_SUBTEST( check_sqrt<std::complex<double> >() );
    
    CALL_SUBTEST( check_rsqrt<float>() );
    CALL_SUBTEST( check_rsqrt<double>() );
    CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
    CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
  }
}