2 * Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
4 * Please refer to the NVIDIA end user license agreement (EULA) associated
5 * with this source code for terms and conditions that govern your use of
6 * this software. Any use, reproduction, disclosure, or distribution of
7 * this software and related documentation outside the terms of the EULA
8 * is strictly prohibited.
18 #include "mergeSort_common.h"
22 ////////////////////////////////////////////////////////////////////////////////
24 ////////////////////////////////////////////////////////////////////////////////
25 static void checkOrder(uint *data, uint N, uint sortDir)
32 for (uint i = 0; i < N - 1; i++)
33 if ((sortDir && (data[i] > data[i + 1])) || (!sortDir && (data[i] < data[i + 1])))
35 fprintf(stderr, "checkOrder() failed!!!\n");
40 static uint umin(uint a, uint b)
42 return (a <= b) ? a : b;
45 static uint getSampleCount(uint dividend)
47 return ((dividend % SAMPLE_STRIDE) != 0) ? (dividend / SAMPLE_STRIDE + 1) : (dividend / SAMPLE_STRIDE);
50 static uint nextPowerOfTwo(uint x)
61 static uint binarySearchInclusive(uint val, uint *data, uint L, uint sortDir)
70 for (uint stride = nextPowerOfTwo(L); stride > 0; stride >>= 1)
72 uint newPos = umin(pos + stride, L);
74 if ((sortDir && (data[newPos - 1] <= val)) || (!sortDir && (data[newPos - 1] >= val)))
83 static uint binarySearchExclusive(uint val, uint *data, uint L, uint sortDir)
92 for (uint stride = nextPowerOfTwo(L); stride > 0; stride >>= 1)
94 uint newPos = umin(pos + stride, L);
96 if ((sortDir && (data[newPos - 1] < val)) || (!sortDir && (data[newPos - 1] > val)))
107 ////////////////////////////////////////////////////////////////////////////////
108 // Merge step 1: find sample ranks in each segment
109 ////////////////////////////////////////////////////////////////////////////////
110 static void generateSampleRanks(
119 uint lastSegmentElements = N % (2 * stride);
120 uint sampleCount = (lastSegmentElements > stride) ? (N + 2 * stride - lastSegmentElements) / (2 * SAMPLE_STRIDE) : (N - lastSegmentElements) / (2 * SAMPLE_STRIDE);
122 for (uint pos = 0; pos < sampleCount; pos++)
124 const uint i = pos & ((stride / SAMPLE_STRIDE) - 1);
125 const uint segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
127 const uint lenA = stride;
128 const uint lenB = umin(stride, N - segmentBase - stride);
129 const uint nA = stride / SAMPLE_STRIDE;
130 const uint nB = getSampleCount(lenB);
134 ranksA[(segmentBase + 0) / SAMPLE_STRIDE + i] = i * SAMPLE_STRIDE;
135 ranksB[(segmentBase + 0) / SAMPLE_STRIDE + i] = binarySearchExclusive(srcKey[segmentBase + i * SAMPLE_STRIDE], srcKey + segmentBase + stride, lenB, sortDir);
140 ranksB[(segmentBase + stride) / SAMPLE_STRIDE + i] = i * SAMPLE_STRIDE;
141 ranksA[(segmentBase + stride) / SAMPLE_STRIDE + i] = binarySearchInclusive(srcKey[segmentBase + stride + i * SAMPLE_STRIDE], srcKey + segmentBase, lenA, sortDir);
148 ////////////////////////////////////////////////////////////////////////////////
149 // Merge step 2: merge ranks and indices to derive elementary intervals
150 ////////////////////////////////////////////////////////////////////////////////
151 static void mergeRanksAndIndices(
158 uint lastSegmentElements = N % (2 * stride);
159 uint sampleCount = (lastSegmentElements > stride) ? (N + 2 * stride - lastSegmentElements) / (2 * SAMPLE_STRIDE) : (N - lastSegmentElements) / (2 * SAMPLE_STRIDE);
161 for (uint pos = 0; pos < sampleCount; pos++)
163 const uint i = pos & ((stride / SAMPLE_STRIDE) - 1);
164 const uint segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
166 const uint lenA = stride;
167 const uint lenB = umin(stride, N - segmentBase - stride);
168 const uint nA = stride / SAMPLE_STRIDE;
169 const uint nB = getSampleCount(lenB);
173 uint dstPosA = binarySearchExclusive(ranks[(segmentBase + 0) / SAMPLE_STRIDE + i], ranks + (segmentBase + stride) / SAMPLE_STRIDE, nB, 1) + i;
174 assert(dstPosA < nA + nB);
175 limits[(segmentBase / SAMPLE_STRIDE) + dstPosA] = ranks[(segmentBase + 0) / SAMPLE_STRIDE + i];
180 uint dstPosA = binarySearchInclusive(ranks[(segmentBase + stride) / SAMPLE_STRIDE + i], ranks + (segmentBase + 0) / SAMPLE_STRIDE, nA, 1) + i;
181 assert(dstPosA < nA + nB);
182 limits[(segmentBase / SAMPLE_STRIDE) + dstPosA] = ranks[(segmentBase + stride) / SAMPLE_STRIDE + i];
189 ////////////////////////////////////////////////////////////////////////////////
190 // Merge step 3: merge elementary intervals (each interval is <= SAMPLE_STRIDE)
191 ////////////////////////////////////////////////////////////////////////////////
204 checkOrder(srcAKey, lenA, sortDir);
205 checkOrder(srcBKey, lenB, sortDir);
207 for (uint i = 0; i < lenA; i++)
209 uint dstPos = binarySearchExclusive(srcAKey[i], srcBKey, lenB, sortDir) + i;
210 assert(dstPos < lenA + lenB);
211 dstKey[dstPos] = srcAKey[i];
212 dstVal[dstPos] = srcAVal[i];
215 for (uint i = 0; i < lenB; i++)
217 uint dstPos = binarySearchInclusive(srcBKey[i], srcAKey, lenA, sortDir) + i;
218 assert(dstPos < lenA + lenB);
219 dstKey[dstPos] = srcBKey[i];
220 dstVal[dstPos] = srcBVal[i];
224 static void mergeElementaryIntervals(
236 uint lastSegmentElements = N % (2 * stride);
237 uint mergePairs = (lastSegmentElements > stride) ? getSampleCount(N) : (N - lastSegmentElements) / SAMPLE_STRIDE;
239 for (uint pos = 0; pos < mergePairs; pos++)
241 uint i = pos & ((2 * stride) / SAMPLE_STRIDE - 1);
242 uint segmentBase = (pos - i) * SAMPLE_STRIDE;
244 const uint lenA = stride;
245 const uint lenB = umin(stride, N - segmentBase - stride);
246 const uint nA = stride / SAMPLE_STRIDE;
247 const uint nB = getSampleCount(lenB);
248 const uint n = nA + nB;
250 const uint startPosA = limitsA[pos];
251 const uint endPosA = (i + 1 < n) ? limitsA[pos + 1] : lenA;
252 const uint startPosB = limitsB[pos];
253 const uint endPosB = (i + 1 < n) ? limitsB[pos + 1] : lenB;
254 const uint startPosDst = startPosA + startPosB;
256 assert(startPosA <= endPosA && endPosA <= lenA);
257 assert(startPosB <= endPosB && endPosB <= lenB);
258 assert((endPosA - startPosA) <= SAMPLE_STRIDE);
259 assert((endPosB - startPosB) <= SAMPLE_STRIDE);
262 dstKey + segmentBase + startPosDst,
263 dstVal + segmentBase + startPosDst,
264 (srcKey + segmentBase + 0) + startPosA,
265 (srcVal + segmentBase + 0) + startPosA,
266 (srcKey + segmentBase + stride) + startPosB,
267 (srcVal + segmentBase + stride) + startPosB,
277 ////////////////////////////////////////////////////////////////////////////////
278 // Retarded bubble sort
279 ////////////////////////////////////////////////////////////////////////////////
280 static void bubbleSort(uint *key, uint *val, uint N, uint sortDir)
287 for (uint bottom = 0; bottom < N - 1; bottom++)
289 uint savePos = bottom;
290 uint saveKey = key[bottom];
292 for (uint i = bottom + 1; i < N; i++)
294 (sortDir && (key[i] < saveKey)) ||
295 (!sortDir && (key[i] > saveKey))
302 if (savePos != bottom)
306 key[savePos] = key[bottom];
309 val[savePos] = val[bottom];
317 ////////////////////////////////////////////////////////////////////////////////
318 // Interface function
319 ////////////////////////////////////////////////////////////////////////////////
320 extern "C" void mergeSortHost(
331 uint *ikey, *ival, *okey, *oval;
334 for (uint stride = SHARED_SIZE_LIMIT; stride < N; stride <<= 1, stageCount++);
351 printf("Bottom-level sort...\n");
352 memcpy(ikey, srcKey, N * sizeof(uint));
353 memcpy(ival, srcVal, N * sizeof(uint));
355 for (uint pos = 0; pos < N; pos += SHARED_SIZE_LIMIT)
357 bubbleSort(ikey + pos, ival + pos, umin(SHARED_SIZE_LIMIT, N - pos), sortDir);
360 printf("Merge...\n");
361 uint *ranksA = (uint *)malloc(getSampleCount(N) * sizeof(uint));
362 uint *ranksB = (uint *)malloc(getSampleCount(N) * sizeof(uint));
363 uint *limitsA = (uint *)malloc(getSampleCount(N) * sizeof(uint));
364 uint *limitsB = (uint *)malloc(getSampleCount(N) * sizeof(uint));
365 memset(ranksA, 0xFF, getSampleCount(N) * sizeof(uint));
366 memset(ranksB, 0xFF, getSampleCount(N) * sizeof(uint));
367 memset(limitsA, 0xFF, getSampleCount(N) * sizeof(uint));
368 memset(limitsB, 0xFF, getSampleCount(N) * sizeof(uint));
370 for (uint stride = SHARED_SIZE_LIMIT; stride < N; stride <<= 1)
372 uint lastSegmentElements = N % (2 * stride);
374 //Find sample ranks and prepare for limiters merge
375 generateSampleRanks(ranksA, ranksB, ikey, stride, N, sortDir);
377 //Merge ranks and indices
378 mergeRanksAndIndices(limitsA, ranksA, stride, N);
379 mergeRanksAndIndices(limitsB, ranksB, stride, N);
381 //Merge elementary intervals
382 mergeElementaryIntervals(okey, oval, ikey, ival, limitsA, limitsB, stride, N, sortDir);
384 if (lastSegmentElements <= stride)
386 //Last merge segment consists of a single array which just needs to be passed through
387 memcpy(okey + (N - lastSegmentElements), ikey + (N - lastSegmentElements), lastSegmentElements * sizeof(uint));
388 memcpy(oval + (N - lastSegmentElements), ival + (N - lastSegmentElements), lastSegmentElements * sizeof(uint));