안녕하세요? 새해부터 찾아온 JustKode, 박민재입니다. 오늘은 Spark Structured Streaming에 대한 Unit Test를 수행 하는 법에 대해서 공유 드려 보려고 합니다.
Spark Structured Streaming은 Spark SQL API (Dataframe, Dataset)를 이용하여, Streaming 처리를 수행 할 수 있게 하는 Stream Processing Engine 입니다. Spark 2.x 버전 이상부터 지원하고 있으며, 차세대 Spark 기반 Streaming Engine입니다.
우리는 실제 Streaming 데이터를 처리 하면서 다양한 경우를 경험 하게 됩니다.
이런 상황에서도 우리는 Streaming 데이터를 처리하는 Application이 종료 되지 않도록, 예외 처리를 수행해야 합니다. 하지만, Spark Structured Streaming 환경에서 Unit Test는 어떻게 수행 하여야 할까요? Spark SQL 쿼리를 함수로 빼서 Test를 수행 할 수 있겠지만, 실제 Streaming 데이터가 들어온다는 가정 하에는 테스트를 자동으로 수행 하게 하기가 어렵습니다.
네... 아래와 같은 함수들은 JUnit이나, ScalaTest로 충분히 가능 하거든요.
def aggProductPerformanceReport(df: DataFrame): DataFrame = {
df.groupBy(col("productId"))
.agg(
sum("price") as "price",
count("*") as "sales"
)
}
굳이 한다면, Socket을 readStream으로 적용한 Test용 Application을 만들어서 진행 할 수 있을 것 같지만, 5분 단위 Streaming Window에 대해서 이를 테스트 하려고 5분을 기다리는 건 너무 비효율적이며, 수동으로 데이터를 삽입 하여야 한다는 문제가 발생 합니다. 이를 CI/CD 과정에 태우는 것도 어려움이 있을 거에요. Sink Result 같은 경우에도, 콘솔로 출력 하는 것이 최선일 것이기 때문이에요.
val query = wordCounts.writeStream
.outputMode("complete")
.format("console")
.start()
query.awaitTermination()
우리가 정밀한 Test를 위해서 직면한 문제를 정의 하면 다음과 같습니다.
이제 하나하나씩 문제를 해결 해 보도록 하겠습니다.
해당 실습에서 사용하는 Spark Version은 3.5.0을 사용합니다.
실제 Streaming 데이터가 들어 오는 것을 구현하기 위해서는 어떻게 해야 할까요? 우리는 이를 위해, MemoryStream
을 사용할 수 있습니다. MemoryStream
을 통해서, Kafka로 데이터가 유입되는 케이스를 구현 할 수 있습니다. 다음과 같이 말이에요.
해당 Code Snippet은 MemoryStream
객체에 데이터를 삽입하는 예제입니다.
import org.apache.spark.sql.execution.streaming.MemoryStream
val logs = new MemoryStream[String]
logs.addData("""{"id": 1, "content": "hello"}""")
logs.addData("""{"id": 2, "content": "hello!"}""")
logs.addData("""{"id": 3, "content": "hello!!"}""")
logs.addData("""{"id": 3, "co""") // 이상 데이터
이를 DataFrame
으로 변환하기 위해서는 MemoryStream
의 메소드인 toDF()
를 호출 하면 되고, JSON을 추출하기 위해서는 from_json()
과 StructType
을 이용하면 됩니다.
방금 우리가 MemoryStream
으로 만들어 낸 DataFrame
을 가상의 In-Memory Table에 Sink하기 위해서는 writeStream
을 이용하여 다음과 같이 .format("memory")
와 .queryName("원하는 테이블 명")
을 입력 하면 됩니다. 추가적으로, 인위적인 Trigger 없이, 즉시 데이터를 Processing 하고 싶다면, DataFrame
으로 만들어 낸 StreamingQuery
의 processAllAvailable()
메서드를 호출하여 주면 됩니다.
import org.apache.spark.sql.types._
val schema = StructType(
Seq(
StructField("id", DataTypes.LongType, true),
StructField("content", DataTypes.StringType, true)
)
)
val df = logs.toDF()
.select(from_json(col("value"), schema) as "data")
.select("data.*")
val streamingQuery = df.writeStream
.format("memory")
.queryName("agg")
.outputMode("append")
.start()
streamingQuery.processAllAvailable()
val result = spark.sql("select * from agg").collectAsList()
result.forEach(println)
출력 결과는 다음과 같습니다.
[1,hello]
[2,hello!]
[3,hello!!]
[null,null]
하지만, 다음과 같은 구성은 Watermark 같이, 과거의 데이터를 드랍 하는 로직이 정상적으로 작동 하는지를 확인 하기 어렵습니다. Trigger가 호출 되는 논리적 시간을 조정하는 로직은 StreamingQuery
Class에 존재 하지 않기 때문이에요. 그렇다면 우리는 다른 방법을 찾아 볼 필요가 있습니다.
StreamingQueryManager.startQuery()
는 StreamingQuery
를 호출하는 저레벨의 Private 함수 입니다. 파라미터로 Option, 연산을 수행 할 DataFrame, Clock, Sink, OutputMode 등을 입력 하면, 이에 맞는 StreamingQuery
를 반환 합니다.
private[sql] def startQuery(
userSpecifiedName: Option[String],
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
extraOptions: Map[String, String],
sink: Table,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean = false,
recoverFromCheckpointLocation: Boolean = true,
trigger: Trigger = Trigger.ProcessingTime(0),
triggerClock: Clock = new SystemClock(),
catalogAndIdent: Option[(TableCatalog, Identifier)] = None,
catalogTable: Option[CatalogTable] = None): StreamingQuery
StreamingQueryManager.startQuery()
와, MemorySink
, ManualClock
을 이용하여, 원하는 논리적 시간대에 Trigger를 호출 할 수 있습니다.
그 전에 우리가 해야 할 일이 있습니다. StreamingQueryManager.startQuery()
과 ManualClock
은 Private Method라서 일반적인 방법으로는 사용이 불가능 하기 때문에, 인위적으로 Package에 접근 하여 꺼내오는 방법은 다음과 같습니다. 약간의 트릭을 사용 하여, Test 폴더 내 org.apache.spark.sql
경로에서 파일을 생성 합니다.
org.apache.spark.sql.StreamingTestUtil
package org.apache.spark.sql
import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.apache.spark.util.ManualClock
object StreamingTestUtil {
def getStreamingQuery(df: DataFrame,
clock: ManualClock,
sink: MemorySink,
checkpoint: String,
outputMode: OutputMode): StreamingQuery = {
df.sparkSession
.streams
.startQuery(
userSpecifiedName = Some("spark-structured-streaming-unit-test"),
userSpecifiedCheckpointLocation = Some(checkpoint),
df = df,
extraOptions = Map[String, String](),
sink = sink,
outputMode = outputMode,
recoverFromCheckpointLocation = false,
triggerClock = clock
)
}
def getClock(time: Long): ManualClock = {
new ManualClock(time)
}
}
원활한 테스트 코드를 작성 하기 위해서, SparkStreamingTestRunner
라는 trait
또한 만들어 보겠습니다. 구현된 내용은 다음과 같습니다.
AnyFlatSpec
, BeforeAndAfter
, BeforeAndAfterAll
)val spark
: 로컬에서 구동 되는 SparkContext
를 가지고 있습니다.checkpointLocation, logs, memorySink
: 각각 체크포인트 경로, MemoryStream
, MemorySink
를 가지고 있습니다. startQuery
로는 특정 in-memory table에 데이터를 저장 하도록 설정할 수 없으므로, MemorySink에 데이터를 저장하여, MemorySink.allData
로 Seq[Row]
정보를 추출 합니다.getDataFrameFromJsonRecordsBySchema(schema: StructType): DataFrame
: StructType
을 입력 받아, MemoryStream
내에 존재하는 JSON을 DataFrame
으로 변환 합니다.caseClassObjectToJson[T](obj: T): String
: case class 객체를 JSON String으로 변환 합니다.caseClassToStructType[T: scala.reflect.runtime.universe.TypeTag]: StructType
: case class Type을 삽입하면, 이를 Struct로 변환 합니다.calenderToTimestamp(calender: Calendar): Timestamp
: java.util.Calendar
객체를 java.sql.Timestamp
객체로 변환 합니다.rowListToString(rows: Seq[Row]): String
: Row
가 담긴 리스트를 String
으로 변환 합니다.package kr.justkode.util
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.{DefaultScalaModule, ScalaObjectMapper}
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.sql.catalyst.ScalaReflection
import java.io.File
import java.sql.Timestamp
import java.util.Calendar
trait SparkStreamingTestRunner extends AnyFlatSpec
with BeforeAndAfter
with BeforeAndAfterAll {
private val mapper = new ObjectMapper() with ScalaObjectMapper
mapper.registerModule(DefaultScalaModule)
val spark = SparkSession
.builder()
.config("spark.sql.shuffle.partitions", 1)
.master("local[2]")
.getOrCreate()
import spark.implicits._
implicit val ctx = spark.sqlContext
spark.sparkContext.setLogLevel("WARN")
val checkpointLocation = "/tmp/spark-structured-streaming-unit-test"
val logs: MemoryStream[String] = MemoryStream[String]
val memorySink = new MemorySink
override def beforeAll(): Unit = {
FileUtils.deleteDirectory(new File(checkpointLocation))
}
override def afterAll(): Unit = {
FileUtils.deleteDirectory(new File(checkpointLocation))
}
protected def getDataFrameFromJsonRecordsBySchema(schema: StructType): DataFrame = {
logs.toDF()
.select(from_json(col("value"), schema) as "data")
.select("data.*")
}
protected def caseClassObjectToJson[T](obj: T): String = {
mapper.writeValueAsString(obj)
}
protected def caseClassObjectToJson[T](objList: Seq[T]): String = {
objList.foldLeft("")((x, y) => x + mapper.writeValueAsString(y) + '\n').trim
}
protected def caseClassToStructType[T: scala.reflect.runtime.universe.TypeTag]: StructType = {
ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
}
protected def calenderToTimestamp(calender: Calendar): Timestamp = {
new Timestamp(calender.getTimeInMillis / 1000)
}
protected def rowListToString(rows: Seq[Row]): String = {
rows.foldLeft("")((x, row) => x + row.toString() + '\n').trim
}
}
그 다음, 이 함수들을 이용해서 한 번 Test 환경을 구축 해 보겠습니다. 자세한 설명은 주석으로 남기겠습니다.
kr.justkode.aggregator.WatermarkAggregator
package kr.justkode.aggregator
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import java.sql.Timestamp
object WatermarkAggregator {
case class Event(adId: Long, eventId: Long, timestamp: Timestamp)
def aggClick(df: DataFrame): DataFrame = {
df.filter(col("eventId") === 1)
.withWatermark("timestamp", "10 minutes") // 특정 기준에서 10분 이상 벗어난 데이터는 Drop
.groupBy(
window(col("timestamp"), "10 minutes", "5 minutes"),
col("adId")
) // 5분 간격, 10분 길이 Window에 대해 Group By 연산
.count()
}
}
kr.justkode.streaming.WatermarkStreamingTest
package kr.justkode.streaming
import kr.justkode.aggregator.WatermarkAggregator.{Event, aggClick}
import kr.justkode.util.SparkStreamingTestRunner
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.StreamingTestUtil.{getClock, getStreamingQuery}
import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.streaming.OutputMode
import java.io.File
import java.util.Calendar
class WatermarkStreamingTest extends SparkStreamingTestRunner {
// case class로부터, StructType 추출
val schema = caseClassToStructType[Event]
// Test 종료 시, Checkpoint 제거 및 MemoryStream 초기화
after {
logs.reset()
FileUtils.deleteDirectory(new File(checkpointLocation))
}
"row count / sum of imp" should "2, 10 / 2, 14" in {
// Calendar 객체, 2024-01-01 00:00으로 초기화
val currentTime = Calendar.getInstance()
currentTime.set(2024, 0, 1, 0, 0)
logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(2, 0, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(2, 0, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(3, 0, calenderToTimestamp(currentTime))))
logs.addData("{asdfd}") // 깨진 데이터
val df = getDataFrameFromJsonRecordsBySchema(schema) // JSON으로 부터 데이터 추출
val sink = new MemorySink
val clock = getClock(currentTime.getTimeInMillis) // ManualClock을 2024-01-01 00:00으로 초기화
// StreamingQuery 초기화 후, Mini-Batch 1회 수행
val streamingQuery = getStreamingQuery(aggClick(df), clock, sink, checkpointLocation, OutputMode.Update)
streamingQuery.processAllAvailable()
// 1회차 Mini-Batch Test. sink.allData는 Seq[Row]를 반환 합니다.
info("=== 1 ===")
info(rowListToString(sink.allData))
assert(sink.allData.size == 2)
assert(sink.allData.head.getAs[Long]("count") == 2L)
// Clock, CurrentTime 5분 증가
clock.advance(1000 * 60 * 5)
currentTime.add(Calendar.MINUTE, 5)
logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(2, 1, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(3, 1, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(4, 1, calenderToTimestamp(currentTime))))
// Clock 에서 부터 22분 전의 데이터 삽입
currentTime.add(Calendar.MINUTE, -22)
logs.addData(caseClassObjectToJson(Event(1, 1, calenderToTimestamp(currentTime))))
logs.addData(caseClassObjectToJson(Event(2, 1, calenderToTimestamp(currentTime))))
streamingQuery.processAllAvailable()
// 2회차 Mini-Batch Test
info("=== 2 ===")
info(rowListToString(sink.allData))
assert(sink.allData.size == 10)
assert(sink.allData.foldLeft(0L)((x, y) => x + y.getAs[Long]("count")) == 14L)
}
}
Output은 다음은 같습니다.
=== 1 ===
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],1,2]
[[2023-12-31 23:55:00.0,2024-01-01 00:05:00.0],1,2]
=== 2 ===
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],1,2]
[[2023-12-31 23:55:00.0,2024-01-01 00:05:00.0],1,2]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],1,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],1,3]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],2,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],2,1]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],3,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],3,1]
[[2024-01-01 00:05:00.0,2024-01-01 00:15:00.0],4,1]
[[2024-01-01 00:00:00.0,2024-01-01 00:10:00.0],4,1]
코드의 길이가 길어서 따라오기가 힘들었을 것으로 예상 됩니다. 해당 Github Repository에 예제 코드를 기록 해 놨으니, 더 자세한 예제가 필요 하다면 Repository를 참조 해 주세요.
긴 글 읽어 주셔서 감사합니다. 좋은 하루 보내세요! :D