Posted on Wed 04 September 2013

Refactoring Scala Code From Imperative to Functional

The Scala programming language is unique in terms of FP languages in that it lets you choose your own balance between pure FP and filthy imperative style. This is good and bad, but one benefit is that when you have an algorithm that doesn't immediately translate well into a functional style, you can implement it imperatively and refactor it later to be more functional.

This example is from about a year ago, and I'll take you through a few steps of my refactor. I'm still very much a Scala beginner, so keep in mind, there may be some mistakes.

The Problem, and the Algorithm

I was working on a proof-of-concept middleware TCP load balancer using Netty and Zookeeper. One feature that was going to be important was the ability to do "weighted round-robin" load balancing. At my job, we deploy code to new servers rather than upgrading code on existing servers. Then, we move new servers into the load balancer to old ones out. With our current load balancing solution, there wasn't an easy way to shift live traffic incrementally, so that was the motivation behind weighted round robin.

With a little bit of research on Wikipedia, I found a pretty straightforward algorithm used in packet switching. Then, I found the following C/psuedo-code implementation from this site.

Supposing that there is a server set S = {S0, S1, , Sn-1};
W(Si) indicates the weight of Si;
i indicates the server selected last time, and i is initialized with -1;
cw is the current weight in scheduling, and cw is initialized with zero;
max(S) is the maximum weight of all the servers in S;
gcd(S) is the greatest common divisor of all server weights in S;

while (true) {
    i = (i + 1) mod n;
    if (i == 0) {
        cw = cw - gcd(S);
        if (cw <= 0) {
            cw = max(S);
            if (cw == 0)
            return NULL;
        }
    }
    if (W(Si) >= cw)
        return Si;
}

First crack at it

My first Scala implementation was almost exactly like this. I used vars for mutable state and a recursive call to an inner function for the while loop.

import scala.reflect.ClassTag
import scala.annotation.tailrec

type Weight = Int

class WeightedRRIterator[A: ClassTag](val itemsToWeights: Map[A, Weight]) {
  val items: Array[A] = itemsToWeights.keys.toArray
  val total = items.length

    // Integer does not have gcd but BigInt does.
  lazy val gcdWeights = itemsToWeights.values.map(BigInt(_)).reduce(_.gcd(_)).toInt
  lazy val maxWeight = itemsToWeights.values.max

  private var i = -1
  private var currentWeight = 0

  def hasNext = total != 0
  def next: A = {
  if (!hasNext) throw new NoSuchElementException("Called next on empty iterator")
    // this inner function will get called at worst case, `items.size`
    // times for each next() call
    @tailrec
    def doNext(): A =  {
      i = (i + 1) % total
      if (i == 0) {
        currentWeight -= gcdWeights
        if (currentWeight <= 0) currentWeight = maxWeight
      }
      if (itemsToWeights(items(i)) >= currentWeight) items(i) else doNext()
    }
    doNext()
  }
}

val iter = new WeightedRRIterator(Map("a" -> 5, "b" -> 3))
for (i <- 1 to 20) { print(iter.next) } // prints "aaabababaaabababaaab"

Note

A few things unrelated to the algorithm: the [A: ClassTag] is basically a saying that generic type A must also pass along a ClassTag. This is an unfortunate side effect of Java type erasure. It's needed because I'm creating an Array of type A, which is of an unknown type at runtime. In Scala, this is all you need to do, but when trying to use this class as a library from Java, I wasn't able to figure out how to pass the ClassTag and the generated type signatures of the functions were pretty dense. Some other data structures don't have this problem, but Array was chosen for constant time access since we're indexing by a counter.

The this.synchronized is the equivalent of the Java synchronized keyword and only important if you plan on sharing the iterator between multiple threads (not recommended!).

The @tailreq annotation does not actually affect the function other than to give you a compile-time error if the function has a non-tail call. It's not clear the function is any better than a while-loop, but it does help us to avoid needing an explicit return.

The idea of this algorithm is mostly simple, yet it's difficult to reason out of the code. One key to refactoring was to get a much more intuitive understanding of what the code is actually doing. It becomes very clear when illustrated. Imagine we have weights of [10, 5, 20] for hosts a, b, c, respectively. The gcd is 5, max weight is 20, and amount is 3.

i currentWeight selected
0 20 --
1 20 --
2 20 c
0 15 --
1 15 --
2 15 c
0 10 a
1 10 --
2 10 c
0 5 a
1 5 b
2 5 c
0 20 --
1 20 --
2 20 c

... and so on.

So the gist of it is i is an index that cycles through out array, currentWeight cycles between maxWeight and 0 decrementing by gcd. For each value of currentWeight, we yield a new value only if the weight of that item is high enough. So for every 4 loops of currentWeight, a is selected 2 times, b is selected 1 time, and c is selected 4 times. This is what our weights of 10, 5 and 20 reduce to, so it works!

There's a few things that make this difficult to read from the code (for me, anyway). One thing is we're indexing an array rather than using some kind of iterator. This is caveman-style of iterating through arrays. It's error-prone and low level. The other problem, from an FP perspective is those pesky vars. Mutable data is shunned in functional programming.

Finally, the intention of the programmer is lost in the details of the implementation. All the arithmetic and updating counters is a lot to keep track of. The code is not very expressive.

Let's Get Functional

It turns out Scala for-comprehensions are a really nice way to express this. The following is analogous to a nested for-loop in an imperative language:

val itemsToWeights = Map("a" -> 10, "b" -> 5, "c" -> 20)
val gcdWeights = 5 // not to be hardcoded in the real code
val maxWeight = itemsToWeights.values.max

