OSDN Git Service

Implement kth order selection algorithm in new class Select
authorJon Renner <rennerjc@gmail.com>
Mon, 16 Sep 2013 10:13:34 +0000 (18:13 +0800)
committerJon Renner <rennerjc@gmail.com>
Thu, 19 Sep 2013 03:23:21 +0000 (11:23 +0800)
classes Select, QuickSelect and Array.selectRanked method

gdx/src/com/badlogic/gdx/utils/Array.java
gdx/src/com/badlogic/gdx/utils/QuickSelect.java [new file with mode: 0644]
gdx/src/com/badlogic/gdx/utils/Select.java [new file with mode: 0644]
tests/gdx-tests/src/com/badlogic/gdx/tests/SelectTest.java [new file with mode: 0644]
tests/gdx-tests/src/com/badlogic/gdx/tests/utils/GdxTests.java

index d878f30..325fe31 100644 (file)
@@ -335,6 +335,35 @@ public class Array<T> implements Iterable<T> {
                Sort.instance().sort(items, comparator, 0, size);\r
        }\r
 \r
+       /** Selects the nth-lowest element from the Array according to Comparator ranking.\r
+        * @see Select\r
+        * @param comparator used for comparison\r
+        * @param nth_lowest rank of desired object according to comparison,\r
+        * n is based on ordinal numbers, not array indices.\r
+        * for min value use 1, for max value use size of array, using 0 results in runtime exception.\r
+        * @return the value of the Nth lowest ranked object.\r
+        */\r
+       public T selectRanked(Comparator<T> comparator, int nth_lowest) {\r
+               if (nth_lowest < 1) {\r
+                       throw new GdxRuntimeException("nth_lowest must be greater than 0, 1 = first, 2 = second...");\r
+               }\r
+               return Select.instance().select(items, comparator, nth_lowest, size);\r
+       }\r
+\r
+       /** @see Array#selectRanked(java.util.Comparator, int)\r
+       * @param comparator used for comparison\r
+        * @param nth_lowest rank of desired object according to comparison,\r
+        * n is based on ordinal numbers, not array indices.\r
+        * for min value use 1, for max value use size of array, using 0 results in runtime exception.\r
+        * @return the index of the Nth lowest ranked object.\r
+        */\r
+       public int selectRankedIndex(Comparator<T> comparator, int nth_lowest) {\r
+               if (nth_lowest < 1) {\r
+                       throw new GdxRuntimeException("nth_lowest must be greater than 0, 1 = first, 2 = second...");\r
+               }\r
+               return Select.instance().selectIndex(items, comparator, nth_lowest, size);\r
+       }\r
+\r
        public void reverse () {\r
                T[] items = this.items;\r
                for (int i = 0, lastIndex = size - 1, n = size / 2; i < n; i++) {\r
@@ -363,7 +392,7 @@ public class Array<T> implements Iterable<T> {
        }\r
 \r
        /** Returns an iterable for the selected items in the array. Remove is supported, but not between hasNext() and next(). Note\r
-        * that the same iteratable instance is returned each time this method is called. Use the {@link Predicate.PredicateIterable}\r
+        * that the same iterable instance is returned each time this method is called. Use the {@link Predicate.PredicateIterable}\r
         * constructor for nested or multithreaded iteration. */\r
        public Iterable<T> select (Predicate<T> predicate) {\r
                if (predicateIterable == null)\r
diff --git a/gdx/src/com/badlogic/gdx/utils/QuickSelect.java b/gdx/src/com/badlogic/gdx/utils/QuickSelect.java
new file mode 100644 (file)
index 0000000..00ac18e
--- /dev/null
@@ -0,0 +1,87 @@
+package com.badlogic.gdx.utils;
+
+import java.util.Comparator;
+
+
+/**
+ * Implementation of Tony Hoare's quickselect algorithm.
+ * Running time is generally O(n), but worst case is O(n**2)
+ * Pivot choice is median of three method, providing better performance
+ * than a random pivot for partially sorted data.
+ * @author Jon Renner
+ */
+public class QuickSelect<T> {
+       private T[] array;
+       private Comparator<? super T> comp;
+
+       public int select(T[] items, Comparator<T> comp, int n, int size) {
+               this.array = items;
+               this.comp = comp;
+               return recursiveSelect(0, size - 1, n);
+       }
+
+       private int partition(int left, int right, int pivot) {
+               T pivotValue = array[pivot];
+               swap(right, pivot);
+               int storage = left;
+               for (int i = left; i < right; i++) {
+                       if (comp.compare(array[i], pivotValue) == -1) {
+                               swap(storage, i);
+                               storage++;
+                       }
+               }
+               swap(right, storage);
+               return storage;
+       }
+
+       private int recursiveSelect(int left, int right, int k) {
+               if (left == right) return left;
+               int pivotIndex = medianOfThreePivot(left, right);
+               int pivotNewIndex = partition(left, right, pivotIndex);
+               int pivotDist = (pivotNewIndex - left) + 1;
+               int result;
+               if (pivotDist == k) {
+                       result = pivotNewIndex;
+               }
+               else if (k < pivotDist) {
+                       result = recursiveSelect(left, pivotNewIndex - 1, k);
+               } else {
+                       result = recursiveSelect(pivotNewIndex + 1, right, k - pivotDist);
+               }
+               return result;
+       }
+
+       /** Median of Three has the potential to outperform a random pivot, especially for partially sorted arrays */
+       private int medianOfThreePivot(int leftIdx, int rightIdx) {
+               T left = array[leftIdx];
+               int midIdx = (leftIdx + rightIdx) / 2;
+               T mid = array[midIdx];
+               T right = array[rightIdx];
+
+               // spaghetti median of three algorithm
+               // does at most 2 comparisons
+               if (comp.compare(left, mid) > 0) {
+                       if (comp.compare(mid, right) > 0) {
+                               return midIdx;
+                       } else if (comp.compare(left, right) > 0) {
+                               return rightIdx;
+                       } else {
+                               return leftIdx;
+                       }
+               } else {
+                       if (comp.compare(left, right) > 0) {
+                               return leftIdx;
+                       } else if (comp.compare(mid, right) > 0) {
+                               return rightIdx;
+                       } else {
+                               return midIdx;
+                       }
+               }
+       }
+
+       private void swap(int left, int right) {
+               T tmp = array[left];
+               array[left] = array[right];
+               array[right] = tmp;
+       }
+}
diff --git a/gdx/src/com/badlogic/gdx/utils/Select.java b/gdx/src/com/badlogic/gdx/utils/Select.java
new file mode 100644 (file)
index 0000000..772920e
--- /dev/null
@@ -0,0 +1,73 @@
+package com.badlogic.gdx.utils;
+
+import java.util.Comparator;
+
+/**
+ * This class is for selecting a ranked element (kth ordered statistic) from
+ * an unordered list in faster time than sorting the whole array.
+ * Typical applications include finding the nearest enemy unit(s), and other
+ * operations which are likely to run as often as every x frames.
+ * <p>The lowest ranking element starts at 1, not 0. 1 = first, 2 = second, 3 = third, etc.
+ * calling with a value of zero will result in a {@link GdxRuntimeException} </p>
+ * <p> This class uses very minimal extra memory, as it makes no copies of the array.
+ * The underlying algorithms used are a naive single-pass for k=min and k=max, and Hoare's
+ * quickselect for values in between. </p> 
+ * @author Jon Renner
+ */
+public class Select {
+       private static Select instance;
+       private QuickSelect quickSelect;
+
+       /** Provided for convenience */
+       public static Select instance() {
+               if (instance == null) instance = new Select();
+               return instance;
+       }
+
+       public <T> T select (T[] items, Comparator<T> comp, int kthLowest, int size) {
+               int idx = selectIndex(items, comp, kthLowest, size);
+               return items[idx];
+       }
+
+       public <T> int selectIndex(T[] items, Comparator<T> comp, int kthLowest, int size) {
+               if (size < 1) throw new GdxRuntimeException("cannot select from empty array (size < 1)");
+               int idx;
+               // naive partial selection sort almost certain to outperform quickselect where n is min or max
+               if (kthLowest == 1) {
+                       // find min
+                       idx = fastMin(items, comp, size);
+               } else if (kthLowest == size) {
+                       // find max
+                       idx = fastMax(items, comp, size);
+               } else {
+                       // quickselect a better choice for cases of k between min and max
+                       if (quickSelect == null) quickSelect = new QuickSelect();
+                       idx = quickSelect.select(items, comp, kthLowest, size);
+               }
+               return idx;
+       }
+
+       /** Faster than quickselect for n = min */
+       private <T> int fastMin(T[] items, Comparator<T> comp, int size) {
+               int lowestIdx = 0;
+               for (int i = 1; i < size; i++) {
+                       int comparison = comp.compare(items[i], items[lowestIdx]);
+                       if (comparison < 0) {
+                               lowestIdx = i;
+                       }
+               }
+               return lowestIdx;
+       }
+
+       /** Faster than quickselect for n = max */
+       private <T> int fastMax(T[] items, Comparator<T> comp, int size) {
+               int highestIdx = 0;
+               for (int i = 1; i < size; i++) {
+                       int comparison = comp.compare(items[i], items[highestIdx]);
+                       if (comparison > 0) {
+                               highestIdx = i;
+                       }
+               }
+               return highestIdx;
+       }
+}
diff --git a/tests/gdx-tests/src/com/badlogic/gdx/tests/SelectTest.java b/tests/gdx-tests/src/com/badlogic/gdx/tests/SelectTest.java
new file mode 100644 (file)
index 0000000..b62a3b0
--- /dev/null
@@ -0,0 +1,266 @@
+package com.badlogic.gdx.tests;
+
+import com.badlogic.gdx.Gdx;
+import com.badlogic.gdx.math.MathUtils;
+import com.badlogic.gdx.math.Vector2;
+import com.badlogic.gdx.tests.utils.GdxTest;
+import com.badlogic.gdx.utils.Array;
+import com.badlogic.gdx.utils.GdxRuntimeException;
+import com.badlogic.gdx.utils.PerformanceCounter;
+
+import java.util.Comparator;
+
+/**
+ * For testing and benchmarking of gdx.utils.Select and its associated algorithms/classes
+ * @author Jon renner
+ */
+public class SelectTest extends GdxTest {
+       private static PerformanceCounter perf = new PerformanceCounter("bench");
+       private static boolean verify; // verify and report the results of each selection
+       private static boolean quiet;
+
+       @Override
+       public void create() {
+               int n = 100;
+               player = createDummies(n);
+               enemy = createDummies(n);
+
+               int runs = 100;
+               // run correctness first to warm up the JIT and other black magic
+               quiet = true;
+               allRandom();
+               print("VERIFY CORRECTNESS FIND LOWEST RANKED");
+               correctnessTest(runs, 1);
+               print("VERIFY CORRECTNESS FIND MIDDLE RANKED");
+               correctnessTest(runs, enemy.size / 2);
+               print("VERIFY CORRECTNESS FIND HIGHEST RANKED");
+               correctnessTest(runs, enemy.size);
+
+               runs = 1000;
+               quiet = true;
+               print("BENCHMARK FIND LOWEST RANKED");
+               performanceTest(runs, 1);
+               print("BENCHMARK FIND MIDDLE RANKED");
+               performanceTest(runs, enemy.size / 2);
+               print("BENCHMARK FIND HIGHEST RANKED");
+               performanceTest(runs, enemy.size);
+
+               print("TEST CONSISTENCY FOR LOWEST RANKED");
+               consistencyTest(runs, 1);
+               print("TEST CONSISTENCY FOR MIDDLE RANKED");
+               consistencyTest(runs, enemy.size / 2);
+               print("TEST CONSISTENCY FOR HIGHEST RANKED");
+               consistencyTest(runs, enemy.size);
+               
+               // test that selectRanked and selectRankedIndex return the same
+               print("TEST selectRanked AND selectRankedIndex RETURN MATCHING RESULTS - LOWEST RANKED");
+               testValueMatchesIndex(runs, 1);
+               print("TEST selectRanked AND selectRankedIndex RETURN MATCHING RESULTS - MIDDLE RANKED");
+               testValueMatchesIndex(runs, enemy.size / 2);
+               print("TEST selectRanked AND selectRankedIndex RETURN MATCHING RESULTS - HIGHEST RANKED");
+               testValueMatchesIndex(runs, enemy.size);
+               
+               print("ALL TESTS PASSED");
+       }
+
+       public static void correctnessTest(int runs, int k) {
+               String msg = String.format("[%d runs with %dx%d dummy game units] - ",
+                               runs, player.size, enemy.size);
+               verify = true;
+               test(runs, k);
+               print(msg + "VERIFIED");
+       }
+
+       public static void performanceTest(int runs, int k) {
+               verify = false;
+               test(runs, k);
+               String msg = String.format("[%d runs with %dx%d dummy game units] - ",
+                               runs, player.size, enemy.size);
+               print(msg + String.format("avg: %.5f, min/max: %.4f/%.4f, total time: %.3f (ms), made %d comparisons",
+                               allPerf.time.min, allPerf.time.max, allPerf.time.average * 1000, allPerf.time.total * 1000,
+                               comparisonsMade));
+       }
+
+       public static void consistencyTest(int runs, int k) {
+               verify = false;
+               Dummy test = player.get(0);
+               Dummy lastFound = null;
+               allRandom();
+               for (int i = 0; i < runs; i++) {
+                       Dummy found = test.getKthNearestEnemy(k);
+                       if (lastFound == null) {
+                               lastFound = found;
+                       } else {
+                               if (!(lastFound.equals(found))) {
+                                       print("CONSISTENCY TEST FAILED");
+                                       print("lastFound: " + lastFound);
+                                       print("justFound: " + found);
+                                       throw new GdxRuntimeException("test failed");
+                               }
+                       }
+               }
+       }
+       
+       public static void testValueMatchesIndex(int runs, int k) {
+               verify = false;
+               for (int i = 0; i < runs; i++) {
+                       allRandom();
+                       player.shuffle();
+                       enemy.shuffle();
+                       originDummy = player.random();
+                       int idx = enemy.selectRankedIndex(distComp, k);
+                       Dummy indexDummy = enemy.get(idx);
+                       Dummy valueDummy = enemy.selectRanked(distComp, k);
+                       if (!(indexDummy.equals(valueDummy))) {
+                               throw new GdxRuntimeException("results of selectRankedIndex and selectRanked do not return the same object\n" +
+                                       "selectRankedIndex -> " + indexDummy + "\n" +
+                                       "selectRanked      -> " + valueDummy);
+                       }
+                       
+               }                               
+       }
+
+       public static void test(int runs, int k) {
+               // k = kth order statistic
+               comparisonsMade = 0;
+               perf.reset();
+               allPerf.reset();
+               allRandom();
+               enemy.shuffle();
+               player.shuffle();
+               for (int i = 0; i < runs; i++) {
+                       getKthNearestEnemy(quiet, k);
+               }
+       }
+
+       public static void allRandom() {
+               for (Dummy d : player) {
+                       d.setRandomPos();
+               }
+               for (Dummy d : enemy) {
+                       d.setRandomPos();
+               }
+       }
+
+       private static PerformanceCounter allPerf = new PerformanceCounter("all");
+       public static void getKthNearestEnemy(boolean silent, int k) {
+               Dummy kthDummy = null;
+               perf.reset();
+               allPerf.start();
+               for (Dummy d : player) {
+                       Dummy found = d.getKthNearestEnemy(k);
+               }
+               allPerf.stop();
+               allPerf.tick();
+               if (silent) return;
+               print(String.format("found nearest. min: %.4f, max: %.4f, avg: %.4f, total: %.3f ms",
+                               perf.time.min * 1000, perf.time.max * 1000, perf.time.average * 1000, perf.time.total * 1000));
+       }
+
+       public static void verifyCorrectness(Dummy d, int k) {
+               enemy.sort(distComp);
+               int idx = enemy.indexOf(d, true);
+               // remember that k = min value = 0 position in the array, therefore k - 1
+               //if (idx != k - 1) {
+               //print("verified - idx: " + idx + ", (k - 1): " + (k - 1));
+               if (enemy.get(idx) != enemy.get(k - 1)) {
+                       System.out.println("origin dummy: " + originDummy);
+                       System.out.println("TEST FAILURE: " + "idx: " + idx + " does not equal (k - 1): " + (k - 1));
+                       throw new GdxRuntimeException("test failed");
+               }
+       }
+
+       static class Dummy {
+               public Vector2 pos;
+               public int id;
+
+               public Dummy() {
+                       // set the position manually
+               }
+
+               @Override
+               public boolean equals(Object obj) {
+                       if (!(obj instanceof Dummy)) {
+                               throw new GdxRuntimeException("do not compare to anything but other Dummy objects");
+                       }
+                       Dummy d = (Dummy) obj;
+                       // we only care about position/distance
+                       float epsilon = 0.0001f;
+                       float diff = Math.abs(d.pos.x - this.pos.x) + Math.abs(d.pos.y - this.pos.y);
+                       if (diff > epsilon)
+                               return false;
+                       return true;
+
+               }
+
+               public Dummy getKthNearestEnemy(int k) {
+                       perf.start();
+                       originDummy = this;
+                       Dummy found = enemy.selectRanked(distComp, k);
+                       //print(this + " found enemy: " + found);
+                       perf.stop();
+                       perf.tick();
+                       if (verify) {
+                               verifyCorrectness(found, k);
+                       }
+                       return found;
+               }
+
+               public void setRandomPos() {
+                       float max = 100;
+                       this.pos.x = -max + MathUtils.random(max * 2);
+                       this.pos.y = -max + MathUtils.random(max * 2);
+                       float xShift = 100;
+                       if (player.contains(this, true)) {
+                               this.pos.x -= xShift;
+                       } else if (enemy.contains(this, true)) {
+                               this.pos.x += xShift;
+                       } else {
+                               throw new RuntimeException("unhandled");
+                       }
+               }
+
+               @Override
+               public String toString() {
+                       return String.format("Dummy at: %.2f, %.2f", pos.x, pos.y);
+               }
+       }
+
+       public static int nextID = 1;
+       public static Array<Dummy> player;
+       public static Array<Dummy> enemy;
+
+       public static Array<Dummy> createDummies(int n) {
+               float variance = 20;
+               Array<Dummy> dummies = new Array<Dummy>();
+               for (int i = 0; i < n; i++) {
+                       Dummy d = new Dummy();
+                       dummies.add(d);
+                       d.pos = new Vector2();
+                       d.id = nextID++;
+               }
+               return dummies;
+       }
+
+       private static Dummy originDummy;
+       private static long comparisonsMade = 0;
+       static Comparator<Dummy> distComp = new Comparator<Dummy>() {
+               @Override
+               public int compare(Dummy o1, Dummy o2) {
+                       comparisonsMade++;
+                       float d1 = originDummy.pos.dst2(o1.pos);
+                       float d2 = originDummy.pos.dst2(o2.pos);
+                       float diff = d1 - d2;
+                       if (diff < 0) return -1;
+                       if (diff > 0) return 1;
+                       return 0;
+               }
+       };
+
+       public static void print(Object ...objs) {
+               for (Object o : objs) {
+                       System.out.print(o);
+               }
+               System.out.println();
+       }
+}
index d9af717..b2b6a33 100644 (file)
@@ -72,6 +72,7 @@ public class GdxTests {
                OrthoCamBorderTest.class, ParallaxTest.class, ParticleEmitterTest.class, PickingTest.class, PixelsPerInchTest.class,\r
                PixmapBlendingTest.class, PixmapTest.class, PixmapPackerTest.class, PolygonRegionTest.class, PolygonSpriteTest.class, PreferencesTest.class,\r
                ProjectiveTextureTest.class, Pong.class, ProjectTest.class, RemoteTest.class, RotationTest.class, DragAndDropTest.class,\r
+               SelectTest.class,\r
                ShaderMultitextureTest.class, ShadowMappingTest.class, PathTest.class, SimpleAnimationTest.class, SimpleDecalTest.class,\r
                SimpleStageCullingTest.class, SoundTest.class, SpriteCacheTest.class, SpriteCacheOffsetTest.class, LetterBoxTest1.class,\r
                SpriteBatchRotationTest.class, SpriteBatchShaderTest.class, SpriteBatchTest.class, SpritePerformanceTest.class,\r