
import java.util.Random;
import java.util.Scanner;

/**
 * Class that uses concurrency to merge sort a list.
 *
 * @author Dave Reed
 * @version 4/2/19
 */
public class MergeSort {

    /**
     * Top-level method for sorting an array of numbers using concurrency.
     *
     * @param a the numbers to be sorted
     * @param threadCount the number of threads that can be spawned
     */
    public static void mergeSortConcurrently(int[] a, int threadCount) {
        MergeSort.mergeSortConcurrently(a, 0, a.length - 1, threadCount);
    }

    /**
     * Helper method that recursively sorts the numbers in a range.
     *
     * @param a the numbers to be sorted
     * @param minIndex minimum index in the range
     * @param maxIndex maximum index in the range
     * @param threadCount the number of threads that can be spawned
     */
    public static void mergeSortConcurrently(int[] a, int minIndex, int maxIndex, int threadCount) {
        if (minIndex < maxIndex) {
            int mid = (minIndex + maxIndex) / 2;
            if (threadCount > 1) {
                Thread leftThread = new SortThread(a, minIndex, mid, threadCount / 2);
                Thread rightThread = new SortThread(a, mid + 1, maxIndex, threadCount / 2);
                leftThread.start();
                rightThread.start();

                try {
                    leftThread.join();
                    rightThread.join();
                } catch (InterruptedException ie) {
                }
            } else {
                MergeSort.mergeSortConcurrently(a, minIndex, mid, threadCount / 2);
                MergeSort.mergeSortConcurrently(a, mid + 1, maxIndex, threadCount / 2);
            }

            MergeSort.merge(a, minIndex, maxIndex);
        }
    }

    /**
     * Helper method that merges the two halves of a range in an array.
     *
     * @param a the numbers to be merged
     * @param minIndex minimum index in the range
     * @param maxIndex maximum index in the range
     */
    public static void merge(int[] a, int minIndex, int maxIndex) {
        int[] copy = new int[maxIndex - minIndex + 1];
        for (int i = minIndex; i <= maxIndex; i++) {
            copy[i - minIndex] = a[i];
        }

        int middle = (copy.length + 1) / 2;
        int front1 = 0;
        int front2 = middle;
        for (int i = minIndex; i <= maxIndex; i++) {
            if (front2 >= copy.length
                    || (front1 < middle && copy[front1] <= copy[front2])) {
                a[i] = copy[front1];
                front1++;
            } else {
                a[i] = copy[front2];
                front2++;
            }
        }
    }

    /////////////////////////////////////////////////////////////////////////
    
    public static void main(String[] args) {
        Random randy = new Random();
        int size = 1000;

        System.out.println("Enter the thread limit: ");
        Scanner input = new Scanner(System.in);
        int numThreads = input.nextInt();
        input.close();

        while (true) {
            int[] nums = new int[size];
            for (int j = 0; j < size; j++) {
                nums[j] = randy.nextInt();
            }

            long startTime1 = System.currentTimeMillis();
            MergeSort.mergeSortConcurrently(nums, numThreads);
            long endTime1 = System.currentTimeMillis();

            for (int k = 0; k < nums.length - 1; k++) {
                if (nums[k] > nums[k + 1]) {
                    throw new RuntimeException("Not sorted correctly");
                }
            }

            System.out.printf("%10d elements  =>  %6d ms \n", size, endTime1 - startTime1);
            size *= 2;
        }
    }
}