val iterator = for {
    weight <- gcdWeights to maxWeight by gcdWeights;
    item <- itemsToWeights.keys if (itemsToWeights(item) >= weight)
} yield item

println(iterator.toList) // prints List(a, b, c, a, c, c, c)

We're iterating over two sequences here, we're iterating over a scala.collection.immutable.Range(5, 10, 15, 20) and for each iteration of that, we're iterating through ("a", "b", "c"). The if expression filters any generated iterations for which our weight isn't high enough. As you can see, it's a much more natural translation of the above table to code.

There's one catch, the iterator runs once, but we want to run forever. I was actually stuck on this for a pretty long time, having several ideas but they none very elegant. Finally, when playing with recursive Streams, I discovered two great feature of Scala's Iterator class:

scala> val sixes = Iterator.continually(6)
sixes: Iterator[Int] = non-empty iterator

scala> sixes.next()
res8: Int = 6

scala> sixes take 10 toList
res9: List[Int] = List(6, 6, 6, 6, 6, 6, 6, 6, 6, 6)

Iterator.continually gives you a new iterator that runs forever, returning only the specified value. So how can we use it?

scala> val oneToFive = Iterator.continually(1 to 5)
oneToFive: Iterator[scala.collection.immutable.Range.Inclusive] = non-empty iterator

scala> oneToFive.next()
res12: scala.collection.immutable.Range.Inclusive = Range(1, 2, 3, 4, 5)

not quite...

scala> val oneToFive = Iterator.continually(1 to 5).flatten
oneToFive: Iterator[Int] = non-empty iterator

scala> oneToFive take 15 toList
res21: List[Int] = List(1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5)

BOOM! The combination of Iterator.continually and Iterator.flatten lets us do a circular loop of a sequence.

Here's our new iterator class:

import scala.reflect.ClassTag
import scala.annotation.tailrec

type Weight = Int

class WeightedRRIterator[A](itemsToWeights: Map[A, Weight]) {

  private def gcd (x: Weight, y: Weight) = if (x == 0) y else gcd (y % x, x)
  lazy val gcdWeights = itemsToWeights.values.reduce(gcd)
  lazy val maxWeight = itemsToWeights.values.max

  private val permutations = for {
    weight <- gcdWeights to maxWeight by gcdWeights;
    item <- itemsToWeights.keys if (itemsToWeights(item) >= weight)
  } yield item

  val iterator = Iterator.continually(permutations).flatten

  def hasNext = iterator.hasNext
  def next = iterator.next
}

val iter = new WeightedRRIterator(Map("a" -> 5, "b" -> 3))
for (i <- 1 to 20) { print(iter.next) } // prints abababaaabababaaabab

Note

We're using the temporary variable permutations because Iterator.continually takes parameters by name which basically means anything inside of the parenthesis is executed every time.

We could actually avoid the hasNext and next functions and just make this calls Iterable rather than an Iterator (right now it isn't declared as either, in my actual code it extends a base class that has other implementations like WeightedProbabilisticIterator and extends Iterator).

Fixing the Distribution

One problem a colleague of mine was quick to point out was their poor distribution resulting from this approach. Image you had weights like 100 and 7 for o and x. What you would get would be:

oxoxoxoxoxoxoxoooooooooooooooooooooooooooooooooooo
oooooooooooooooooooooooooooooooooooooooooooooooooo
ooooooooxoxoxoxoxoxoxooooooooooooooooooooooooooooo
oooooooooooooooooooooooooooooooooooooooooooooooooo

for the first 200 requests. Notice how b gets 50% of the traffic for x short time, and then 0% for a while. Ideally, it would be a bit more distributed. There's a few approaches to accomplish this.

We could shuffle the Range before iterating over it...

private val permutations = for {
    weight <- Random.shuffle(gcdWeights to maxWeight by gcdWeights);
    item <- itemsToWeights.keys if (itemsToWeights(item) >= weight)
  } yield item

or if we could just shuffle the entire sequence

private val permutations = Random.shuffle(for {
    weight <- Random.shuffle(gcdWeights to maxWeight by gcdWeights);
    item <- itemsToWeights.keys if (itemsToWeights(item) >= weight)
  } yield item)

the differences here are subtle but I'd argue the first one will produce better distributions because it makes it impossible for the lower weighted items ever get selected twice before higher weighted items.

Finally, if you want the iterator to be continuously shuffled: reshuffled every time we get to the end, you can take advantage of Iterator.continuously taking its parameter as call-by-name.

val iterator = Iterator.continually(Random.shuffle(permutations)).flatten

Since I was being indecisive, I implemented each of these "strategies" as traits, (SequentialOrdering, Shuffled, ContinuouslyShuffled) so when you instantiate a new RoundRobinIterator you must select one by mixing it in, such as

val wi = new WeightedRRIterator(Map("pizza" -> 2, "taco" -> 4)) with Shuffled

I'm not sure if it was the right approach or not yet, because the compiler won't stop you from meaninglessly mixing in multiple strategies.

In conclusion

For comprehensions are very expressive, and the Scala Iterator has some pretty neat tricks. I'm not sure how to do something similar in Java, but I am sure any approximation would just make me sad.

If there's any interest, I can open source the code discussed here. I made some interesting OO design choices that I'd probably want to revisit before letting other people see.

Category: misc

Tags: scala, functional programming, refactoring

Comments: toggle

© Chad Selph. Built using Pelican. Theme by Giulio Fidente on github. .