Unit Test of Spark Structured Streaming
더 정밀한 Streaming Unit Test를 위해 2024-01-07

머릿말

안녕하세요? 새해부터 찾아온 JustKode, 박민재입니다. 오늘은 Spark Structured Streaming에 대한 Unit Test를 수행 하는 법에 대해서 공유 드려 보려고 합니다.

What is Spark Structured Streaming?

Spark Structured StreamingSpark SQL API (Dataframe, Dataset)를 이용하여, Streaming 처리를 수행 할 수 있게 하는 Stream Processing Engine 입니다. Spark 2.x 버전 이상부터 지원하고 있으며, 차세대 Spark 기반 Streaming Engine입니다.

In real case

우리는 실제 Streaming 데이터를 처리 하면서 다양한 경우를 경험 하게 됩니다.

  • JSON 데이터가 깨져서 들어 오는 경우
  • 한창 전에 유입 되었어야 할 데이터가 뒤늦게 들어 오는 경우
  • 중복 데이터가 유입 되는 경우

이런 상황에서도 우리는 Streaming 데이터를 처리하는 Application이 종료 되지 않도록, 예외 처리를 수행해야 합니다. 하지만, Spark Structured Streaming 환경에서 Unit Test는 어떻게 수행 하여야 할까요? Spark SQL 쿼리를 함수로 빼서 Test를 수행 할 수 있겠지만, 실제 Streaming 데이터가 들어온다는 가정 하에는 테스트를 자동으로 수행 하게 하기가 어렵습니다.

네... 아래와 같은 함수들은 JUnit이나, ScalaTest로 충분히 가능 하거든요.

def aggProductPerformanceReport(df: DataFrame): DataFrame = {
  df.groupBy(col("productId"))
    .agg(
      sum("price") as "price",
      count("*") as "sales"
    )
}

굳이 한다면, SocketreadStream으로 적용한 Test용 Application을 만들어서 진행 할 수 있을 것 같지만, 5분 단위 Streaming Window에 대해서 이를 테스트 하려고 5분을 기다리는 건 너무 비효율적이며, 수동으로 데이터를 삽입 하여야 한다는 문제가 발생 합니다. 이를 CI/CD 과정에 태우는 것도 어려움이 있을 거에요. Sink Result 같은 경우에도, 콘솔로 출력 하는 것이 최선일 것이기 때문이에요.

val query = wordCounts.writeStream
  .outputMode("complete")
  .format("console")
  .start()

query.awaitTermination()

우리가 정밀한 Test를 위해서 직면한 문제를 정의 하면 다음과 같습니다.

  • 실제 Streaming 데이터가 들어오는 것처럼, 깨진 파일 등의 Test Case를 만들 수 있는가?
  • 지연된 데이터가 유입되는 시나리오를 구현 하기 위해, 원하는 시간대에서 Trigger 되도록 구현 할 수 있는가?
  • Sink Result를 뽑아서, 조건에 일치 하는 지 확인 할 수 있는가?

이제 하나하나씩 문제를 해결 해 보도록 하겠습니다.

해당 실습에서 사용하는 Spark Version은 3.5.0을 사용합니다.

MemoryStream

실제 Streaming 데이터가 들어 오는 것을 구현하기 위해서는 어떻게 해야 할까요? 우리는 이를 위해, MemoryStream을 사용할 수 있습니다. MemoryStream을 통해서, Kafka로 데이터가 유입되는 케이스를 구현 할 수 있습니다. 다음과 같이 말이에요.

해당 Code SnippetMemoryStream 객체에 데이터를 삽입하는 예제입니다.

import org.apache.spark.sql.execution.streaming.MemoryStream

val logs = new MemoryStream[String]
logs.addData("""{"id": 1, "content": "hello"}""")
logs.addData("""{"id": 2, "content": "hello!"}""")
logs.addData("""{"id": 3, "content": "hello!!"}""")
logs.addData("""{"id": 3, "co""")  // 이상 데이터

