blob: 00972289528ec5077535d846e80ea28812dca523 [file] [log] [blame]
 // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2017 Gagan Goel // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" #include using Eigen::Tensor; using Eigen::array; template static void test_0D_trace() { Tensor tensor; tensor.setRandom(); array dims; Tensor result = tensor.trace(dims); VERIFY_IS_EQUAL(result(), tensor()); } template static void test_all_dimensions_trace() { Tensor tensor1(5, 5, 5); tensor1.setRandom(); Tensor result1 = tensor1.trace(); VERIFY_IS_EQUAL(result1.rank(), 0); float sum = 0.0f; for (int i = 0; i < 5; ++i) { sum += tensor1(i, i, i); } VERIFY_IS_EQUAL(result1(), sum); Tensor tensor2(7, 7, 7, 7, 7); tensor2.setRandom(); array dims = { { 2, 1, 0, 3, 4 } }; Tensor result2 = tensor2.trace(dims); VERIFY_IS_EQUAL(result2.rank(), 0); sum = 0.0f; for (int i = 0; i < 7; ++i) { sum += tensor2(i, i, i, i, i); } VERIFY_IS_EQUAL(result2(), sum); } template static void test_simple_trace() { Tensor tensor1(3, 5, 3); tensor1.setRandom(); array dims1 = { { 0, 2 } }; Tensor result1 = tensor1.trace(dims1); VERIFY_IS_EQUAL(result1.rank(), 1); VERIFY_IS_EQUAL(result1.dimension(0), 5); float sum = 0.0f; for (int i = 0; i < 5; ++i) { sum = 0.0f; for (int j = 0; j < 3; ++j) { sum += tensor1(j, i, j); } VERIFY_IS_EQUAL(result1(i), sum); } Tensor tensor2(5, 5, 7, 7); tensor2.setRandom(); array dims2 = { { 2, 3 } }; Tensor result2 = tensor2.trace(dims2); VERIFY_IS_EQUAL(result2.rank(), 2); VERIFY_IS_EQUAL(result2.dimension(0), 5); VERIFY_IS_EQUAL(result2.dimension(1), 5); for (int i = 0; i < 5; ++i) { for (int j = 0; j < 5; ++j) { sum = 0.0f; for (int k = 0; k < 7; ++k) { sum += tensor2(i, j, k, k); } VERIFY_IS_EQUAL(result2(i, j), sum); } } array dims3 = { { 1, 0 } }; Tensor result3 = tensor2.trace(dims3); VERIFY_IS_EQUAL(result3.rank(), 2); VERIFY_IS_EQUAL(result3.dimension(0), 7); VERIFY_IS_EQUAL(result3.dimension(1), 7); for (int i = 0; i < 7; ++i) { for (int j = 0; j < 7; ++j) { sum = 0.0f; for (int k = 0; k < 5; ++k) { sum += tensor2(k, k, i, j); } VERIFY_IS_EQUAL(result3(i, j), sum); } } Tensor tensor3(3, 7, 3, 7, 3); tensor3.setRandom(); array dims4 = { { 0, 2, 4 } }; Tensor result4 = tensor3.trace(dims4); VERIFY_IS_EQUAL(result4.rank(), 2); VERIFY_IS_EQUAL(result4.dimension(0), 7); VERIFY_IS_EQUAL(result4.dimension(1), 7); for (int i = 0; i < 7; ++i) { for (int j = 0; j < 7; ++j) { sum = 0.0f; for (int k = 0; k < 3; ++k) { sum += tensor3(k, i, k, j, k); } VERIFY_IS_EQUAL(result4(i, j), sum); } } Tensor tensor4(3, 7, 4, 7, 5); tensor4.setRandom(); array dims5 = { { 1, 3 } }; Tensor result5 = tensor4.trace(dims5); VERIFY_IS_EQUAL(result5.rank(), 3); VERIFY_IS_EQUAL(result5.dimension(0), 3); VERIFY_IS_EQUAL(result5.dimension(1), 4); VERIFY_IS_EQUAL(result5.dimension(2), 5); for (int i = 0; i < 3; ++i) { for (int j = 0; j < 4; ++j) { for (int k = 0; k < 5; ++k) { sum = 0.0f; for (int l = 0; l < 7; ++l) { sum += tensor4(i, l, j, l, k); } VERIFY_IS_EQUAL(result5(i, j, k), sum); } } } } template static void test_trace_in_expr() { Tensor tensor(2, 3, 5, 3); tensor.setRandom(); array dims = { { 1, 3 } }; Tensor result(2, 5); result = result.constant(1.0f) - tensor.trace(dims); VERIFY_IS_EQUAL(result.rank(), 2); VERIFY_IS_EQUAL(result.dimension(0), 2); VERIFY_IS_EQUAL(result.dimension(1), 5); float sum = 0.0f; for (int i = 0; i < 2; ++i) { for (int j = 0; j < 5; ++j) { sum = 0.0f; for (int k = 0; k < 3; ++k) { sum += tensor(i, k, j, k); } VERIFY_IS_EQUAL(result(i, j), 1.0f - sum); } } } EIGEN_DECLARE_TEST(cxx11_tensor_trace) { CALL_SUBTEST(test_0D_trace()); CALL_SUBTEST(test_0D_trace()); CALL_SUBTEST(test_all_dimensions_trace()); CALL_SUBTEST(test_all_dimensions_trace()); CALL_SUBTEST(test_simple_trace()); CALL_SUBTEST(test_simple_trace()); CALL_SUBTEST(test_trace_in_expr()); CALL_SUBTEST(test_trace_in_expr()); }