Tensor Rank

Rank overload

All xtensor’s classes have a member rank that can be used to overload based on rank using SFINAE. Consider the following example:

template <class E, std::enable_if_t<!xt::has_rank_t<E, 2>::value, int> = 0>
inline E foo(E&& a)
{
    ... // act on object of flexible rank, or fixed rank != 2
}

template <class E, std::enable_if_t<xt::has_rank_t<E, 2>::value, int> = 0>
inline E foo(E&& a)
{
    ... // act on object of fixed rank == 2
}

int main()
{
    xt::xarray<size_t> a = {{9, 9}, {9, 9}};
    xt::xtensor<size_t, 1> b = {9, 9};
    xt::xtensor<size_t, 2> c = {{9, 9}, {9, 9}};

    foo(a); // flexible rank -> first overload
    foo(b); // fixed rank == 2 -> first overload
    foo(c); // fixed rank == 2 -> second overload

    return 0;
}

Note

If one wants to test for more than a single value for rank, one can use the default value SIZE_MAX used for flexible rank objects. For example, one could have the following overloads:

// flexible rank
template <class E, std::enable_if_t<!xt::has_fixed_rank_t<E>::value, int> = 0>
inline E foo(E&& a);

// fixed rank == 1
template <class E, std::enable_if_t<xt::has_rank_t<E, 1>::value, int> = 0>
inline E foo(E&& a);

// fixed rank == 2
template <class E, std::enable_if_t<xt::has_rank_t<E, 2>::value, int> = 0>
inline E foo(E&& a);

Note that fixed ranks other than 1 and 2 will raise a compiler error.

Of course, if one wants a more limited scope, one could also do the following:

// flexible rank
inline void foo(xt::xarray<double>& a);

// fixed rank == 1
inline void foo(xt::xtensor<double,1>& a);

// fixed rank == 2
inline void foo(xt::xtensor<double,2>& a);

Rank as member

If you want to use the rank as a member of your own class you can use xt::get_rank<E>. Consider the following example:

template <class T>
struct Foo
{
    static const size_t rank = xt::get_rank<T>::value;

    static size_t value()
    {
        return rank;
    }
};

int main()
{
    xt::xtensor<double, 1> A = xt::zeros<double>({2});
    xt::xtensor<double, 2> B = xt::zeros<double>({2, 2});
    xt::xarray<double> C = xt::zeros<double>({2, 2});

    assert(Foo<decltype(A)>::value() == 1);
    assert(Foo<decltype(B)>::value() == 2);
    assert(Foo<decltype(C)>::value() == SIZE_MAX);

    return 0;
}

xt::get_rank ‘returns’ the rank of the xtensor object if its rank is fixed. In all other cases it ‘returns’ SIZE_MAX. Indeed xt::get_rank<xt::array<double>>::value is equal to SIZE_MAX, but equally so is xt::get_rank<double>::value.