이를 DataFrame으로 변환하기 위해서는 MemoryStream의 메소드인 toDF()를 호출 하면 되고, JSON을 추출하기 위해서는 from_json()StructType을 이용하면 됩니다.

방금 우리가 MemoryStream으로 만들어 낸 DataFrame을 가상의 In-Memory TableSink하기 위해서는 writeStream을 이용하여 다음과 같이 .format("memory").queryName("원하는 테이블 명") 을 입력 하면 됩니다. 추가적으로, 인위적인 Trigger 없이, 즉시 데이터를 Processing 하고 싶다면, DataFrame으로 만들어 낸 StreamingQueryprocessAllAvailable() 메서드를 호출하여 주면 됩니다.

import org.apache.spark.sql.types._

val schema = StructType(
  Seq(
    StructField("id", DataTypes.LongType, true),
    StructField("content", DataTypes.StringType, true)
  )
)

val df = logs.toDF()
  .select(from_json(col("value"), schema) as "data")
  .select("data.*")

val streamingQuery = df.writeStream
  .format("memory")
  .queryName("agg")
  .outputMode("append")
  .start()

streamingQuery.processAllAvailable()

val result = spark.sql("select * from agg").collectAsList()
result.forEach(println)

출력 결과는 다음과 같습니다.

[1,hello]
[2,hello!]
[3,hello!!]
[null,null]

하지만, 다음과 같은 구성은 Watermark 같이, 과거의 데이터를 드랍 하는 로직이 정상적으로 작동 하는지를 확인 하기 어렵습니다. Trigger가 호출 되는 논리적 시간을 조정하는 로직은 StreamingQuery Class에 존재 하지 않기 때문이에요. 그렇다면 우리는 다른 방법을 찾아 볼 필요가 있습니다.

StreamingQueryManager.startQuery()

StreamingQueryManager.startQuery()StreamingQuery를 호출하는 저레벨의 Private 함수 입니다. 파라미터로 Option, 연산을 수행 할 DataFrame, Clock, Sink, OutputMode 등을 입력 하면, 이에 맞는 StreamingQuery 를 반환 합니다.

private[sql] def startQuery(
    userSpecifiedName: Option[String],
    userSpecifiedCheckpointLocation: Option[String],
    df: DataFrame,
    extraOptions: Map[String, String],
    sink: Table,
    outputMode: OutputMode,
    useTempCheckpointLocation: Boolean = false,
    recoverFromCheckpointLocation: Boolean = true,
    trigger: Trigger = Trigger.ProcessingTime(0),
    triggerClock: Clock = new SystemClock(),
    catalogAndIdent: Option[(TableCatalog, Identifier)] = None,
    catalogTable: Option[CatalogTable] = None): StreamingQuery

StreamingQueryManager.startQuery()와, MemorySink, ManualClock을 이용하여, 원하는 논리적 시간대Trigger를 호출 할 수 있습니다.

중요!!!

그 전에 우리가 해야 할 일이 있습니다. StreamingQueryManager.startQuery()ManualClock은 Private Method라서 일반적인 방법으로는 사용이 불가능 하기 때문에, 인위적으로 Package에 접근 하여 꺼내오는 방법은 다음과 같습니다. 약간의 트릭을 사용 하여, Test 폴더 내 org.apache.spark.sql 경로에서 파일을 생성 합니다.

org.apache.spark.sql.StreamingTestUtil

package org.apache.spark.sql

import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.apache.spark.util.ManualClock

object StreamingTestUtil {
  def getStreamingQuery(df: DataFrame,
                        clock: ManualClock,
                        sink: MemorySink,
                        checkpoint: String,
                        outputMode: OutputMode): StreamingQuery = {
    df.sparkSession
      .streams
      .startQuery(
        userSpecifiedName = Some("spark-structured-streaming-unit-test"),
        userSpecifiedCheckpointLocation = Some(checkpoint),
        df = df,
        extraOptions = Map[String, String](),
        sink = sink,
        outputMode = outputMode,
        recoverFromCheckpointLocation = false,
        triggerClock = clock
      )
  }

