Skip to content
Advertisement

Perform operation on n random distinct elements from Collection using Streams API

I’m attempting to retrieve n unique random elements for further processing from a Collection using the Streams API in Java 8, however, without much or any luck.

More precisely I’d want something like this:

Set<Integer> subList = new HashSet<>();
Queue<Integer> collection = new PriorityQueue<>();
collection.addAll(Arrays.asList(1,2,3,4,5,6,7,8,9));
Random random = new Random();
int n = 4;
while (subList.size() < n) {
  subList.add(collection.get(random.nextInt()));
}
sublist.forEach(v -> v.doSomethingFancy());

I want to do it as efficiently as possible.

Can this be done?

edit: My second attempt — although not exactly what I was aiming for:

List<Integer> sublist = new ArrayList<>(collection);
Collections.shuffle(sublist);
sublist.stream().limit(n).forEach(v -> v.doSomethingFancy());

edit: Third attempt (inspired by Holger), which will remove a lot of the overhead of shuffle if coll.size() is huge and n is small:

int n = // unique element count
List<Integer> sublist = new ArrayList<>(collection);   
Random r = new Random();
for(int i = 0; i < n; i++)
    Collections.swap(sublist, i, i + r.nextInt(source.size() - i));
sublist.stream().limit(n).forEach(v -> v.doSomethingFancy());

Advertisement

Answer

The shuffling approach works reasonably well, as suggested by fge in a comment and by ZouZou in another answer. Here’s a generified version of the shuffling approach:

static <E> List<E> shuffleSelectN(Collection<? extends E> coll, int n) {
    assert n <= coll.size();
    List<E> list = new ArrayList<>(coll);
    Collections.shuffle(list);
    return list.subList(0, n);
}

I’ll note that using subList is preferable to getting a stream and then calling limit(n), as shown in some other answers, because the resulting stream has a known size and can be split more efficiently.

The shuffling approach has a couple disadvantages. It needs to copy out all the elements, and then it needs to shuffle all the elements. This can be quite expensive if the total number of elements is large and the number of elements to be chosen is small.

An approach suggested by the OP and by a couple other answers is to choose elements at random, while rejecting duplicates, until the desired number of unique elements has been chosen. This works well if the number of elements to choose is small relative to the total, but as the number to choose rises, this slows down quite a bit because of the likelihood of choosing duplicates rises as well.

Wouldn’t it be nice if there were a way to make a single pass over the space of input elements and choose exactly the number wanted, with the choices made uniformly at random? It turns out that there is, and as usual, the answer can be found in Knuth. See TAOCP Vol 2, sec 3.4.2, Random Sampling and Shuffling, Algorithm S.

Briefly, the algorithm is to visit each element and decide whether to choose it based on the number of elements visited and the number of elements chosen. In Knuth’s notation, suppose you have N elements and you want to choose n of them at random. The next element should be chosen with probability

(n – m) / (N – t)

where t is the number of elements visited so far, and m is the number of elements chosen so far.

It’s not at all obvious that this will give a uniform distribution of chosen elements, but apparently it does. The proof is left as an exercise to the reader; see Exercise 3 of this section.

Given this algorithm, it’s pretty straightforward to implement it in “conventional” Java by looping over the collection and adding to the result list based on the random test. The OP asked about using streams, so here’s a shot at that.

Algorithm S doesn’t lend itself obviously to Java stream operations. It’s described entirely sequentially, and the decision about whether to select the current element depends on a random decision plus state derived from all previous decisions. That might make it seem inherently sequential, but I’ve been wrong about that before. I’ll just say that it’s not immediately obvious how to make this algorithm run in parallel.

There is a way to adapt this algorithm to streams, though. What we need is a stateful predicate. This predicate will return a random result based on a probability determined by the current state, and the state will be updated — yes, mutated — based on this random result. This seems hard to run in parallel, but at least it’s easy to make thread-safe in case it’s run from a parallel stream: just make it synchronized. It’ll degrade to running sequentially if the stream is parallel, though.

The implementation is pretty straightforward. Knuth’s description uses random numbers between 0 and 1, but the Java Random class lets us choose a random integer within a half-open interval. Thus all we need to do is keep counters of how many elements are left to visit and how many are left to choose, et voila:

/**
 * A stateful predicate that, given a total number
 * of items and the number to choose, will return 'true'
 * the chosen number of times distributed randomly
 * across the total number of calls to its test() method.
 */
static class Selector implements Predicate<Object> {
    int total;  // total number items remaining
    int remain; // number of items remaining to select
    Random random = new Random();

    Selector(int total, int remain) {
        this.total = total;
        this.remain = remain;
    }

    @Override
    public synchronized boolean test(Object o) {
        assert total > 0;
        if (random.nextInt(total--) < remain) {
            remain--;
            return true;
        } else {
            return false;
        }
    }
}

Now that we have our predicate, it’s easy to use in a stream:

static <E> List<E> randomSelectN(Collection<? extends E> coll, int n) {
    assert n <= coll.size();
    return coll.stream()
        .filter(new Selector(coll.size(), n))
        .collect(toList());
}

An alternative also mentioned in the same section of Knuth suggests choosing an element at random with a constant probability of n / N. This is useful if you don’t need to choose exactly n elements. It’ll choose n elements on average, but of course there will be some variation. If this is acceptable, the stateful predicate becomes much simpler. Instead of writing a whole class, we can simply create the random state and capture it from a local variable:

/**
 * Returns a predicate that evaluates to true with a probability
 * of toChoose/total.
 */
static Predicate<Object> randomPredicate(int total, int toChoose) {
    Random random = new Random();
    return obj -> random.nextInt(total) < toChoose;
}

To use this, replace the filter line in the stream pipeline above with

        .filter(randomPredicate(coll.size(), n))

Finally, for comparison purposes, here’s an implementation of the selection algorithm written using conventional Java, that is, using a for-loop and adding to a collection:

static <E> List<E> conventionalSelectN(Collection<? extends E> coll, int remain) {
    assert remain <= coll.size();
    int total = coll.size();
    List<E> result = new ArrayList<>(remain);
    Random random = new Random();

    for (E e : coll) {
        if (random.nextInt(total--) < remain) {
            remain--;
            result.add(e);
        }
    }            

    return result;
}

This is quite straightforward, and there’s nothing really wrong with this. It’s simpler and more self-contained than the stream approach. Still, the streams approach illustrates some interesting techniques that might be useful in other contexts.


Reference:

Knuth, Donald E. The Art of Computer Programming: Volume 2, Seminumerical Algorithms, 2nd edition. Copyright 1981, 1969 Addison-Wesley.

Advertisement