Making text generator with Markov chain

Markov chain is a graph model which describes transfering from one state to another with certain probability. In daily life you can experience it, when using smart keyboard on your portable device. You may notice, that keyboard is slowly learning what you are typing and after some time, can predict next word. This characteristics can be used to create simple chat bot or application which mimics someones writing. For example, if your job requires you to fill yearly performance review, this might be helpful ;) In following paragraphs I’m going to show you, how to create your own Markov chain from provided text file. At first we need to formulate constraints about our text, which should be:

With above limitations we construct following algorithm

Prepare text:

  1. Split text using whitespaces
  2. Remove any non-latin characters from every splitted word
  3. If word contains dot character, then treat them as separate word
  4. Make every word lowercase

Make chain:

  1. Take first word from list and mark it as previous
  2. Take next word from list and add it as previous’ connection
  3. If previous doesn’t contain word as connection make counter equal 1 for this word
  4. Else if previous contains word as connection, increment counter
  5. Mark current word as previous
  6. Go to point 2 if there there are still words to process

Above algorithm will make a chain of words with structure as shown below:

       states
      / (5)
united  -- (2) nations
      \ (1)
       kingdom

As you can see, every connection has a weight, which affects probability of randomly choosing it as a next word. For initial word united we have 58 chances for choosing states, 14 for nations and 18 for kingdom as a next word. Let’s write some code.

At first we should split text into tokens.

     Arrays.stream(text.replaceAll("\\.", " .").split("\\s"))
                .map(s -> s.toLowerCase().replaceAll("[^a-z.]", ""))
                .filter(word -> word != null && !"".equals(word.trim()))
                .forEach(repository::addConnection);

As you can see, we introduced space character before every dot, which let as treat dot as a separate word. Then we lowercased every word and removed any non-latin letter or dot characters. Then we filtered out every empty word and at the end we added every word into some repository.

The repository

public class NodeRepository {
    private final Map<String, Node> nodes = new HashMap<>();
    private Node last = new Node(".");

    public boolean hasNode(String word) {
        return nodes.containsKey(word);
    }

    void addConnection(String word) {
        last.incConnectionsCountForWord(word);
        last = getNode(word);
    }

    Node getNode(String word) {
        return nodes.computeIfAbsent(word, Node::new);
    }

    Node getRandomConnection(Node node) {
        return getNode(node.getRandomWord());
    }
}

As you can see, NodeRepository is a class which contains a map of nodes, which are simple containers for words. Every new connection is being added to last used node and current (already added) node is set as a last one. We can also retrieve any node, check if it already exist in repository or get random connection from the node of our interest.

The node

class Node {
    private static final Pattern wordPattern = Pattern.compile("[a-z.]+");
    static Random random = new Random();

    private final String word;
    private int connectionsTotal = 0;
    private Map<String, Integer> connections = new HashMap<>();

    Node (String word) {
        if (word == null || !wordPattern.matcher(word).matches()) {
            throw new IllegalArgumentException("Word '" + word + "' argument must contain letters or dot character.");
        }
        this.word = word;
    }

    int getConnectionsTotal() {
        return connectionsTotal;
    }

    String getRandomWord() {
        final int rnd = random.nextInt(connectionsTotal);
        final AtomicInteger i = new AtomicInteger(0);
        return connections.entrySet().stream()
                .filter(e -> rnd < i.addAndGet(e.getValue()))
                .findFirst().get().getKey();
    }

    void incConnectionsCountForWord(String word) {
        Integer cnt = ofNullable(connections.get(word)).orElse(0);
        connectionsTotal++;
        connections.put(word, ++cnt);
    }

    String getWord() {
        return word;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        Node node = (Node) o;
        return word.equals(node.word);
    }

    @Override
    public int hashCode() {
        return word.hashCode();
    }

}