  def getClock(time: Long): ManualClock = {
    new ManualClock(time)
  }
}

SparkStreamingTestRunner

원활한 테스트 코드를 작성 하기 위해서, SparkStreamingTestRunner 라는 trait 또한 만들어 보겠습니다. 구현된 내용은 다음과 같습니다.

  • Scalatest의 API를 사용합니다. (AnyFlatSpec, BeforeAndAfter, BeforeAndAfterAll)
  • val spark: 로컬에서 구동 되는 SparkContext 를 가지고 있습니다.
  • checkpointLocation, logs, memorySink: 각각 체크포인트 경로, MemoryStream, MemorySink를 가지고 있습니다. startQuery로는 특정 in-memory table에 데이터를 저장 하도록 설정할 수 없으므로, MemorySink에 데이터를 저장하여, MemorySink.allDataSeq[Row] 정보를 추출 합니다.
  • getDataFrameFromJsonRecordsBySchema(schema: StructType): DataFrame: StructType을 입력 받아, MemoryStream내에 존재하는 JSON을 DataFrame으로 변환 합니다.
  • caseClassObjectToJson[T](obj: T): String: case class 객체를 JSON String으로 변환 합니다.
  • caseClassToStructType[T: scala.reflect.runtime.universe.TypeTag]: StructType: case class Type을 삽입하면, 이를 Struct로 변환 합니다.
  • calenderToTimestamp(calender: Calendar): Timestamp: java.util.Calendar 객체를 java.sql.Timestamp 객체로 변환 합니다.
  • rowListToString(rows: Seq[Row]): String: Row가 담긴 리스트를 String으로 변환 합니다.
package kr.justkode.util

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.{DefaultScalaModule, ScalaObjectMapper}
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.sql.catalyst.ScalaReflection

import java.io.File
import java.sql.Timestamp
import java.util.Calendar

trait SparkStreamingTestRunner extends AnyFlatSpec
  with BeforeAndAfter
  with BeforeAndAfterAll {

  private val mapper = new ObjectMapper() with ScalaObjectMapper
  mapper.registerModule(DefaultScalaModule)

  val spark = SparkSession
      .builder()
      .config("spark.sql.shuffle.partitions", 1)
      .master("local[2]")
      .getOrCreate()

  import spark.implicits._

  implicit val ctx = spark.sqlContext

  spark.sparkContext.setLogLevel("WARN")

  val checkpointLocation = "/tmp/spark-structured-streaming-unit-test"
  val logs: MemoryStream[String] = MemoryStream[String]
  val memorySink = new MemorySink

  override def beforeAll(): Unit = {
    FileUtils.deleteDirectory(new File(checkpointLocation))
  }

  override def afterAll(): Unit = {
    FileUtils.deleteDirectory(new File(checkpointLocation))
  }

  protected def getDataFrameFromJsonRecordsBySchema(schema: StructType): DataFrame = {
    logs.toDF()
      .select(from_json(col("value"), schema) as "data")
      .select("data.*")
  }

  protected def caseClassObjectToJson[T](obj: T): String = {
    mapper.writeValueAsString(obj)
  }

  protected def caseClassObjectToJson[T](objList: Seq[T]): String = {
    objList.foldLeft("")((x, y) => x + mapper.writeValueAsString(y) + '\n').trim
  }

  protected def caseClassToStructType[T: scala.reflect.runtime.universe.TypeTag]: StructType = {
    ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
  }

  protected def calenderToTimestamp(calender: Calendar): Timestamp = {
    new Timestamp(calender.getTimeInMillis / 1000)
  }

  protected def rowListToString(rows: Seq[Row]): String = {
    rows.foldLeft("")((x, row) => x + row.toString() + '\n').trim
  }
}

그 다음, 이 함수들을 이용해서 한 번 Test 환경을 구축 해 보겠습니다. 자세한 설명은 주석으로 남기겠습니다.

kr.justkode.aggregator.WatermarkAggregator

package kr.justkode.aggregator

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

import java.sql.Timestamp

object WatermarkAggregator {
  case class Event(adId: Long, eventId: Long, timestamp: Timestamp)

