Published on

C++ で numpy.argsort

Authors

要素自体をソートするのでは無く,要素の大小のインデックスを取得したいことがあります.Python の numpy だと argsort ってやつです.それの C++ 版です.

#include <vector>
#include <algorithm>
#include <functional>

#include <boost/random.hpp>

template <typename Sequence, typename BinaryPredicate>
struct IndexCompareT {
  IndexCompareT(const Sequence& seq, const BinaryPredicate comp)
    : seq_(seq), comp_(comp) { }
  bool operator()(const size_t a, const size_t b) const
  {
    return comp_(seq_[a], seq_[b]);
  }
  const Sequence seq_;
  const BinaryPredicate comp_;
};

template <typename Sequence, typename BinaryPredicate>
IndexCompareT<Sequence, BinaryPredicate>
IndexCompare(const Sequence& seq, const BinaryPredicate comp)
{
  return IndexCompareT<Sequence, BinaryPredicate>(seq, comp);
}

template <typename Sequence, typename BinaryPredicate>
std::vector<size_t> ArgSort(const Sequence& seq, BinaryPredicate func)
{
  std::vector<size_t> index(seq.size());
  for (int i = 0; i < index.size(); i++)
    index[i] = i;

  std::sort(index.begin(), index.end(), IndexCompare(seq, func));

  return index;
}

int main()
{
  std::vector<double> val;
  boost::random::mt19937 gen;
  boost::random::uniform_real_distribution<> dist(0, 1000.0);

  for (int i = 0; i < 10; i++) {
    val.push_back(dist(gen));
  }

  for (int i = 0; i < 10; i++) {
    std::cout << i << "\t" << val[i] << "\n";
  }

  std::cout << "\n";

  std::vector<size_t> sorted_index = ArgSort(val, std::greater<double>());
  for (int i = 0; i < 10; i++) {
    std::cout << sorted_index[i] << "\t" << val[sorted_index[i]] << "\n";
  }

  std::cout << "\n";

  sorted_index = ArgSort(val, std::less<double>());
  for (int i = 0; i < 10; i++) {
    std::cout << sorted_index[i] << "\t" << val[sorted_index[i]] << "\n";
  }

  return 0;
}
0       814.724
1       135.477
2       905.792
3       835.009
4       126.987
5       968.868
6       913.376
7       221.034
8       632.359
9       308.167

5       968.868
6       913.376
2       905.792
3       835.009
0       814.724
8       632.359
9       308.167
7       221.034
1       135.477
4       126.987

4       126.987
1       135.477
7       221.034
9       308.167
8       632.359
0       814.724
3       835.009
2       905.792
6       913.376
5       968.868