The node consists of word pattern, the word variable, connections count for current word and connections map which is being used to get random connection node. Here comes a little magic, because we need to draw a random connection using weights. Without them we’d just just random.nextInt(connectionsTotal) but our problem is a little more complicated. But it’s not rocket science. So, let’s assume we’ve got 9 total connections. 3 for X, 2 for Y and 4 for Z. We can spread them out on the axis so, Z will take 4 units, X 3 and Y will take 2 like below:

  Z4  X3  Y2
|====|===|==|
 0123 456 78
       r

Now we can draw a random number from 0 to 8 and check in which secion it will fit. There are two approchaches for that (actually, there are more complicated algorithms for that, but we will cover only two simplest) we can create an array of possible choices, which will save us time, but no memory. This will give us time complexity of O(1), but memory complexity of O(n^2), where n is connections count. We can also just iterate every element and check if our random number fits the range. This will give us O(n) time complexity and also O(n) memory complexity. The algorithm is following:

  1. Draw random number r, less or equal n
  2. Set i := 0
  3. Get next element from connections collection
  4. Set i := i + element.weight
  5. If r < i then return element
  6. Go to point 3.

In our case r = 5, so when we’re processing word Z we’re using inequality 5 < 4. This statement is false, so we’re processing another word, X for which i = Z + Y = 4 + 3 = 7, so we’re having inequality 5 < 7 which is true and now we know, that we’ve drawn word X.

 String getRandomWord() {
        final int rnd = random.nextInt(connectionsTotal);
        final AtomicInteger i = new AtomicInteger(0);
        return connections.entrySet().stream()
                .filter(e -> rnd < i.addAndGet(e.getValue()))
                .findFirst().get().getKey();
    }

After we get random word, we can use NodeRepository#getNode(word) method to retrieve node for it and then repeat from the beginning as long as we reach an end point, which can be dot character (if we want to create just single sentence) or final word count in output string. Such code may look like below example from my TextProcessor class:

public String buildText(String word) {
        StringBuilder sb = new StringBuilder(word);
        sb.append(" ");
        Node node = repository.getNode(word);
        int cnt = 0;

        while (!(node = repository.getRandomNode(node)).getWord().equals(".") && cnt < SENTENCE_SIZE_LIMIT) {
            sb.append(node.getWord()).append(" ");
            cnt++;
        }

        return sb.append('.').toString();
    }

And that is roughly as much as we need to do. We can get a little trouble with testing code which peeks random connection from word, but we can solve this by mocking random number generator:

class RandomMock extends Random {
    private Iterator<Integer> iterator;

    RandomMock(Collection<Integer> collection) {
        this.iterator = collection.iterator();
    }

    @Override
    public int nextInt(int bound) {
        return iterator.next();
    }
}

public class NodeTest {
//...

@Test
    public void shouldProperlyReturnRandomConnection() {
        List<Integer> numbers = range(0, 16).boxed().collect(Collectors.toList());
        Collections.shuffle(numbers);
        Node.random = new RandomMock(numbers);

        Node node = new Node("test");
        rangeClosed(1, 10).forEach(i -> node.incConnectionsCountForWord("this"));
        rangeClosed(1, 5).forEach(i -> node.incConnectionsCountForWord("that"));
        node.incConnectionsCountForWord("nothing");

        assertThat(node.getConnectionsTotal()).isEqualTo(16);
        Map<String, Integer> randomWords = new HashMap<>();

        for (int i = 0; i < 16; i++) {
            String word = node.getRandomWord();
            Integer cnt = ofNullable(randomWords.get(word)).orElse(0);
            randomWords.put(word, ++cnt);
        }

        assertThat(randomWords.get("this")).isEqualTo(10);
        assertThat(randomWords.get("that")).isEqualTo(5);
        assertThat(randomWords.get("nothing")).isEqualTo(1);
    }
}

In above code in RandomMock class we overrided Random#nextInt method, which now returns numbers from collection provided by us. Here numbers are being shuffled, but we know, that this will always be collection of integrals from 0 to 15 inclusive. So even when words are being returned randomly, we still know that distribution is flat.

You can check the full code built with Spring Boot, Bootstrap and jQuery on my github and try live demo on Heroku. For input data you can easily use Gutenberg Project.