2 * Copyright 2008-2012 NVIDIA Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <thrust/detail/config.h>
18 #include <thrust/detail/temporary_array.h>
19 #include <thrust/detail/copy.h>
20 #include <thrust/system/detail/internal/scalar/sort.h>
21 #include <thrust/iterator/iterator_traits.h>
22 #include <thrust/distance.h>
23 #include <thrust/merge.h>
24 #include <tbb/parallel_invoke.h>
37 // TODO tune this based on data type and comp
38 const static int threshold = 128 * 1024;
40 template <typename DerivedPolicy, typename Iterator1, typename Iterator2, typename StrictWeakOrdering>
41 void merge_sort(execution_policy<DerivedPolicy> &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace);
43 template <typename DerivedPolicy, typename Iterator1, typename Iterator2, typename StrictWeakOrdering>
44 struct merge_sort_closure
46 execution_policy<DerivedPolicy> &exec;
47 Iterator1 first1, last1;
49 StrictWeakOrdering comp;
52 merge_sort_closure(execution_policy<DerivedPolicy> &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace)
53 : exec(exec), first1(first1), last1(last1), first2(first2), comp(comp), inplace(inplace)
56 void operator()(void) const
58 merge_sort(exec, first1, last1, first2, comp, inplace);
63 template <typename DerivedPolicy, typename Iterator1, typename Iterator2, typename StrictWeakOrdering>
64 void merge_sort(execution_policy<DerivedPolicy> &exec, Iterator1 first1, Iterator1 last1, Iterator2 first2, StrictWeakOrdering comp, bool inplace)
66 typedef typename thrust::iterator_difference<Iterator1>::type difference_type;
68 difference_type n = thrust::distance(first1, last1);
72 thrust::system::detail::internal::scalar::stable_sort(first1, last1, comp);
75 thrust::system::detail::internal::scalar::copy(first1, last1, first2);
80 Iterator1 mid1 = first1 + (n / 2);
81 Iterator2 mid2 = first2 + (n / 2);
82 Iterator2 last2 = first2 + n;
84 typedef merge_sort_closure<DerivedPolicy,Iterator1,Iterator2,StrictWeakOrdering> Closure;
86 Closure left (exec, first1, mid1, first2, comp, !inplace);
87 Closure right(exec, mid1, last1, mid2, comp, !inplace);
89 ::tbb::parallel_invoke(left, right);
91 if (inplace) thrust::merge(exec, first2, mid2, mid2, last2, first1, comp);
92 else thrust::merge(exec, first1, mid1, mid1, last1, first2, comp);
95 } // end namespace sort_detail
98 namespace sort_by_key_detail
101 // TODO tune this based on data type and comp
102 const static int threshold = 128 * 1024;
104 template <typename DerivedPolicy,
109 typename StrictWeakOrdering>
110 void merge_sort_by_key(execution_policy<DerivedPolicy> &exec,
116 StrictWeakOrdering comp,
119 template <typename DerivedPolicy,
124 typename StrictWeakOrdering>
125 struct merge_sort_by_key_closure
127 execution_policy<DerivedPolicy> &exec;
128 Iterator1 first1, last1;
132 StrictWeakOrdering comp;
135 merge_sort_by_key_closure(execution_policy<DerivedPolicy> &exec,
141 StrictWeakOrdering comp,
143 : exec(exec), first1(first1), last1(last1), first2(first2), first3(first3), first4(first4), comp(comp), inplace(inplace)
146 void operator()(void) const
148 merge_sort_by_key(exec, first1, last1, first2, first3, first4, comp, inplace);
153 template <typename DerivedPolicy,
158 typename StrictWeakOrdering>
159 void merge_sort_by_key(execution_policy<DerivedPolicy> &exec,
165 StrictWeakOrdering comp,
168 typedef typename thrust::iterator_difference<Iterator1>::type difference_type;
170 difference_type n = thrust::distance(first1, last1);
172 Iterator1 mid1 = first1 + (n / 2);
173 Iterator2 mid2 = first2 + (n / 2);
174 Iterator3 mid3 = first3 + (n / 2);
175 Iterator4 mid4 = first4 + (n / 2);
176 Iterator2 last2 = first2 + n;
177 Iterator3 last3 = first3 + n;
181 thrust::system::detail::internal::scalar::stable_sort_by_key(first1, last1, first2, comp);
185 thrust::system::detail::internal::scalar::copy(first1, last1, first3);
186 thrust::system::detail::internal::scalar::copy(first2, last2, first4);
192 typedef merge_sort_by_key_closure<DerivedPolicy,Iterator1,Iterator2,Iterator3,Iterator4,StrictWeakOrdering> Closure;
194 Closure left (exec, first1, mid1, first2, first3, first4, comp, !inplace);
195 Closure right(exec, mid1, last1, mid2, mid3, mid4, comp, !inplace);
197 ::tbb::parallel_invoke(left, right);
201 thrust::merge_by_key(exec, first3, mid3, mid3, last3, first4, mid4, first1, first2, comp);
205 thrust::merge_by_key(exec, first1, mid1, mid1, last1, first2, mid2, first3, first4, comp);
209 } // end namespace sort_detail
211 template<typename DerivedPolicy,
212 typename RandomAccessIterator,
213 typename StrictWeakOrdering>
214 void stable_sort(execution_policy<DerivedPolicy> &exec,
215 RandomAccessIterator first,
216 RandomAccessIterator last,
217 StrictWeakOrdering comp)
219 typedef typename thrust::iterator_value<RandomAccessIterator>::type key_type;
221 thrust::detail::temporary_array<key_type, DerivedPolicy> temp(exec, first, last);
223 sort_detail::merge_sort(exec, first, last, temp.begin(), comp, true);
226 template<typename DerivedPolicy,
227 typename RandomAccessIterator1,
228 typename RandomAccessIterator2,
229 typename StrictWeakOrdering>
230 void stable_sort_by_key(execution_policy<DerivedPolicy> &exec,
231 RandomAccessIterator1 first1,
232 RandomAccessIterator1 last1,
233 RandomAccessIterator2 first2,
234 StrictWeakOrdering comp)
236 typedef typename thrust::iterator_value<RandomAccessIterator1>::type key_type;
237 typedef typename thrust::iterator_value<RandomAccessIterator2>::type val_type;
239 RandomAccessIterator2 last2 = first2 + thrust::distance(first1, last1);
241 thrust::detail::temporary_array<key_type, DerivedPolicy> temp1(exec, first1, last1);
242 thrust::detail::temporary_array<val_type, DerivedPolicy> temp2(exec, first2, last2);
244 sort_by_key_detail::merge_sort_by_key(exec, first1, last1, first2, temp1.begin(), temp2.begin(), comp, true);
247 } // end namespace detail
248 } // end namespace tbb
249 } // end namespace system
250 } // end namespace thrust