I’ll start with this short code in Scala:

      def oneIter(ds: Dataset[Int]) : Dataset[Int] = {
        import spark.implicits._
        ds.map(x => x + 1)
      }

      import spark.implicits._
      var before : Dataset[Int] = Seq(1, 2, 3, 4, 5).toDS()
      var after : Dataset[Int] = null
      var iter = 0
      breakable {
        while (true) {
          iter += 1
          after = oneIter(before)
          after.show() // In reality, the forced evaluation would be to determine if it's time to exit, but .show() also forces the evaluation and is simpler for the demonstration.

          if (iter >= 300) {
            break
          } else {
            before = after
            after = null
          }
        }
      }
      after.show()

The function oneIter takes a Dataset and improves it and it needs to be run as many times as it takes until the Dataset it returns is deemed good enough. To know that it’s good enough, I need to compare it to the “before” Dataset. If it is good enough, the “after” Dataset is the result and the program may exit.

The challenge I am running into is that I am juggling the “before” and “after” and it leads to some kind of bloat over time because the iteration times are nearly double as we reach iteration 300.

I’m trying to make use of caching and also to uncache (unpersist) when the Dataset is no longer needed:

      var before : Dataset[Int] = Seq(1, 2, 3, 4, 5).toDS()
      before.cache()
      var after : Dataset[Int] = null
      var iter = 0
      breakable {
        while (true) {
          iter += 1
          if (after != null) {
            after.unpersist()
          }
          after = oneIter(before)
          after.cache()
          after.show()

          if (iter >= 300) {
            break
          } else {
            before.unpersist()
            before = after
            after = null
          }
        }
      }
      after.show()
      before.unpersist()

This version crawls to (almost) a halt after just 25 iterations! Also, the DAG for the show() is extremely long and ends with MapPartitionsRDD[300]. The last successful .show() that the driver displayed was around iteration 28 so far. How does it even know that there are 300 iterations in total if it hasn’t completely executed “after.show()” more than 28 times yet? I thought it was a blocking operation.

What I think I want to achieve is once “after” is computed, the resulting Dataset should be cached in memory so that it can be assigned to the next iteration’s “before” (the input) to produce a new “after”, and so on. Also, prior to assigning to a new “before”, the old “before” should be purged from memory. It makes sense to me that this design will work, but since it slows down so much after 25 iterations, there seems to be some kind of a leak.

What is a good way to structure this program if I want the application to be long-running and processing many such Datasets? That is, it starts on one path and loops until the data is “good enough”, then queries for another path to work on, etc. I’m looking for this to be able to handle an infinite number of iterations, and I’m sure it’s possible since each iteration of the loop needs to know only the result of the previous iteration and none that came before it.

Thanks!