#ifdef _OPENMP
#include <omp.h>
#endif
#include <stdint.h>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include "../tools/contest.h"
#define PSRS_CHECK(expr) \
if (expr) { \
} else { \
fprintf(stderr, "CHECK Failed(%s:%d): %s\n", \
__FILE__, __LINE__, #expr); \
exit(EXIT_FAILURE); \
}
namespace parallel_string_radix_sort {
namespace internal {
template<typename StringType> class Compare;
template<> class Compare<const char*> {
public:
explicit Compare(int depth) : depth_(depth) {}
inline bool operator()(const char* const a, const char* const b) {
return strcmp(a + depth_, b + depth_) < 0;
}
private:
int depth_;
};
template<> class Compare<const unsigned char*> {
public:
explicit Compare(int depth) : depth_(depth) {}
inline bool operator()(const unsigned char* const a,
const unsigned char* const b) {
return strcmp((char*)a + depth_, (char*)b + depth_) < 0;
}
private:
int depth_;
};
template<> class Compare<std::string> {
public:
explicit Compare(int depth) : depth_(depth) {}
inline bool operator()(const std::string &a, const std::string &b) {
return a.compare(depth_, a.length() - depth_,
b, depth_, b.length() - depth_) < 0;
}
private:
int depth_;
};
}
@param
template<typename StringType>
class ParallelStringRadixSort {
public:
ParallelStringRadixSort();
~ParallelStringRadixSort();
@param
void Init(size_t max_elems);
@param
@param
void Sort(StringType *strings, size_t num_elems);
private:
static const size_t kThreshold = 30;
static const size_t kDepthLimit = 100;
size_t max_elems_;
StringType *data_, *temp_;
uint8_t *letters8_;
uint16_t *letters16_;
void DeleteAll();
@param
void Sort8(size_t bgn, size_t end, size_t depth, bool flip);
@param
void Sort16(size_t bgn, size_t end, size_t depth, bool flip);
@param
void Sort16Parallel(size_t bgn, size_t end, size_t depth, bool flip);
inline void Recurse(size_t bgn, size_t end, size_t depth, bool flip) {
size_t n = end - bgn;
if (depth >= kDepthLimit || n <= kThreshold) {
if (flip) {
for (size_t i = bgn; i < end; ++i) {
std::swap(data_[i], temp_[i]);
}
}
if (n > 1) {
std::sort(data_ + bgn, data_ + end, internal::Compare<StringType>(depth));
}
} else if (n < (1 << 16)) {
Sort8(bgn, end, depth, flip);
} else {
Sort16(bgn, end, depth, flip);
}
}
};
template<typename StringType>
ParallelStringRadixSort<StringType>::ParallelStringRadixSort()
: max_elems_(0), temp_(NULL), letters8_(NULL), letters16_(NULL) {}
template<typename StringType>
ParallelStringRadixSort<StringType>::~ParallelStringRadixSort() {
DeleteAll();
}
template<typename StringType>
void ParallelStringRadixSort<StringType>::Init(size_t max_elems) {
DeleteAll();
max_elems_ = max_elems;
temp_ = new StringType[max_elems];
letters8_ = new uint8_t[max_elems];
letters16_ = new uint16_t[max_elems];
PSRS_CHECK(temp_ != NULL && letters8_ != NULL && letters16_ != NULL);
}
template<typename StringType>
void ParallelStringRadixSort<StringType>::DeleteAll() {
delete [] temp_;
delete [] letters8_;
delete [] letters16_;
max_elems_ = 0;
temp_ = NULL;
letters8_ = NULL;
letters16_ = NULL;
}
template<typename StringType>
void ParallelStringRadixSort<StringType>
::Sort(StringType *strings, size_t num_elems) {
assert(num_elems <= max_elems_);
data_ = strings;
Sort16Parallel(0, num_elems, 0, false);
}
template<typename StringType>
void ParallelStringRadixSort<StringType>
::Sort8(size_t bgn, size_t end, size_t depth, bool flip) {
size_t cnt[1 << 8] = {};
StringType *src = (flip ? temp_ : data_) + bgn;
StringType *dst = (flip ? data_ : temp_) + bgn;
uint8_t *let = letters8_ + bgn;
size_t n = end - bgn;
for (size_t i = 0; i < n; ++i) {
let[i] = src[i][depth];
}
for (size_t i = 0; i < n; ++i) {
++cnt[let[i]];
}
size_t s = 0;
for (int i = 0; i < 1 << 8; ++i) {
std::swap(cnt[i], s);
s += cnt[i];
}
for (size_t i = 0; i < n; ++i) {
std::swap(dst[cnt[let[i]]++], src[i]);
}
if (flip == false) {
size_t b = 0, e = cnt[0];
for (size_t j = b; j < e; ++j) {
std::swap(src[j], dst[j]);
}
}
for (size_t i = 1; i < 1 << 8; ++i) {
if (cnt[i] - cnt[i - 1] >= 1) {
Recurse(bgn + cnt[i - 1], bgn + cnt[i], depth + 1, !flip);
}
}
}
template<typename StringType>
void ParallelStringRadixSort<StringType>
::Sort16(size_t bgn, size_t end, size_t depth, bool flip) {
size_t *cnt = new size_t[1 << 16]();
PSRS_CHECK(cnt != NULL);
StringType *src = (flip ? temp_ : data_) + bgn;
StringType *dst = (flip ? data_ : temp_) + bgn;
uint16_t *let = letters16_ + bgn;
size_t n = end - bgn;
for (size_t i = 0; i < n; ++i) {
uint16_t x = src[i][depth];
let[i] = x == 0 ? 0 : ((x << 8) | src[i][depth + 1]);
}
for (size_t i = 0; i < n; ++i) {
++cnt[let[i]];
}
size_t s = 0;
for (int i = 0; i < 1 << 16; ++i) {
std::swap(cnt[i], s);
s += cnt[i];
}
for (size_t i = 0; i < n; ++i) {
std::swap(dst[cnt[let[i]]++], src[i]);
}
if (flip == false) {
for (int i = 0; i < 1 << 8; ++i) {
size_t b = i == 0 ? 0 : cnt[(i << 8) - 1];
size_t e = cnt[i << 8];
for (size_t j = b; j < e; ++j) {
std::swap(src[j], dst[j]);
}
}
}
for (size_t i = 1; i < 1 << 16; ++i) {
if ((i & 0xFF) != 0 && cnt[i] - cnt[i - 1] >= 1) {
Recurse(bgn + cnt[i - 1], bgn + cnt[i], depth + 2, !flip);
}
}
delete [] cnt;
}
template<typename StringType>
void ParallelStringRadixSort<StringType>
::Sort16Parallel(size_t bgn, size_t end, size_t depth, bool flip) {
size_t cnt[1 << 16] = {};
StringType *src = (flip ? temp_ : data_) + bgn;
StringType *dst = (flip ? data_ : temp_) + bgn;
uint16_t *let = letters16_ + bgn;
size_t n = end - bgn;
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (size_t i = 0; i < n; ++i) {
uint16_t x = src[i][depth];
let[i] = x == 0 ? 0 : ((x << 8) | src[i][depth + 1]);
}
for (size_t i = 0; i < n; ++i) {
++cnt[let[i]];
}
{
size_t s = 0;
for (int i = 0; i < 1 << 16; ++i) {
std::swap(cnt[i], s);
s += cnt[i];
}
}
for (size_t i = 0; i < n; ++i) {
std::swap(dst[cnt[let[i]]++], src[i]);
}
if (flip == false) {
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int i = 0; i < 1 << 8; ++i) {
size_t b = i == 0 ? 0 : cnt[(i << 8) - 1];
size_t e = cnt[i << 8];
for (size_t j = b; j < e; ++j) {
src[j] = dst[j];
}
}
}
#ifdef _OPENMP
#pragma omp parallel for schedule(dynamic)
#endif
for (size_t i = 1; i < 1 << 16; ++i) {
if ((i & 0xFF) != 0 && cnt[i] - cnt[i - 1] >= 1) {
Recurse(bgn + cnt[i - 1], bgn + cnt[i], depth + 2, !flip);
}
}
}
@param
@param
template<typename StringType>
void Sort(StringType *strings, size_t num_elems) {
ParallelStringRadixSort<StringType> psrs;
psrs.Init(num_elems);
psrs.Sort(strings, num_elems);
}
template<typename StringType, size_t kNumElems>
void Sort(StringType (&strings)[kNumElems]) {
Sort(strings, kNumElems);
}
}
void akiba_parallel_radix_sort(unsigned char **strings, size_t count)
{
parallel_string_radix_sort::Sort<const unsigned char *>(
(const unsigned char **)strings, count);
}
CONTESTANT_REGISTER_PARALLEL(akiba_parallel_radix_sort,
"akiba/parallel_radix_sort",
"Parallel MSD radix sort by Takuya Akiba")