  def aggClick(df: DataFrame): DataFrame = {
    df.filter(col("eventId") === 1)
      .withWatermark("timestamp", "10 minutes")  // 특정 기준에서 10분 이상 벗어난 데이터는 Drop
      .groupBy(
        window(col("timestamp"), "10 minutes", "5 minutes"), 
        col("adId")
      )  // 5분 간격, 10분 길이 Window에 대해 Group By 연산
      .count()
  }
}

kr.justkode.streaming.WatermarkStreamingTest

package kr.justkode.streaming

import kr.justkode.aggregator.WatermarkAggregator.{Event, aggClick}
import kr.justkode.util.SparkStreamingTestRunner
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.StreamingTestUtil.{getClock, getStreamingQuery}
import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.streaming.OutputMode

import java.io.File
import java.util.Calendar

class WatermarkStreamingTest extends SparkStreamingTestRunner {
  // case class로부터, StructType 추출 
  val schema = caseClassToStructType[Event]

  // Test 종료 시, Checkpoint 제거 및 MemoryStream 초기화
  after {
    logs.reset()
    FileUtils.deleteDirectory(new File(checkpointLocation))
  }

  "row count / sum of imp" should "2, 10 / 2, 14" in {
    // Calendar 객체, 2024-01-01 00:00으로 초기화
    val currentTime = Calendar.getInstance()
    currentTime.set(2024, 0, 1, 0, 0)

    logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(2, 0, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(2, 0, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(3, 0, calenderToTimestamp(currentTime))))
    logs.addData("{asdfd}")  // 깨진 데이터

    val df = getDataFrameFromJsonRecordsBySchema(schema)  // JSON으로 부터 데이터 추출
    val sink = new MemorySink
    val clock = getClock(currentTime.getTimeInMillis)  // ManualClock을 2024-01-01 00:00으로 초기화

    // StreamingQuery 초기화 후, Mini-Batch 1회 수행
    val streamingQuery = getStreamingQuery(aggClick(df), clock, sink, checkpointLocation, OutputMode.Update)
    streamingQuery.processAllAvailable()

    // 1회차 Mini-Batch Test. sink.allData는 Seq[Row]를 반환 합니다.
    info("=== 1 ===")
    info(rowListToString(sink.allData))

    assert(sink.allData.size == 2)
    assert(sink.allData.head.getAs[Long]("count") == 2L)

    // Clock, CurrentTime 5분 증가
    clock.advance(1000 * 60 * 5)
    currentTime.add(Calendar.MINUTE, 5)

    logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(2, 1, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(3, 1, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(4, 1, calenderToTimestamp(currentTime))))

    // Clock 에서 부터 22분 전의 데이터 삽입
    currentTime.add(Calendar.MINUTE, -22)
    logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
    logs.addData(caseClassObjectToJson(Event(2, 1, calenderToTimestamp(currentTime))))
    streamingQuery.processAllAvailable()

    // 2회차 Mini-Batch Test
    info("=== 2 ===")
    info(rowListToString(sink.allData))

    assert(sink.allData.size == 10)
    assert(sink.allData.foldLeft(0L)((x, y) => x + y.getAs[Long]("count")) == 14L)
  }
}

Output은 다음은 같습니다.

=== 1 ===

[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],1,2]
[[2023-12-31 23:55:00.0,2024-01-01 00:05:00.0],1,2]

=== 2 ===

[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],1,2]
[[2023-12-31 23:55:00.0,2024-01-01 00:05:00.0],1,2]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],1,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],1,3]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],2,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],2,1]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],3,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],3,1]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],4,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],4,1]

마치며

코드의 길이가 길어서 따라오기가 힘들었을 것으로 예상 됩니다. 해당 Github Repository에 예제 코드를 기록 해 놨으니, 더 자세한 예제가 필요 하다면 Repository를 참조 해 주세요.

긴 글 읽어 주셔서 감사합니다. 좋은 하루 보내세요! :D

Spark Structured Streaming 시리즈의 다른 글